Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2018-07-17 10:04:33 +01:00
commit f793ff4571
37 changed files with 244 additions and 289 deletions

View file

@ -23,6 +23,9 @@ matrix:
- python: 3.6
env: TOX_ENV=py36
- python: 3.6
env: TOX_ENV=check_isort
- python: 3.6
env: TOX_ENV=check-newsfragment

0
changelog.d/3499.misc Normal file
View file

0
changelog.d/3530.misc Normal file
View file

1
changelog.d/3533.bugfix Normal file
View file

@ -0,0 +1 @@
Fix queued federation requests being processed in the wrong order

1
changelog.d/3534.misc Normal file
View file

@ -0,0 +1 @@
refactor: use parse_{string,integer} and assert's from http.servlet for deduplication

0
changelog.d/3535.misc Normal file
View file

1
changelog.d/3540.misc Normal file
View file

@ -0,0 +1 @@
check isort for each PR

View file

@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service)
)
access_token = get_access_token_from_request(
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
get_access_token_from_request(
self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request):
try:
token = get_access_token_from_request(
token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
@ -673,67 +673,67 @@ class Auth(object):
" edit its room list entry"
)
@staticmethod
def has_access_token(request):
"""Checks if the request has an access_token.
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@staticmethod
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]
return query_params[0]

View file

@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_request
from synapse.http.servlet import assert_params_in_dict
from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):

View file

@ -38,7 +38,7 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
@ -53,7 +53,7 @@ class RoomListHandler(BaseHandler):
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
network_tuple=EMTPY_THIRD_PARTY_ID,):
network_tuple=EMPTY_THIRD_PARTY_ID,):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@ -90,7 +90,7 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
network_tuple=EMTPY_THIRD_PARTY_ID,):
network_tuple=EMPTY_THIRD_PARTY_ID,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:

View file

@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
return content
def assert_params_in_request(body, required):
def assert_params_in_dict(body, required):
absent = []
for k in required:
if k not in body:

View file

@ -20,7 +20,7 @@ from twisted.web.server import Request, Site
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics
from synapse.util.logcontext import LoggingContext, ContextResourceUsage
from synapse.util.logcontext import ContextResourceUsage, LoggingContext
logger = logging.getLogger(__name__)

View file

@ -17,38 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
def __init__(self, clock):
self.clock = clock
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@ -56,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):

View file

@ -22,7 +22,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns
@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None)
if not before_ts:
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
before_ts = parse_integer(request, "before_ts", required=True)
logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts)
@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
new_room_user_id = content.get("new_room_user_id")
if not new_room_user_id:
raise SynapseError(400, "Please provide field `new_room_user_id`")
assert_params_in_dict(content, ["new_room_user_id"])
new_room_user_id = content["new_room_user_id"]
room_creator_requester = create_requester(new_room_user_id)
@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
if not new_password:
raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
start = request.args.get("start")[0]
limit = request.args.get("limit")[0]
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
start = parse_integer(request, "start", required=True)
limit = parse_integer(request, "limit", required=True)
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["limit", "start"])
limit = params['limit']
start = params['start']
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
term = request.args.get("term")[0]
if not term:
raise SynapseError(400, "Missing 'term' arg")
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(

View file

@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)

View file

@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
raise SynapseError(400, "Missing room_id key",
raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
room_alias = RoomAlias.from_string(room_alias)
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]

View file

@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.http.servlet import parse_boolean
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = request.args.get("archived", None) == ["true"]
include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(

View file

@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
from synapse.http.servlet import parse_json_value_from_request
from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@ -75,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
before = parse_string(request, "before")
if before:
before = _namespaced_rule_id(spec, before[0])
before = _namespaced_rule_id(spec, before)
after = request.args.get("after", None)
after = parse_string(request, "after")
if after:
after = _namespaced_rule_id(spec, after[0])
after = _namespaced_rule_id(spec, after)
try:
yield self.store.add_push_rule(

View file

@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
)
defer.returnValue((200, {}))
reqd = ['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
missing = []
for i in reqd:
if i not in content:
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
assert_params_in_dict(
content,
['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
super(PushersRemoveRestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()

View file

@ -18,15 +18,12 @@ import hmac
import logging
from hashlib import sha1
from six import string_types
from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns
@ -67,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
if "type" not in register_json:
raise SynapseError(400, "Missing 'type' key.")
assert_params_in_dict(register_json, ["type"])
try:
login_type = register_json["type"]
@ -310,11 +307,9 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request)
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
as_token = self.auth.get_access_token_from_request(request)
assert_params_in_dict(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@ -400,7 +390,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)
@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
if "localpart" not in user_json:
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
assert_params_in_dict(user_json, ["localpart", "displayname"])
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")

View file

@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
@ -435,9 +436,9 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0])
limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context(
requester.user,
@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
assert_params_in_dict(content, ["user_id"])
target = UserID.from_string(content["user_id"])
event_content = None
@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0]
batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
requester.user,
content,

View file

@ -20,12 +20,11 @@ from six.moves import http_client
from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_request,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.util.msisdn import phone_number_to_msisdn
@ -48,7 +47,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
@ -81,7 +80,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
if has_access_token(request):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@ -160,11 +159,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
logger.error("Auth succeeded but no known type!", result.keys())
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
yield self._set_password_handler.set_password(
@ -229,15 +227,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
assert_params_in_dict(
body,
['id_server', 'client_secret', 'email', 'send_attempt'],
)
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
@ -267,18 +260,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = [
assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
]
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@ -373,15 +358,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
assert_params_in_dict(body, ['medium', 'address'])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

View file

@ -18,14 +18,18 @@ import logging
from twisted.internet import defer
from synapse.api import errors
from synapse.http import servlet
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet):
class DeleteDevicesRestServlet(RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
body = servlet.parse_json_object_from_request(request)
body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict
# DELETE
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
if 'devices' not in body:
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
assert_params_in_dict(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
body = servlet.parse_json_object_from_request(request)
body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request)
body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,

View file

@ -24,12 +24,11 @@ from twisted.internet import defer
import synapse
import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
RestServlet,
assert_params_in_request,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
@ -69,7 +68,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
@ -105,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
if has_access_token(request):
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid
# IDs.
access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
@ -387,9 +386,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None)
new_password = params.get("password", None)
@ -566,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue()
try:
assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
except SynapseError as ex:
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue(None)
raise
yield self.auth_handler.add_threepid(
user_id,

View file

@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_request,
assert_params_in_dict,
parse_json_object_from_request,
)
@ -50,7 +50,7 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
assert_params_in_request(body, ("reason", "score"))
assert_params_in_dict(body, ("reason", "score"))
if not isinstance(body["reason"], string_types):
raise SynapseError(

View file

@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):

View file

@ -16,6 +16,8 @@ from pydenticon import Generator
from twisted.web.resource import Resource
from synapse.http.servlet import parse_integer
FOREGROUND = [
"rgb(45,79,255)",
"rgb(254,180,44)",
@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request):
name = "/".join(request.postpath)
width = int(request.args.get("width", [96])[0])
height = int(request.args.get("height", [96])[0])
width = parse_integer(request, "width", default=96)
height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png")
request.setHeader(

View file

@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request)
url = request.args.get("url")[0]
url = parse_string(request, "url")
if "ts" in request.args:
ts = int(request.args.get("ts")[0])
ts = parse_integer(request, "ts")
else:
ts = self.clock.time_msec()

View file

@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
upload_name = request.args.get("filename", None)
upload_name = parse_string(request, "filename")
if upload_name:
try:
upload_name = upload_name[0].decode('UTF-8')
upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),

View file

@ -16,6 +16,7 @@
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod
def from_request(cls, request, raise_invalid_params=True,
default_limit=None):
def get_param(name, default=None):
lst = request.args.get(name, [])
if len(lst) > 1:
raise SynapseError(
400, "%s must be specified only once" % (name,)
)
elif len(lst) == 1:
return lst[0]
else:
return default
direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
direction = get_param("dir", 'f')
if direction not in ['f', 'b']:
raise SynapseError(400, "'dir' parameter is invalid.")
from_tok = get_param("from")
to_tok = get_param("to")
from_tok = parse_string(request, "from")
to_tok = parse_string(request, "to")
try:
if from_tok == "END":
@ -88,12 +76,10 @@ class PaginationConfig(object):
except Exception:
raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", None)
if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.")
limit = parse_integer(request, "limit", default=default_limit)
if limit is None:
limit = default_limit
if limit and limit < 0:
raise SynapseError(400, "Limit must be 0 or above")
try:
return PaginationConfig(from_tok, to_tok, direction, limit)

View file

@ -74,19 +74,12 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
changed_entities = {
result = {
self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos),
)
}
# we need to include entities which we don't know about, as well as
# those which are known to have changed since the stream pos.
result = {
e for e in entities
if e in changed_entities or e not in self._entity_to_key
}
self.metrics.inc_hits()
else:
result = set(entities)

View file

@ -92,13 +92,22 @@ class _PerHostRatelimiter(object):
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.sleep_sec = sleep_msec / 1000.0
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
# request_id objects for requests which have been slept
self.sleeping_requests = set()
# map from request_id object to Deferred for requests which are ready
# for processing but have been queued
self.ready_request_queue = collections.OrderedDict()
# request id objects for requests which are in progress
self.current_processing = set()
# times at which we have recently (within the last window_size ms)
# received requests.
self.request_times = []
@contextlib.contextmanager
@ -117,11 +126,15 @@ class _PerHostRatelimiter(object):
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
# remove any entries from request_times which aren't within the window
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
@ -134,9 +147,13 @@ class _PerHostRatelimiter(object):
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
"Ratelimiter: queueing request (queue now %i items)",
len(self.ready_request_queue),
)
return queue_defer
else:
return defer.succeed(None)
@ -148,10 +165,9 @@ class _PerHostRatelimiter(object):
if len(self.request_times) > self.sleep_limit:
logger.debug(
"Ratelimit [%s]: sleeping req",
id(request_id),
"Ratelimiter: sleeping request for %f sec", self.sleep_sec,
)
ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
@ -200,11 +216,8 @@ class _PerHostRatelimiter(object):
)
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
# XXX: why do we do the following? the on_start callback above will
# do it for us.
self.current_processing.add(request_id)
# start processing the next item on the queue.
_, deferred = self.ready_request_queue.popitem(last=False)
with PreserveLoggingContext():
deferred.callback(None)

View file

@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock)
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"

View file

@ -141,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase):
)
# Query all the entries mid-way through the stream, but include one
# that doesn't exist in it. We should get back the one that doesn't
# exist, too.
# that doesn't exist in it. We shouldn't get back the one that doesn't
# exist.
self.assertEqual(
cache.get_entities_changed(
[
@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
set(["bar@baz.net", "user@elsewhere.org"]),
)
# Query all the entries, but before the first known point. We will get

View file

@ -65,6 +65,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.federation_domain_whitelist = None
config.federation_rc_reject_limit = 10
config.federation_rc_sleep_limit = 10
config.federation_rc_sleep_delay = 100
config.federation_rc_concurrent = 10
config.filter_timeline_limit = 5000
config.user_directory_search_all_users = False

View file

@ -1,5 +1,5 @@
[tox]
envlist = packaging, py27, py36, pep8
envlist = packaging, py27, py36, pep8, check_isort
[testenv]
deps =
@ -103,10 +103,14 @@ deps =
flake8
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
[testenv:check_isort]
skip_install = True
deps = isort
commands = /bin/sh -c "isort -c -sp setup.cfg -rc synapse tests"
[testenv:check-newsfragment]
skip_install = True
deps = towncrier>=18.6.0rc1
commands =
python -m towncrier.check --compare-with=origin/develop
basepython = python3.6
basepython = python3.6