mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-16 19:23:51 +01:00
Move are_all_users_on_domain checks to main data store.
This commit is contained in:
parent
9a4fb457cf
commit
d64bb32a73
3 changed files with 24 additions and 24 deletions
|
@ -68,7 +68,7 @@ from synapse.rest.key.v2 import KeyApiV2Resource
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.well_known import WellKnownResource
|
from synapse.rest.well_known import WellKnownResource
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore, are_all_users_on_domain
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
||||||
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
|
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
|
||||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
|
@ -295,16 +295,6 @@ class SynapseHomeServer(HomeServer):
|
||||||
logger.warning("Unrecognized listener type: %s", listener["type"])
|
logger.warning("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
def run_startup_checks(self, db_conn, database_engine):
|
def run_startup_checks(self, db_conn, database_engine):
|
||||||
all_users_native = are_all_users_on_domain(
|
|
||||||
db_conn.cursor(), database_engine, self.hostname
|
|
||||||
)
|
|
||||||
if not all_users_native:
|
|
||||||
quit_with_error(
|
|
||||||
"Found users in database not native to %s!\n"
|
|
||||||
"You cannot changed a synapse server_name after it's been configured"
|
|
||||||
% (self.hostname,)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
database_engine.check_database(db_conn.cursor())
|
database_engine.check_database(db_conn.cursor())
|
||||||
except IncorrectDatabaseSetup as e:
|
except IncorrectDatabaseSetup as e:
|
||||||
|
|
|
@ -49,15 +49,3 @@ class Storage(object):
|
||||||
self.persistence = EventsPersistenceStorage(hs, stores)
|
self.persistence = EventsPersistenceStorage(hs, stores)
|
||||||
self.purge_events = PurgeEventsStorage(hs, stores)
|
self.purge_events = PurgeEventsStorage(hs, stores)
|
||||||
self.state = StateGroupStorage(hs, stores)
|
self.state = StateGroupStorage(hs, stores)
|
||||||
|
|
||||||
|
|
||||||
def are_all_users_on_domain(txn, database_engine, domain):
|
|
||||||
sql = database_engine.convert_param_style(
|
|
||||||
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
|
|
||||||
)
|
|
||||||
pat = "%:" + domain
|
|
||||||
txn.execute(sql, (pat,))
|
|
||||||
num_not_matching = txn.fetchall()[0][0]
|
|
||||||
if num_not_matching == 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
|
@ -115,7 +115,17 @@ class DataStore(
|
||||||
def __init__(self, database: Database, db_conn, hs):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = database.engine
|
||||||
|
|
||||||
|
all_users_native = are_all_users_on_domain(
|
||||||
|
db_conn.cursor(), database.engine, hs.hostname
|
||||||
|
)
|
||||||
|
if not all_users_native:
|
||||||
|
raise Exception(
|
||||||
|
"Found users in database not native to %s!\n"
|
||||||
|
"You cannot changed a synapse server_name after it's been configured"
|
||||||
|
% (self.hostname,)
|
||||||
|
)
|
||||||
|
|
||||||
self._stream_id_gen = StreamIdGenerator(
|
self._stream_id_gen = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
@ -555,3 +565,15 @@ class DataStore(
|
||||||
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
|
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
|
||||||
desc="search_users",
|
desc="search_users",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def are_all_users_on_domain(txn, database_engine, domain):
|
||||||
|
sql = database_engine.convert_param_style(
|
||||||
|
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
|
||||||
|
)
|
||||||
|
pat = "%:" + domain
|
||||||
|
txn.execute(sql, (pat,))
|
||||||
|
num_not_matching = txn.fetchall()[0][0]
|
||||||
|
if num_not_matching == 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
Loading…
Reference in a new issue