From 455579ca90dd5479dae785b5a1b9bdd201654ea6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Mar 2015 10:55:55 +0000 Subject: [PATCH] Make database selection configurable --- synapse/app/homeserver.py | 44 ++++++++++++++++++++++++++++---------- synapse/config/database.py | 9 ++++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 15c454af7..a2fca2e02 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -61,6 +61,7 @@ import resource import subprocess import sqlite3 import syweb +import yaml logger = logging.getLogger(__name__) @@ -108,14 +109,14 @@ class SynapseHomeServer(HomeServer): return None def build_db_pool(self): - return adbapi.ConnectionPool( - "sqlite3", self.get_db_name(), - check_same_thread=False, - cp_min=1, - cp_max=1, - cp_openfun=prepare_database, # Prepare the database for each conn - # so that :memory: sqlite works - ) + name = self.db_config.pop("name", None) + if name == "MySQLdb": + return adbapi.ConnectionPool( + name, + **self.db_config + ) + + raise RuntimeError("Unsupported database type") def create_resource_tree(self, redirect_root_to_web_client): """Create the resource tree for this Home Server. @@ -357,11 +358,29 @@ def setup(config_options): tls_context_factory = context_factory.ServerContextFactory(config) + if config.database_config: + with open(config.database_config, 'r') as f: + db_config = yaml.safe_load(f) + + name = db_config.get("name", None) + if name == "MySQLdb": + db_config.update({ + "sql_mode": "TRADITIONAL", + "charset": "utf8", + "use_unicode": True, + }) + else: + db_config = { + "name": "sqlite3", + "database": config.database_path, + } + hs = SynapseHomeServer( config.server_name, domain_with_port=domain_with_port, upload_dir=os.path.abspath("uploads"), db_name=config.database_path, + db_config=db_config, tls_context_factory=tls_context_factory, config=config, content_addr=config.content_addr, @@ -377,9 +396,12 @@ def setup(config_options): logger.info("Preparing database: %s...", db_name) try: - with sqlite3.connect(db_name) as db_conn: - prepare_sqlite3_database(db_conn) - prepare_database(db_conn) + # with sqlite3.connect(db_name) as db_conn: + # prepare_sqlite3_database(db_conn) + # prepare_database(db_conn) + import MySQLdb + db_conn = MySQLdb.connect(**db_config) + prepare_database(db_conn) except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" diff --git a/synapse/config/database.py b/synapse/config/database.py index 87efe5464..8dc9873f8 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -26,6 +26,11 @@ class DatabaseConfig(Config): self.database_path = self.abspath(args.database_path) self.event_cache_size = self.parse_size(args.event_cache_size) + if args.database_config: + self.database_config = self.abspath(args.database_config) + else: + self.database_config = None + @classmethod def add_arguments(cls, parser): super(DatabaseConfig, cls).add_arguments(parser) @@ -38,6 +43,10 @@ class DatabaseConfig(Config): "--event-cache-size", default="100K", help="Number of events to cache in memory." ) + db_group.add_argument( + "--database-config", default=None, + help="Location of the database configuration file." + ) @classmethod def generate_config(cls, args, config_dir_path):