0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-11 20:42:23 +01:00

Make auth & transactions more testable (#3499)

This commit is contained in:
Amber Brown 2018-07-14 07:34:49 +10:00 committed by GitHub
parent 2aba1f549c
commit 33b60c01b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 96 additions and 97 deletions

0
changelog.d/3499.misc Normal file
View file

View file

@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service) synapse.types.create_requester(user_id, app_service=app_service)
) )
access_token = get_access_token_from_request( access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
get_access_token_from_request( self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
) )
@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
@ -673,8 +673,8 @@ class Auth(object):
" edit its room list entry" " edit its room list entry"
) )
@staticmethod
def has_access_token(request): def has_access_token(request):
"""Checks if the request has an access_token. """Checks if the request has an access_token.
Returns: Returns:
@ -684,8 +684,8 @@ def has_access_token(request):
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@staticmethod
def get_access_token_from_request(request, token_not_found_http_status=401): def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request. """Extracts the access_token from the request.
Args: Args:

View file

@ -17,14 +17,28 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
def get_transaction_key(request):
class HttpTransactionCache(object):
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
"""A helper function which returns a transaction key that can be used """A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests. with TransactionCache for idempotent requests.
@ -38,24 +52,9 @@ def get_transaction_key(request):
Returns: Returns:
str: A transaction key str: A transaction key
""" """
token = get_access_token_from_request(request) token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
def __init__(self, clock):
self.clock = clock
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def fetch_or_execute_request(self, request, fn, *args, **kwargs): def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts """A helper function for fetch_or_execute which extracts
a transaction key from the given request. a transaction key from the given request.
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute fetch_or_execute
""" """
return self.fetch_or_execute( return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs self._get_transaction_key(request), fn, *args, **kwargs
) )
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): def fetch_or_execute(self, txn_key, fn, *args, **kwargs):

View file

@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock()) self.txns = HttpTransactionCache(hs)

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # the acccess token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = get_access_token_from_request(request) access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token) yield self._auth_handler.delete_access_token(access_token)
else: else:
yield self._device_handler.delete_device( yield self._device_handler.delete_device(

View file

@ -23,7 +23,6 @@ from six import string_types
from twisted.internet import defer from twisted.internet import defer
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -67,6 +66,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@ -310,7 +310,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request) as_token = self.auth.get_access_token_from_request(request)
if "user" not in register_json: if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.") raise SynapseError(400, "Expected 'user' key.")
@ -400,7 +400,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
access_token access_token
) )

View file

@ -20,7 +20,6 @@ from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
# #
# In the second case, we require a password to confirm their identity. # In the second case, we require a password to confirm their identity.
if has_access_token(request): if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth( params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),

View file

@ -24,7 +24,6 @@ from twisted.internet import defer
import synapse import synapse
import synapse.types import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username'] desired_username = body['username']
appservice = None appservice = None
if has_access_token(request): if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which # fork off as soon as possible for ASes and shared secret auth which
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid # because the IRC bridges rely on being able to register stupid
# IDs. # IDs.
access_token = get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types): if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(

View file

@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock()) self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):

View file

@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock) self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!") self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo" self.mock_key = "foo"