mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 05:33:46 +01:00
Move register_device into handler
This commit is contained in:
parent
8b9ae6d3a6
commit
af691e415c
5 changed files with 97 additions and 172 deletions
|
@ -27,6 +27,7 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||
from synapse.replication.http.register import ReplicationRegisterServlet
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
@ -64,6 +65,11 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
if hs.config.worker_app:
|
||||
self._register_client = ReplicationRegisterServlet.make_client(hs)
|
||||
self._register_device_client = (
|
||||
RegisterDeviceReplicationServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_username(self, localpart, guest_access_token=None,
|
||||
|
@ -159,7 +165,7 @@ class RegistrationHandler(BaseHandler):
|
|||
yield self.auth.check_auth_blocking(threepid=threepid)
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = yield self.auth_handler().hash(password)
|
||||
password_hash = yield self._auth_handler.hash(password)
|
||||
|
||||
if localpart:
|
||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||
|
@ -516,9 +522,6 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
|
||||
"""Get a guest access token for a 3PID, creating a guest account if
|
||||
|
@ -628,3 +631,43 @@ class RegistrationHandler(BaseHandler):
|
|||
admin=admin,
|
||||
user_type=user_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register_device(self, user_id, device_id, initial_display_name,
|
||||
is_guest=False):
|
||||
"""Register a device for a user and generate an access token.
|
||||
|
||||
Args:
|
||||
user_id (str): full canonical @user:id
|
||||
device_id (str|None): The device ID to check, or None to generate
|
||||
a new one.
|
||||
initial_display_name (str|None): An optional display name for the
|
||||
device.
|
||||
is_guest (bool): Whether this is a guest account
|
||||
|
||||
Returns:
|
||||
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
|
||||
"""
|
||||
|
||||
if self.hs.config.worker_app:
|
||||
r = yield self._register_device_client(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_display_name=initial_display_name,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
defer.returnValue((r["device_id"], r["access_token"]))
|
||||
else:
|
||||
device_id = yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
if is_guest:
|
||||
access_token = self.macaroon_gen.generate_access_token(
|
||||
user_id, ["guest = true"]
|
||||
)
|
||||
else:
|
||||
access_token = yield self._auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
)
|
||||
|
||||
defer.returnValue((device_id, access_token))
|
||||
|
|
|
@ -35,9 +35,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(RegisterDeviceReplicationServlet, self).__init__(hs)
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
||||
|
@ -62,19 +60,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||
initial_display_name = content["initial_display_name"]
|
||||
is_guest = content["is_guest"]
|
||||
|
||||
device_id = yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name,
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name, is_guest,
|
||||
)
|
||||
|
||||
if is_guest:
|
||||
access_token = self.macaroon_gen.generate_access_token(
|
||||
user_id, ["guest = true"]
|
||||
)
|
||||
else:
|
||||
access_token = yield self.auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"device_id": device_id,
|
||||
"access_token": access_token,
|
||||
|
|
|
@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.device_handler = self.hs.get_device_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.handlers = hs.get_handlers()
|
||||
self._well_known_builder = WellKnownBuilder(hs)
|
||||
|
||||
|
@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
login_submission,
|
||||
)
|
||||
|
||||
device_id = yield self._register_device(
|
||||
canonical_user_id, login_submission,
|
||||
)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
canonical_user_id, device_id,
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
canonical_user_id, device_id, initial_display_name,
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id,
|
||||
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
)
|
||||
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
|
@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
auth_handler = self.auth_handler
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
device_id = yield self._register_device(
|
||||
registered_user_id, login_submission
|
||||
)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
registered_user_id, device_id,
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
"home_server": self.hs.hostname,
|
||||
}
|
||||
else:
|
||||
# TODO: we should probably check that the register isn't going
|
||||
# to fonx/change our user_id before registering the device
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
user_id, access_token = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
)
|
||||
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
|
@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue(result)
|
||||
|
||||
def _register_device(self, user_id, login_submission):
|
||||
"""Register a device for a user.
|
||||
|
||||
This is called after the user's credentials have been validated, but
|
||||
before the access token has been issued.
|
||||
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) login_submission: dictionary supplied to /login call, from
|
||||
which we pull device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (str) device_id
|
||||
"""
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get(
|
||||
"initial_device_display_name")
|
||||
return self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
|
||||
class CasRedirectServlet(RestServlet):
|
||||
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
|
||||
|
|
|
@ -33,7 +33,6 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
|
@ -193,13 +192,6 @@ class RegisterRestServlet(RestServlet):
|
|||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
if self.hs.config.worker_app:
|
||||
self._register_device_client = (
|
||||
RegisterDeviceReplicationServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -642,7 +634,7 @@ class RegisterRestServlet(RestServlet):
|
|||
if not params.get("inhibit_login", False):
|
||||
device_id = params.get("device_id")
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id, access_token = yield self._register_device(
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name, is_guest=False,
|
||||
)
|
||||
|
||||
|
@ -652,43 +644,6 @@ class RegisterRestServlet(RestServlet):
|
|||
})
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _register_device(self, user_id, device_id, initial_display_name,
|
||||
is_guest):
|
||||
"""Register a device for a user and generate an access token.
|
||||
|
||||
Args:
|
||||
user_id (str): full canonical @user:id
|
||||
device_id (str|None): The device ID to check, or None to generate
|
||||
a new one.
|
||||
initial_display_name (str|None): An optional display name for the
|
||||
device.
|
||||
is_guest (bool): Whether this is a guest account
|
||||
Returns:
|
||||
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
|
||||
"""
|
||||
if self.hs.config.worker_app:
|
||||
r = yield self._register_device_client(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_display_name=initial_display_name,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
defer.returnValue((r["device_id"], r["access_token"]))
|
||||
else:
|
||||
device_id = yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
if is_guest:
|
||||
access_token = self.macaroon_gen.generate_access_token(
|
||||
user_id, ["guest = true"]
|
||||
)
|
||||
else:
|
||||
access_token = yield self.auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
)
|
||||
defer.returnValue((device_id, access_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self, params):
|
||||
if not self.hs.config.allow_guest_access:
|
||||
|
@ -702,7 +657,7 @@ class RegisterRestServlet(RestServlet):
|
|||
# we have nowhere to store it.
|
||||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id, access_token = yield self._register_device(
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name, is_guest=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client.v2_alpha.register import register_servlets
|
||||
|
||||
from tests import unittest
|
||||
|
@ -18,61 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.url = b"/_matrix/client/r0/register"
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(
|
||||
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
|
||||
)
|
||||
|
||||
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None),
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
self.device_handler = Mock()
|
||||
|
||||
def check_device_registered(user_id, device_id, initial_display_name):
|
||||
# Just echo back the given device ID, or return a new "FAKE" device
|
||||
# ID
|
||||
if device_id:
|
||||
return device_id
|
||||
else:
|
||||
return "FAKE"
|
||||
|
||||
self.device_handler.check_device_registered = Mock(
|
||||
side_effect=check_device_registered,
|
||||
)
|
||||
|
||||
self.datastore = Mock(return_value=Mock())
|
||||
self.datastore.get_current_state_deltas = Mock(return_value=[])
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler,
|
||||
)
|
||||
self.hs = self.setup_test_homeserver()
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||
self.hs.config.enable_registration = True
|
||||
self.hs.config.registrations_require_3pid = []
|
||||
self.hs.config.auto_join_rooms = []
|
||||
self.hs.config.enable_registration_captcha = False
|
||||
|
||||
return self.hs
|
||||
|
||||
def test_POST_appservice_registration_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.appservice = {"id": "1234"}
|
||||
self.registration_handler.appservice_register = Mock(return_value=user_id)
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
user_id = "@as_user_kermit:test"
|
||||
as_token = "i_am_an_app_service"
|
||||
|
||||
appservice = ApplicationService(
|
||||
as_token, self.hs.config.hostname,
|
||||
id="1234",
|
||||
namespaces={
|
||||
"users": [{"regex": r"@as_user.*", "exclusive": True}],
|
||||
},
|
||||
)
|
||||
|
||||
self.hs.get_datastore().services_cache.append(appservice)
|
||||
request_data = json.dumps({"username": "as_user_kermit"})
|
||||
|
||||
request, channel = self.make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
|
@ -82,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||
|
@ -114,37 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEquals(channel.json_body["error"], "Invalid username")
|
||||
|
||||
def test_POST_user_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
user_id = "@kermit:test"
|
||||
device_id = "frogfone"
|
||||
params = {"username": "kermit", "password": "monkey", "device_id": device_id}
|
||||
params = {
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
"device_id": device_id,
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, params, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||
self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id, initial_device_display_name=None
|
||||
)
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.enable_registration = False
|
||||
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
@ -153,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
||||
|
||||
def test_POST_guest_registration(self):
|
||||
user_id = "a@b"
|
||||
self.hs.config.macaroon_secret_key = "test"
|
||||
self.hs.config.allow_guest_access = True
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": "guest_device",
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue