mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-21 01:02:01 +01:00
Convert the registration handler to async/await. (#7649)
This commit is contained in:
parent
375ca0cceb
commit
3c45a78090
3 changed files with 48 additions and 68 deletions
1
changelog.d/7649.misc
Normal file
1
changelog.d/7649.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert registration handler to async/await.
|
|
@ -16,8 +16,6 @@
|
||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
|
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
|
||||||
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
|
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
|
||||||
|
@ -75,8 +73,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
self.session_lifetime = hs.config.session_lifetime
|
self.session_lifetime = hs.config.session_lifetime
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_username(
|
||||||
def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
|
self, localpart, guest_access_token=None, assigned_user_id=None
|
||||||
|
):
|
||||||
if types.contains_invalid_mxid_characters(localpart):
|
if types.contains_invalid_mxid_characters(localpart):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
|
@ -113,13 +112,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
Codes.INVALID_USERNAME,
|
Codes.INVALID_USERNAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
users = yield self.store.get_users_by_id_case_insensitive(user_id)
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if users:
|
if users:
|
||||||
if not guest_access_token:
|
if not guest_access_token:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
)
|
)
|
||||||
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
|
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
||||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
|
@ -137,8 +136,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def register_user(
|
||||||
def register_user(
|
|
||||||
self,
|
self,
|
||||||
localpart=None,
|
localpart=None,
|
||||||
password_hash=None,
|
password_hash=None,
|
||||||
|
@ -169,18 +167,18 @@ class RegistrationHandler(BaseHandler):
|
||||||
by_admin (bool): True if this registration is being made via the
|
by_admin (bool): True if this registration is being made via the
|
||||||
admin api, otherwise False.
|
admin api, otherwise False.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[str]: user_id
|
str: user_id
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if there was a problem registering.
|
SynapseError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
yield self.check_registration_ratelimit(address)
|
self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
# do not check_auth_blocking if the call is coming through the Admin API
|
# do not check_auth_blocking if the call is coming through the Admin API
|
||||||
if not by_admin:
|
if not by_admin:
|
||||||
yield self.auth.check_auth_blocking(threepid=threepid)
|
await self.auth.check_auth_blocking(threepid=threepid)
|
||||||
|
|
||||||
if localpart is not None:
|
if localpart is not None:
|
||||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
await self.check_username(localpart, guest_access_token=guest_access_token)
|
||||||
|
|
||||||
was_guest = guest_access_token is not None
|
was_guest = guest_access_token is not None
|
||||||
|
|
||||||
|
@ -194,7 +192,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
elif default_display_name is None:
|
elif default_display_name is None:
|
||||||
default_display_name = localpart
|
default_display_name = localpart
|
||||||
|
|
||||||
yield self.register_with_store(
|
await self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
was_guest=was_guest,
|
was_guest=was_guest,
|
||||||
|
@ -206,12 +204,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
profile = yield self.store.get_profileinfo(localpart)
|
profile = await self.store.get_profileinfo(localpart)
|
||||||
yield defer.ensureDeferred(
|
await self.user_directory_handler.handle_local_profile_change(
|
||||||
self.user_directory_handler.handle_local_profile_change(
|
|
||||||
user_id, profile
|
user_id, profile
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# autogen a sequential user ID
|
# autogen a sequential user ID
|
||||||
|
@ -222,14 +218,14 @@ class RegistrationHandler(BaseHandler):
|
||||||
if fail_count > 10:
|
if fail_count > 10:
|
||||||
raise SynapseError(500, "Unable to find a suitable guest user ID")
|
raise SynapseError(500, "Unable to find a suitable guest user ID")
|
||||||
|
|
||||||
localpart = yield self._generate_user_id()
|
localpart = await self._generate_user_id()
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_not_appservice_exclusive(user_id)
|
self.check_user_id_not_appservice_exclusive(user_id)
|
||||||
if default_display_name is None:
|
if default_display_name is None:
|
||||||
default_display_name = localpart
|
default_display_name = localpart
|
||||||
try:
|
try:
|
||||||
yield self.register_with_store(
|
await self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
make_guest=make_guest,
|
make_guest=make_guest,
|
||||||
|
@ -252,7 +248,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
|
await self._auto_join_rooms(user_id)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping auto-join for %s because consent is required at registration",
|
"Skipping auto-join for %s because consent is required at registration",
|
||||||
|
@ -270,7 +266,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Bind email to new account
|
# Bind email to new account
|
||||||
yield self._register_email_threepid(user_id, threepid_dict, None)
|
await self._register_email_threepid(user_id, threepid_dict, None)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
@ -335,8 +331,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
await self._auto_join_rooms(user_id)
|
await self._auto_join_rooms(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def appservice_register(self, user_localpart, as_token):
|
||||||
def appservice_register(self, user_localpart, as_token):
|
|
||||||
user = UserID(user_localpart, self.hs.hostname)
|
user = UserID(user_localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
service = self.store.get_app_service_by_token(as_token)
|
service = self.store.get_app_service_by_token(as_token)
|
||||||
|
@ -351,11 +346,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
service_id = service.id if service.is_exclusive_user(user_id) else None
|
service_id = service.id if service.is_exclusive_user(user_id) else None
|
||||||
|
|
||||||
yield self.check_user_id_not_appservice_exclusive(
|
self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service)
|
||||||
user_id, allowed_appservice=service
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.register_with_store(
|
await self.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash="",
|
password_hash="",
|
||||||
appservice_id=service_id,
|
appservice_id=service_id,
|
||||||
|
@ -387,13 +380,12 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE,
|
errcode=Codes.EXCLUSIVE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _generate_user_id(self):
|
||||||
def _generate_user_id(self):
|
|
||||||
if self._next_generated_user_id is None:
|
if self._next_generated_user_id is None:
|
||||||
with (yield self._generate_user_id_linearizer.queue(())):
|
with await self._generate_user_id_linearizer.queue(()):
|
||||||
if self._next_generated_user_id is None:
|
if self._next_generated_user_id is None:
|
||||||
self._next_generated_user_id = (
|
self._next_generated_user_id = (
|
||||||
yield self.store.find_next_generated_user_id_localpart()
|
await self.store.find_next_generated_user_id_localpart()
|
||||||
)
|
)
|
||||||
|
|
||||||
id = self._next_generated_user_id
|
id = self._next_generated_user_id
|
||||||
|
@ -496,8 +488,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def register_device(
|
||||||
def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
|
self, user_id, device_id, initial_display_name, is_guest=False
|
||||||
|
):
|
||||||
"""Register a device for a user and generate an access token.
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
The access token will be limited by the homeserver's session_lifetime config.
|
The access token will be limited by the homeserver's session_lifetime config.
|
||||||
|
@ -511,11 +504,11 @@ class RegistrationHandler(BaseHandler):
|
||||||
is_guest (bool): Whether this is a guest account
|
is_guest (bool): Whether this is a guest account
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
|
tuple[str, str]: Tuple of device ID and access token
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
r = yield self._register_device_client(
|
r = await self._register_device_client(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
initial_display_name=initial_display_name,
|
initial_display_name=initial_display_name,
|
||||||
|
@ -531,7 +524,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
valid_until_ms = self.clock.time_msec() + self.session_lifetime
|
||||||
|
|
||||||
device_id = yield self.device_handler.check_device_registered(
|
device_id = await self.device_handler.check_device_registered(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
if is_guest:
|
if is_guest:
|
||||||
|
@ -540,11 +533,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id, ["guest = true"]
|
user_id, ["guest = true"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
access_token = yield defer.ensureDeferred(
|
access_token = await self._auth_handler.get_access_token_for_user_id(
|
||||||
self._auth_handler.get_access_token_for_user_id(
|
|
||||||
user_id, device_id=device_id, valid_until_ms=valid_until_ms
|
user_id, device_id=device_id, valid_until_ms=valid_until_ms
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return (device_id, access_token)
|
return (device_id, access_token)
|
||||||
|
|
||||||
|
@ -594,8 +585,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
await self.store.user_set_consent_version(user_id, consent_version)
|
await self.store.user_set_consent_version(user_id, consent_version)
|
||||||
await self.post_consent_actions(user_id)
|
await self.post_consent_actions(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _register_email_threepid(self, user_id, threepid, token):
|
||||||
def _register_email_threepid(self, user_id, threepid, token):
|
|
||||||
"""Add an email address as a 3pid identifier
|
"""Add an email address as a 3pid identifier
|
||||||
|
|
||||||
Also adds an email pusher for the email address, if configured in the
|
Also adds an email pusher for the email address, if configured in the
|
||||||
|
@ -608,8 +598,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
threepid (object): m.login.email.identity auth response
|
threepid (object): m.login.email.identity auth response
|
||||||
token (str|None): access_token for the user, or None if not logged
|
token (str|None): access_token for the user, or None if not logged
|
||||||
in.
|
in.
|
||||||
Returns:
|
|
||||||
defer.Deferred:
|
|
||||||
"""
|
"""
|
||||||
reqd = ("medium", "address", "validated_at")
|
reqd = ("medium", "address", "validated_at")
|
||||||
if any(x not in threepid for x in reqd):
|
if any(x not in threepid for x in reqd):
|
||||||
|
@ -617,13 +605,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
logger.info("Can't add incomplete 3pid")
|
logger.info("Can't add incomplete 3pid")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
await self._auth_handler.add_threepid(
|
||||||
self._auth_handler.add_threepid(
|
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
|
||||||
user_id,
|
|
||||||
threepid["medium"],
|
|
||||||
threepid["address"],
|
|
||||||
threepid["validated_at"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# And we add an email pusher for them by default, but only
|
# And we add an email pusher for them by default, but only
|
||||||
|
@ -639,10 +622,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
# It would really make more sense for this to be passed
|
# It would really make more sense for this to be passed
|
||||||
# up when the access token is saved, but that's quite an
|
# up when the access token is saved, but that's quite an
|
||||||
# invasive change I'd rather do separately.
|
# invasive change I'd rather do separately.
|
||||||
user_tuple = yield self.store.get_user_by_access_token(token)
|
user_tuple = await self.store.get_user_by_access_token(token)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple["token_id"]
|
||||||
|
|
||||||
yield self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="email",
|
kind="email",
|
||||||
|
@ -654,8 +637,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _register_msisdn_threepid(self, user_id, threepid):
|
||||||
def _register_msisdn_threepid(self, user_id, threepid):
|
|
||||||
"""Add a phone number as a 3pid identifier
|
"""Add a phone number as a 3pid identifier
|
||||||
|
|
||||||
Must be called on master.
|
Must be called on master.
|
||||||
|
@ -663,8 +645,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
Args:
|
Args:
|
||||||
user_id (str): id of user
|
user_id (str): id of user
|
||||||
threepid (object): m.login.msisdn auth response
|
threepid (object): m.login.msisdn auth response
|
||||||
Returns:
|
|
||||||
defer.Deferred:
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
|
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
|
||||||
|
@ -675,11 +655,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
return None
|
return None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
yield defer.ensureDeferred(
|
await self._auth_handler.add_threepid(
|
||||||
self._auth_handler.add_threepid(
|
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
|
||||||
user_id,
|
|
||||||
threepid["medium"],
|
|
||||||
threepid["address"],
|
|
||||||
threepid["validated_at"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -128,8 +128,12 @@ class ModuleApi(object):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[str]: user_id
|
Deferred[str]: user_id
|
||||||
"""
|
"""
|
||||||
return self._hs.get_registration_handler().register_user(
|
return defer.ensureDeferred(
|
||||||
localpart=localpart, default_display_name=displayname, bind_emails=emails
|
self._hs.get_registration_handler().register_user(
|
||||||
|
localpart=localpart,
|
||||||
|
default_display_name=displayname,
|
||||||
|
bind_emails=emails,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_device(self, user_id, device_id=None, initial_display_name=None):
|
def register_device(self, user_id, device_id=None, initial_display_name=None):
|
||||||
|
|
Loading…
Add table
Reference in a new issue