diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index e27c7332d2..7323bf0f1e 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -14,20 +14,8 @@ # limitations under the License. from ._base import BaseSlavedStore -from synapse.storage import DataStore -from synapse.storage.registration import RegistrationStore +from synapse.storage.registration import RegistrationWorkerStore -class SlavedRegistrationStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedRegistrationStore, self).__init__(db_conn, hs) - - # TODO: use the cached version and invalidate deleted tokens - get_user_by_access_token = RegistrationStore.__dict__[ - "get_user_by_access_token" - ] - - _query_for_auth = DataStore._query_for_auth.__func__ - get_user_by_id = RegistrationStore.__dict__[ - "get_user_by_id" - ] +class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore): + pass diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 95f75d6df1..d809b2ba46 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -19,10 +19,70 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(background_updates.BackgroundUpdateStore): +class RegistrationWorkerStore(SQLBaseStore): + @cached() + def get_user_by_id(self, user_id): + return self._simple_select_one( + table="users", + keyvalues={ + "name": user_id, + }, + retcols=["name", "password_hash", "is_guest"], + allow_none=True, + desc="get_user_by_id", + ) + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. + """ + return self.runInteraction( + "get_user_by_access_token", + self._query_for_auth, + token + ) + + @defer.inlineCallbacks + def is_server_admin(self, user): + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + defer.returnValue(res if res else False) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + +class RegistrationStore(RegistrationWorkerStore, + background_updates.BackgroundUpdateStore): def __init__(self, db_conn, hs): super(RegistrationStore, self).__init__(db_conn, hs) @@ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - def get_user_by_id(self, user_id): - return self._simple_select_one( - table="users", - keyvalues={ - "name": user_id, - }, - retcols=["name", "password_hash", "is_guest"], - allow_none=True, - desc="get_user_by_id", - ) - def get_users_by_id_case_insensitive(self, user_id): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. @@ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): return self.runInteraction("delete_access_token", f) - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. - """ - return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token - ) - - @defer.inlineCallbacks - def is_server_admin(self, user): - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - defer.returnValue(res if res else False) - @cachedInlineCallbacks() def is_guest(self, user_id): res = yield self._simple_select_one_onecol( @@ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(res if res else False) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): yield self._simple_upsert("user_threepids", {