mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-22 09:40:10 +01:00
Merge pull request #2727 from matrix-org/rav/refactor_ui_auth_return
Refactor UI auth implementation
This commit is contained in:
commit
aa6ecf0984
7 changed files with 103 additions and 48 deletions
|
@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InteractiveAuthIncompleteError(Exception):
|
||||||
|
"""An error raised when UI auth is not yet complete
|
||||||
|
|
||||||
|
(This indicates we should return a 401 with 'result' as the body)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
result (dict): the server response to the request, which should be
|
||||||
|
passed back to the client
|
||||||
|
"""
|
||||||
|
def __init__(self, result):
|
||||||
|
super(InteractiveAuthIncompleteError, self).__init__(
|
||||||
|
"Interactive auth not yet complete",
|
||||||
|
)
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
|
||||||
class UnrecognizedRequestError(SynapseError):
|
class UnrecognizedRequestError(SynapseError):
|
||||||
"""An error indicating we don't understand the request you're trying to make"""
|
"""An error indicating we don't understand the request you're trying to make"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
|
@ -17,7 +17,10 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
from synapse.api.errors import (
|
||||||
|
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
|
||||||
|
SynapseError,
|
||||||
|
)
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
@ -95,26 +98,36 @@ class AuthHandler(BaseHandler):
|
||||||
session with a map, which maps each auth-type (str) to the relevant
|
session with a map, which maps each auth-type (str) to the relevant
|
||||||
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||||
|
|
||||||
|
If no auth flows have been completed successfully, raises an
|
||||||
|
InteractiveAuthIncompleteError. To handle this, you can use
|
||||||
|
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
|
||||||
|
decorator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flows (list): A list of login flows. Each flow is an ordered list of
|
flows (list): A list of login flows. Each flow is an ordered list of
|
||||||
strings representing auth-types. At least one full
|
strings representing auth-types. At least one full
|
||||||
flow must be completed in order for auth to be successful.
|
flow must be completed in order for auth to be successful.
|
||||||
|
|
||||||
clientdict: The dictionary from the client root level, not the
|
clientdict: The dictionary from the client root level, not the
|
||||||
'auth' key: this method prompts for auth if none is sent.
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
|
|
||||||
clientip (str): The IP address of the client.
|
clientip (str): The IP address of the client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (authed, dict, dict, session_id) where authed is true if
|
defer.Deferred[dict, dict, str]: a deferred tuple of
|
||||||
the client has successfully completed an auth flow. If it is true
|
(creds, params, session_id).
|
||||||
the first dict contains the authenticated credentials of each stage.
|
|
||||||
|
|
||||||
If authed is false, the first dictionary is the server response to
|
'creds' contains the authenticated credentials of each stage.
|
||||||
the login request and should be passed back to the client.
|
|
||||||
|
|
||||||
In either case, the second dict contains the parameters for this
|
'params' contains the parameters for this request (which may
|
||||||
request (which may have been given only in a previous call).
|
have been given only in a previous call).
|
||||||
|
|
||||||
session_id is the ID of this session, either passed in by the client
|
'session_id' is the ID of this session, either passed in by the
|
||||||
or assigned by the call to check_auth
|
client or assigned by this call
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InteractiveAuthIncompleteError if the client has not yet completed
|
||||||
|
all the stages in any of the permitted flows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authdict = None
|
authdict = None
|
||||||
|
@ -142,11 +155,8 @@ class AuthHandler(BaseHandler):
|
||||||
clientdict = session['clientdict']
|
clientdict = session['clientdict']
|
||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
defer.returnValue(
|
raise InteractiveAuthIncompleteError(
|
||||||
(
|
self._auth_dict_for_flows(flows, session),
|
||||||
False, self._auth_dict_for_flows(flows, session),
|
|
||||||
clientdict, session['id']
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'creds' not in session:
|
if 'creds' not in session:
|
||||||
|
@ -190,12 +200,14 @@ class AuthHandler(BaseHandler):
|
||||||
"Auth completed with creds: %r. Client dict has keys: %r",
|
"Auth completed with creds: %r. Client dict has keys: %r",
|
||||||
creds, clientdict.keys()
|
creds, clientdict.keys()
|
||||||
)
|
)
|
||||||
defer.returnValue((True, creds, clientdict, session['id']))
|
defer.returnValue((creds, clientdict, session['id']))
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, session)
|
ret = self._auth_dict_for_flows(flows, session)
|
||||||
ret['completed'] = creds.keys()
|
ret['completed'] = creds.keys()
|
||||||
ret.update(errordict)
|
ret.update(errordict)
|
||||||
defer.returnValue((False, ret, clientdict, session['id']))
|
raise InteractiveAuthIncompleteError(
|
||||||
|
ret,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
|
|
||||||
"""This module contains base REST classes for constructing client v1 servlets.
|
"""This module contains base REST classes for constructing client v1 servlets.
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import logging
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||||
|
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
|
||||||
filter_json['room']['timeline']["limit"] = min(
|
filter_json['room']['timeline']["limit"] = min(
|
||||||
filter_json['room']['timeline']['limit'],
|
filter_json['room']['timeline']['limit'],
|
||||||
filter_timeline_limit)
|
filter_timeline_limit)
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_auth_handler(orig):
|
||||||
|
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
||||||
|
|
||||||
|
Takes a on_POST method which returns a deferred (errcode, body) response
|
||||||
|
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
||||||
|
a 401 response.
|
||||||
|
|
||||||
|
Normal usage is:
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
# ...
|
||||||
|
yield self.auth_handler.check_auth
|
||||||
|
"""
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
res = defer.maybeDeferred(orig, *args, **kwargs)
|
||||||
|
res.addErrback(_catch_incomplete_interactive_auth)
|
||||||
|
return res
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def _catch_incomplete_interactive_auth(f):
|
||||||
|
"""helper for interactive_auth_handler
|
||||||
|
|
||||||
|
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f (failure.Failure):
|
||||||
|
"""
|
||||||
|
f.trap(InteractiveAuthIncompleteError)
|
||||||
|
return 401, f.value.result
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet):
|
||||||
self.datastore = self.hs.get_datastore()
|
self.datastore = self.hs.get_datastore()
|
||||||
self._set_password_handler = hs.get_set_password_handler()
|
self._set_password_handler = hs.get_set_password_handler()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[LoginType.PASSWORD],
|
[LoginType.PASSWORD],
|
||||||
[LoginType.EMAIL_IDENTITY],
|
[LoginType.EMAIL_IDENTITY],
|
||||||
[LoginType.MSISDN],
|
[LoginType.MSISDN],
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
user_id = None
|
user_id = None
|
||||||
requester = None
|
requester = None
|
||||||
|
|
||||||
|
@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[LoginType.PASSWORD],
|
[LoginType.PASSWORD],
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
if LoginType.PASSWORD in result:
|
if LoginType.PASSWORD in result:
|
||||||
user_id = result[LoginType.PASSWORD]
|
user_id = result[LoginType.PASSWORD]
|
||||||
# if using password, they should also be logged in
|
# if using password, they should also be logged in
|
||||||
|
|
|
@ -19,7 +19,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api import constants, errors
|
from synapse.api import constants, errors
|
||||||
from synapse.http import servlet
|
from synapse.http import servlet
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
try:
|
try:
|
||||||
|
@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||||
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
||||||
)
|
)
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[constants.LoginType.PASSWORD],
|
[constants.LoginType.PASSWORD],
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
yield self.device_handler.delete_devices(
|
yield self.device_handler.delete_devices(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
|
@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
)
|
)
|
||||||
defer.returnValue((200, device))
|
defer.returnValue((200, device))
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, device_id):
|
def on_DELETE(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
result, params, _ = yield self.auth_handler.check_auth([
|
||||||
[constants.LoginType.PASSWORD],
|
[constants.LoginType.PASSWORD],
|
||||||
], body, self.hs.get_ip_from_request(request))
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
# check that the UI auth matched the access token
|
# check that the UI auth matched the access token
|
||||||
user_id = result[constants.LoginType.PASSWORD]
|
user_id = result[constants.LoginType.PASSWORD]
|
||||||
if user_id != requester.user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
|
|
|
@ -27,7 +27,7 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import hmac
|
import hmac
|
||||||
|
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||||
])
|
])
|
||||||
|
|
||||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, auth_result))
|
|
||||||
return
|
|
||||||
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
from twisted.python import failure
|
||||||
|
|
||||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
side_effect=lambda x: self.appservice)
|
side_effect=lambda x: self.appservice)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth_result = (False, None, None, None)
|
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
||||||
self.auth_handler = Mock(
|
self.auth_handler = Mock(
|
||||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||||
get_session_data=Mock(return_value=None)
|
get_session_data=Mock(return_value=None)
|
||||||
|
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.request.args = {
|
self.request.args = {
|
||||||
"access_token": "i_am_an_app_service"
|
"access_token": "i_am_an_app_service"
|
||||||
}
|
}
|
||||||
|
|
||||||
self.request_data = json.dumps({
|
self.request_data = json.dumps({
|
||||||
"username": "kermit"
|
"username": "kermit"
|
||||||
})
|
})
|
||||||
|
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
})
|
})
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
self.auth_result = (True, None, {
|
self.auth_result = (None, {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
|
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
})
|
})
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
self.auth_result = (True, None, {
|
self.auth_result = (None, {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
|
|
Loading…
Add table
Reference in a new issue