mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 09:03:51 +01:00
Split registration store
This commit is contained in:
parent
1a6c7cdf54
commit
fafa3e7114
2 changed files with 64 additions and 72 deletions
|
@ -14,20 +14,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.registration import RegistrationStore
|
||||
from synapse.storage.registration import RegistrationWorkerStore
|
||||
|
||||
|
||||
class SlavedRegistrationStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
# TODO: use the cached version and invalidate deleted tokens
|
||||
get_user_by_access_token = RegistrationStore.__dict__[
|
||||
"get_user_by_access_token"
|
||||
]
|
||||
|
||||
_query_for_auth = DataStore._query_for_auth.__func__
|
||||
get_user_by_id = RegistrationStore.__dict__[
|
||||
"get_user_by_id"
|
||||
]
|
||||
class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
|
||||
pass
|
||||
|
|
|
@ -19,10 +19,70 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import StoreError, Codes
|
||||
from synapse.storage import background_updates
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
|
||||
|
||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||
class RegistrationWorkerStore(SQLBaseStore):
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
},
|
||||
retcols=["name", "password_hash", "is_guest"],
|
||||
allow_none=True,
|
||||
desc="get_user_by_id",
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_by_access_token(self, token):
|
||||
"""Get a user from the given access token.
|
||||
|
||||
Args:
|
||||
token (str): The access token of a user.
|
||||
Returns:
|
||||
defer.Deferred: None, if the token did not match, otherwise dict
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_user_by_access_token",
|
||||
self._query_for_auth,
|
||||
token
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_server_admin(self, user):
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user.to_string()},
|
||||
retcol="admin",
|
||||
allow_none=True,
|
||||
desc="is_server_admin",
|
||||
)
|
||||
|
||||
defer.returnValue(res if res else False)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
||||
" access_tokens.device_id"
|
||||
" FROM users"
|
||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||
" WHERE token = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if rows:
|
||||
return rows[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RegistrationStore(RegistrationWorkerStore,
|
||||
background_updates.BackgroundUpdateStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||
|
@ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
)
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
},
|
||||
retcols=["name", "password_hash", "is_guest"],
|
||||
allow_none=True,
|
||||
desc="get_user_by_id",
|
||||
)
|
||||
|
||||
def get_users_by_id_case_insensitive(self, user_id):
|
||||
"""Gets users that match user_id case insensitively.
|
||||
Returns a mapping of user_id -> password_hash.
|
||||
|
@ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
return self.runInteraction("delete_access_token", f)
|
||||
|
||||
@cached()
|
||||
def get_user_by_access_token(self, token):
|
||||
"""Get a user from the given access token.
|
||||
|
||||
Args:
|
||||
token (str): The access token of a user.
|
||||
Returns:
|
||||
defer.Deferred: None, if the token did not match, otherwise dict
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_user_by_access_token",
|
||||
self._query_for_auth,
|
||||
token
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_server_admin(self, user):
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user.to_string()},
|
||||
retcol="admin",
|
||||
allow_none=True,
|
||||
desc="is_server_admin",
|
||||
)
|
||||
|
||||
defer.returnValue(res if res else False)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def is_guest(self, user_id):
|
||||
res = yield self._simple_select_one_onecol(
|
||||
|
@ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
defer.returnValue(res if res else False)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
||||
" access_tokens.device_id"
|
||||
" FROM users"
|
||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||
" WHERE token = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if rows:
|
||||
return rows[0]
|
||||
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
yield self._simple_upsert("user_threepids", {
|
||||
|
|
Loading…
Reference in a new issue