0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-11 12:31:58 +01:00

Make auth & transactions more testable (#3499)

This commit is contained in:
Amber Brown 2018-07-14 07:34:49 +10:00 committed by GitHub
parent 2aba1f549c
commit 33b60c01b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 96 additions and 97 deletions

0
changelog.d/3499.misc Normal file
View file

View file

@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service)
)
access_token = get_access_token_from_request(
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
get_access_token_from_request(
self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request):
try:
token = get_access_token_from_request(
token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
@ -673,67 +673,67 @@ class Auth(object):
" edit its room list entry"
)
@staticmethod
def has_access_token(request):
"""Checks if the request has an access_token.
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@staticmethod
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]
return query_params[0]

View file

@ -17,38 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
def __init__(self, clock):
self.clock = clock
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@ -56,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):

View file

@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(

View file

@ -23,7 +23,6 @@ from six import string_types
from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
@ -67,6 +66,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@ -310,7 +310,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request)
as_token = self.auth.get_access_token_from_request(request)
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
@ -400,7 +400,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)

View file

@ -20,7 +20,6 @@ from six.moves import http_client
from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
if has_access_token(request):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),

View file

@ -24,7 +24,6 @@ from twisted.internet import defer
import synapse
import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
if has_access_token(request):
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid
# IDs.
access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(

View file

@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):

View file

@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock)
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"