mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-11 20:42:23 +01:00
Make auth & transactions more testable (#3499)
This commit is contained in:
parent
2aba1f549c
commit
33b60c01b5
10 changed files with 96 additions and 97 deletions
0
changelog.d/3499.misc
Normal file
0
changelog.d/3499.misc
Normal file
|
@ -193,7 +193,7 @@ class Auth(object):
|
||||||
synapse.types.create_requester(user_id, app_service=app_service)
|
synapse.types.create_requester(user_id, app_service=app_service)
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = get_access_token_from_request(
|
access_token = self.get_access_token_from_request(
|
||||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ class Auth(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_appservice_user_id(self, request):
|
def _get_appservice_user_id(self, request):
|
||||||
app_service = self.store.get_app_service_by_token(
|
app_service = self.store.get_app_service_by_token(
|
||||||
get_access_token_from_request(
|
self.get_access_token_from_request(
|
||||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -513,7 +513,7 @@ class Auth(object):
|
||||||
|
|
||||||
def get_appservice_by_req(self, request):
|
def get_appservice_by_req(self, request):
|
||||||
try:
|
try:
|
||||||
token = get_access_token_from_request(
|
token = self.get_access_token_from_request(
|
||||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||||
)
|
)
|
||||||
service = self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
|
@ -673,8 +673,8 @@ class Auth(object):
|
||||||
" edit its room list entry"
|
" edit its room list entry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def has_access_token(request):
|
def has_access_token(request):
|
||||||
"""Checks if the request has an access_token.
|
"""Checks if the request has an access_token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -684,8 +684,8 @@ def has_access_token(request):
|
||||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
return bool(query_params) or bool(auth_headers)
|
return bool(query_params) or bool(auth_headers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def get_access_token_from_request(request, token_not_found_http_status=401):
|
def get_access_token_from_request(request, token_not_found_http_status=401):
|
||||||
"""Extracts the access_token from the request.
|
"""Extracts the access_token from the request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -17,14 +17,28 @@
|
||||||
to ensure idempotency when performing PUTs using the REST API."""
|
to ensure idempotency when performing PUTs using the REST API."""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.auth import get_access_token_from_request
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
||||||
|
|
||||||
def get_transaction_key(request):
|
|
||||||
|
class HttpTransactionCache(object):
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = self.hs.get_auth()
|
||||||
|
self.clock = self.hs.get_clock()
|
||||||
|
self.transactions = {
|
||||||
|
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
||||||
|
}
|
||||||
|
# Try to clean entries every 30 mins. This means entries will exist
|
||||||
|
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
||||||
|
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
||||||
|
|
||||||
|
def _get_transaction_key(self, request):
|
||||||
"""A helper function which returns a transaction key that can be used
|
"""A helper function which returns a transaction key that can be used
|
||||||
with TransactionCache for idempotent requests.
|
with TransactionCache for idempotent requests.
|
||||||
|
|
||||||
|
@ -38,24 +52,9 @@ def get_transaction_key(request):
|
||||||
Returns:
|
Returns:
|
||||||
str: A transaction key
|
str: A transaction key
|
||||||
"""
|
"""
|
||||||
token = get_access_token_from_request(request)
|
token = self.auth.get_access_token_from_request(request)
|
||||||
return request.path + "/" + token
|
return request.path + "/" + token
|
||||||
|
|
||||||
|
|
||||||
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
|
||||||
|
|
||||||
|
|
||||||
class HttpTransactionCache(object):
|
|
||||||
|
|
||||||
def __init__(self, clock):
|
|
||||||
self.clock = clock
|
|
||||||
self.transactions = {
|
|
||||||
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
|
||||||
}
|
|
||||||
# Try to clean entries every 30 mins. This means entries will exist
|
|
||||||
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
|
||||||
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
|
||||||
|
|
||||||
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
||||||
"""A helper function for fetch_or_execute which extracts
|
"""A helper function for fetch_or_execute which extracts
|
||||||
a transaction key from the given request.
|
a transaction key from the given request.
|
||||||
|
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
|
||||||
fetch_or_execute
|
fetch_or_execute
|
||||||
"""
|
"""
|
||||||
return self.fetch_or_execute(
|
return self.fetch_or_execute(
|
||||||
get_transaction_key(request), fn, *args, **kwargs
|
self._get_transaction_key(request), fn, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
|
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
|
||||||
|
|
|
@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.builder_factory = hs.get_event_builder_factory()
|
self.builder_factory = hs.get_event_builder_factory()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.txns = HttpTransactionCache(hs.get_clock())
|
self.txns = HttpTransactionCache(hs)
|
||||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.auth import get_access_token_from_request
|
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
|
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
||||||
if requester.device_id is None:
|
if requester.device_id is None:
|
||||||
# the acccess token wasn't associated with a device.
|
# the acccess token wasn't associated with a device.
|
||||||
# Just delete the access token
|
# Just delete the access token
|
||||||
access_token = get_access_token_from_request(request)
|
access_token = self._auth.get_access_token_from_request(request)
|
||||||
yield self._auth_handler.delete_access_token(access_token)
|
yield self._auth_handler.delete_access_token(access_token)
|
||||||
else:
|
else:
|
||||||
yield self._device_handler.delete_device(
|
yield self._device_handler.delete_device(
|
||||||
|
|
|
@ -23,7 +23,6 @@ from six import string_types
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
from synapse.api.auth import get_access_token_from_request
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
@ -67,6 +66,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
# TODO: persistent storage
|
# TODO: persistent storage
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.enable_registration = hs.config.enable_registration
|
self.enable_registration = hs.config.enable_registration
|
||||||
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
|
@ -310,7 +310,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_app_service(self, request, register_json, session):
|
def _do_app_service(self, request, register_json, session):
|
||||||
as_token = get_access_token_from_request(request)
|
as_token = self.auth.get_access_token_from_request(request)
|
||||||
|
|
||||||
if "user" not in register_json:
|
if "user" not in register_json:
|
||||||
raise SynapseError(400, "Expected 'user' key.")
|
raise SynapseError(400, "Expected 'user' key.")
|
||||||
|
@ -400,7 +400,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
user_json = parse_json_object_from_request(request)
|
user_json = parse_json_object_from_request(request)
|
||||||
|
|
||||||
access_token = get_access_token_from_request(request)
|
access_token = self.auth.get_access_token_from_request(request)
|
||||||
app_service = self.store.get_app_service_by_token(
|
app_service = self.store.get_app_service_by_token(
|
||||||
access_token
|
access_token
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,6 @@ from six.moves import http_client
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.auth import has_access_token
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
|
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
#
|
#
|
||||||
# In the second case, we require a password to confirm their identity.
|
# In the second case, we require a password to confirm their identity.
|
||||||
|
|
||||||
if has_access_token(request):
|
if self.auth.has_access_token(request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
params = yield self.auth_handler.validate_user_via_ui_auth(
|
params = yield self.auth_handler.validate_user_via_ui_auth(
|
||||||
requester, body, self.hs.get_ip_from_request(request),
|
requester, body, self.hs.get_ip_from_request(request),
|
||||||
|
|
|
@ -24,7 +24,6 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
|
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
|
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
desired_username = body['username']
|
desired_username = body['username']
|
||||||
|
|
||||||
appservice = None
|
appservice = None
|
||||||
if has_access_token(request):
|
if self.auth.has_access_token(request):
|
||||||
appservice = yield self.auth.get_appservice_by_req(request)
|
appservice = yield self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
# fork off as soon as possible for ASes and shared secret auth which
|
# fork off as soon as possible for ASes and shared secret auth which
|
||||||
|
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
# because the IRC bridges rely on being able to register stupid
|
# because the IRC bridges rely on being able to register stupid
|
||||||
# IDs.
|
# IDs.
|
||||||
|
|
||||||
access_token = get_access_token_from_request(request)
|
access_token = self.auth.get_access_token_from_request(request)
|
||||||
|
|
||||||
if isinstance(desired_username, string_types):
|
if isinstance(desired_username, string_types):
|
||||||
result = yield self._do_appservice_registration(
|
result = yield self._do_appservice_registration(
|
||||||
|
|
|
@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
||||||
super(SendToDeviceRestServlet, self).__init__()
|
super(SendToDeviceRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.txns = HttpTransactionCache(hs.get_clock())
|
self.txns = HttpTransactionCache(hs)
|
||||||
self.device_message_handler = hs.get_device_message_handler()
|
self.device_message_handler = hs.get_device_message_handler()
|
||||||
|
|
||||||
def on_PUT(self, request, message_type, txn_id):
|
def on_PUT(self, request, message_type, txn_id):
|
||||||
|
|
|
@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.clock = MockClock()
|
self.clock = MockClock()
|
||||||
self.cache = HttpTransactionCache(self.clock)
|
self.hs = Mock()
|
||||||
|
self.hs.get_clock = Mock(return_value=self.clock)
|
||||||
|
self.hs.get_auth = Mock()
|
||||||
|
self.cache = HttpTransactionCache(self.hs)
|
||||||
|
|
||||||
self.mock_http_response = (200, "GOOD JOB!")
|
self.mock_http_response = (200, "GOOD JOB!")
|
||||||
self.mock_key = "foo"
|
self.mock_key = "foo"
|
||||||
|
|
Loading…
Reference in a new issue