Merge branch 'develop' into matthew/sync_deleted_devices

This commit is contained in:
Matthew Hodgson 2018-07-23 10:03:28 +01:00 committed by GitHub
commit 9b34f3ea3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
103 changed files with 5421 additions and 4713 deletions

View file

@ -23,6 +23,9 @@ matrix:
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36
- python: 3.6
env: TOX_ENV=check_isort
- python: 3.6 - python: 3.6
env: TOX_ENV=check-newsfragment env: TOX_ENV=check-newsfragment

2470
CHANGES.md Normal file

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@ include synctl
include LICENSE include LICENSE
include VERSION include VERSION
include *.rst include *.rst
include *.md
include demo/README include demo/README
include demo/demo.tls.dh include demo/demo.tls.dh
include demo/*.py include demo/*.py

View file

@ -71,7 +71,7 @@ We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/docs/projects/try-matrix-now.html), run a homeserver, take a look https://matrix.org/docs/projects/try-matrix-now.html), run a homeserver, take a look
at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
`APIs <https://matrix.org/docs/api>`_ and `Client SDKs `APIs <https://matrix.org/docs/api>`_ and `Client SDKs
<http://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_. <https://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_.
Thanks for using Matrix! Thanks for using Matrix!
@ -283,7 +283,7 @@ Connecting to Synapse from a client
The easiest way to try out your new Synapse installation is by connecting to it The easiest way to try out your new Synapse installation is by connecting to it
from a web client. The easiest option is probably the one at from a web client. The easiest option is probably the one at
http://riot.im/app. You will need to specify a "Custom server" when you log on https://riot.im/app. You will need to specify a "Custom server" when you log on
or register: set this to ``https://domain.tld`` if you setup a reverse proxy or register: set this to ``https://domain.tld`` if you setup a reverse proxy
following the recommended setup, or ``https://localhost:8448`` - remember to specify the following the recommended setup, or ``https://localhost:8448`` - remember to specify the
port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity
@ -329,7 +329,7 @@ Security Note
============= =============
Matrix serves raw user generated data in some APIs - specifically the `content Matrix serves raw user generated data in some APIs - specifically the `content
repository endpoints <http://matrix.org/docs/spec/client_server/latest.html#get-matrix-media-r0-download-servername-mediaid>`_. repository endpoints <https://matrix.org/docs/spec/client_server/latest.html#get-matrix-media-r0-download-servername-mediaid>`_.
Whilst we have tried to mitigate against possible XSS attacks (e.g. Whilst we have tried to mitigate against possible XSS attacks (e.g.
https://github.com/matrix-org/synapse/pull/1021) we recommend running https://github.com/matrix-org/synapse/pull/1021) we recommend running
@ -348,7 +348,7 @@ Platform-Specific Instructions
Debian Debian
------ ------
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/. Matrix provides official Debian packages via apt from https://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from Note that these packages do not include a client - choose one from
https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one of our SDKs :) https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one of our SDKs :)
@ -524,7 +524,7 @@ Troubleshooting Running
----------------------- -----------------------
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for to manually upgrade PyNaCL, as synapse uses NaCl (https://nacl.cr.yp.to/) for
encryption and digital signatures. encryption and digital signatures.
Unfortunately PyNACL currently has a few issues Unfortunately PyNACL currently has a few issues
(https://github.com/pyca/pynacl/issues/53) and (https://github.com/pyca/pynacl/issues/53) and
@ -672,8 +672,8 @@ useful just for development purposes. See `<demo/README>`_.
Using PostgreSQL Using PostgreSQL
================ ================
As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an As of Synapse 0.9, `PostgreSQL <https://www.postgresql.org>`_ is supported as an
alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has alternative to the `SQLite <https://sqlite.org/>`_ database that Synapse has
traditionally used for convenience and simplicity. traditionally used for convenience and simplicity.
The advantages of Postgres include: The advantages of Postgres include:
@ -697,7 +697,7 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or
`HAProxy <http://www.haproxy.org/>`_ in front of Synapse. One advantage of `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges. Matrix clients without needing to run Synapse with root privileges.

1
changelog.d/3367.misc Normal file
View file

@ -0,0 +1 @@
Remove unnecessary event re-signing hacks

View file

@ -1 +0,0 @@
Include CPU time from database threads in request/block metrics.

View file

@ -1 +0,0 @@
Add CPU metrics for _fetch_event_list

View file

1
changelog.d/3514.bugfix Normal file
View file

@ -0,0 +1 @@
Don't generate TURN credentials if no TURN config options are set

1
changelog.d/3548.bugfix Normal file
View file

@ -0,0 +1 @@
Catch failures saving metrics captured by Measure, and instead log the faulty metrics information for further analysis.

1
changelog.d/3552.misc Normal file
View file

@ -0,0 +1 @@
Release notes are now in the Markdown format.

1
changelog.d/3553.feature Normal file
View file

@ -0,0 +1 @@
Add metrics to track resource usage by background processes

1
changelog.d/3554.feature Normal file
View file

@ -0,0 +1 @@
Add `code` label to `synapse_http_server_response_time_seconds` prometheus metric

1
changelog.d/3556.feature Normal file
View file

@ -0,0 +1 @@
Add metrics to track resource usage by background processes

1
changelog.d/3559.misc Normal file
View file

@ -0,0 +1 @@
add config for pep8

1
changelog.d/3570.bugfix Normal file
View file

@ -0,0 +1 @@
Fix potential stack overflow and deadlock under heavy load

1
changelog.d/3571.misc Normal file
View file

@ -0,0 +1 @@
Merge Linearizer and Limiter

1
changelog.d/3572.misc Normal file
View file

@ -0,0 +1 @@
Merge Linearizer and Limiter

View file

@ -0,0 +1,63 @@
Shared-Secret Registration
==========================
This API allows for the creation of users in an administrative and
non-interactive way. This is generally used for bootstrapping a Synapse
instance with administrator accounts.
To authenticate yourself to the server, you will need both the shared secret
(``registration_shared_secret`` in the homeserver configuration), and a
one-time nonce. If the registration shared secret is not configured, this API
is not enabled.
To fetch the nonce, you need to request one from the API::
> GET /_matrix/client/r0/admin/register
< {"nonce": "thisisanonce"}
Once you have the nonce, you can make a ``POST`` to the same URL with a JSON
body containing the nonce, username, password, whether they are an admin
(optional, False by default), and a HMAC digest of the content.
As an example::
> POST /_matrix/client/r0/admin/register
> {
"nonce": "thisisanonce",
"username": "pepper_roni",
"password": "pizza",
"admin": true,
"mac": "mac_digest_here"
}
< {
"access_token": "token_here",
"user_id": "@pepper_roni@test",
"home_server": "test",
"device_id": "device_id_here"
}
The MAC is the hex digest output of the HMAC-SHA1 algorithm, with the key being
the shared secret and the content being the nonce, user, password, and either
the string "admin" or "notadmin", each separated by NULs. For an example of
generation in Python::
import hmac, hashlib
def generate_mac(nonce, user, password, admin=False):
mac = hmac.new(
key=shared_secret,
digestmod=hashlib.sha1,
)
mac.update(nonce.encode('utf8'))
mac.update(b"\x00")
mac.update(user.encode('utf8'))
mac.update(b"\x00")
mac.update(password.encode('utf8'))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
return mac.hexdigest()

View file

@ -1,5 +1,30 @@
[tool.towncrier] [tool.towncrier]
package = "synapse" package = "synapse"
filename = "CHANGES.rst" filename = "CHANGES.md"
directory = "changelog.d" directory = "changelog.d"
issue_format = "`#{issue} <https://github.com/matrix-org/synapse/issues/{issue}>`_" issue_format = "[\\#{issue}](https://github.com/matrix-org/synapse/issues/{issue}>)"
[[tool.towncrier.type]]
directory = "feature"
name = "Features"
showcontent = true
[[tool.towncrier.type]]
directory = "bugfix"
name = "Bugfixes"
showcontent = true
[[tool.towncrier.type]]
directory = "doc"
name = "Improved Documentation"
showcontent = true
[[tool.towncrier.type]]
directory = "removal"
name = "Deprecations and Removals"
showcontent = true
[[tool.towncrier.type]]
directory = "misc"
name = "Internal Changes"
showcontent = true

View file

@ -26,11 +26,37 @@ import yaml
def request_registration(user, password, server_location, shared_secret, admin=False): def request_registration(user, password, server_location, shared_secret, admin=False):
req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,),
headers={'Content-Type': 'application/json'}
)
try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
body = f.read()
f.close()
nonce = json.loads(body)["nonce"]
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
mac = hmac.new( mac = hmac.new(
key=shared_secret, key=shared_secret,
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
) )
mac.update(nonce)
mac.update("\x00")
mac.update(user) mac.update(user)
mac.update("\x00") mac.update("\x00")
mac.update(password) mac.update(password)
@ -40,10 +66,10 @@ def request_registration(user, password, server_location, shared_secret, admin=F
mac = mac.hexdigest() mac = mac.hexdigest()
data = { data = {
"user": user, "nonce": nonce,
"username": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
"admin": admin, "admin": admin,
} }
@ -52,7 +78,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,), "%s/_matrix/client/r0/admin/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

View file

@ -14,12 +14,17 @@ ignore =
pylint.cfg pylint.cfg
tox.ini tox.ini
[flake8] [pep8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. # W503 requires that binary operators be at the end, not start, of lines. Erik
# E203 is contrary to PEP8. # doesn't like it. E203 is contrary to PEP8.
ignore = W503,E203 ignore = W503,E203
[flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90
[isort] [isort]
line_length = 89 line_length = 89
not_skip = __init__.py not_skip = __init__.py
@ -31,3 +36,4 @@ known_compat = mock,six
known_twisted=twisted,OpenSSL known_twisted=twisted,OpenSSL
multi_line_output=3 multi_line_output=3
include_trailing_comma=true include_trailing_comma=true
combine_as_imports=true

View file

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.32.2" __version__ = "0.33.0"

View file

@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service) synapse.types.create_requester(user_id, app_service=app_service)
) )
access_token = get_access_token_from_request( access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
get_access_token_from_request( self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
) )
@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
@ -673,67 +673,67 @@ class Auth(object):
" edit its room list entry" " edit its room list entry"
) )
@staticmethod
def has_access_token(request):
"""Checks if the request has an access_token.
def has_access_token(request): Returns:
"""Checks if the request has an access_token. bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
Returns: @staticmethod
bool: False if no access_token was given, True otherwise. def get_access_token_from_request(request, token_not_found_http_status=401):
""" """Extracts the access_token from the request.
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
def get_access_token_from_request(request, token_not_found_http_status=401): auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
"""Extracts the access_token from the request. query_params = request.args.get(b"access_token")
if auth_headers:
Args: # Try the get the access_token from a "Authorization: Bearer"
request: The http request. # header
token_not_found_http_status(int): The HTTP status code to set in the if query_params is not None:
AuthError if the token isn't found. This is used in some of the raise AuthError(
legacy APIs to change the status code to 403 from the default of token_not_found_http_status,
401 since some of the old clients depended on auth errors returning "Mixing Authorization headers and access_token query parameters.",
403. errcode=Codes.MISSING_TOKEN,
Returns: )
str: The access_token if len(auth_headers) > 1:
Raises: raise AuthError(
AuthError: If there isn't an access_token in the request. token_not_found_http_status,
""" "Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") )
query_params = request.args.get(b"access_token") parts = auth_headers[0].split(" ")
if auth_headers: if parts[0] == "Bearer" and len(parts) == 2:
# Try the get the access_token from a "Authorization: Bearer" return parts[1]
# header else:
if query_params is not None: raise AuthError(
raise AuthError( token_not_found_http_status,
token_not_found_http_status, "Invalid Authorization header.",
"Mixing Authorization headers and access_token query parameters.", errcode=Codes.MISSING_TOKEN,
errcode=Codes.MISSING_TOKEN, )
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else: else:
raise AuthError( # Try to get the access_token from the query params.
token_not_found_http_status, if not query_params:
"Invalid Authorization header.", raise AuthError(
errcode=Codes.MISSING_TOKEN, token_not_found_http_status,
) "Missing access token.",
else: errcode=Codes.MISSING_TOKEN
# Try to get the access_token from the query params. )
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0] return query_params[0]

View file

@ -18,6 +18,8 @@ import logging
import os import os
import sys import sys
from six import iteritems
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.resource import EncodingResourceWrapper, NoResource
@ -442,7 +444,7 @@ def run(hs):
stats["total_nonbridged_users"] = total_nonbridged_users stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type() daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.iteritems(): for name, count in iteritems(daily_user_type_results):
stats["daily_user_type_" + name] = count stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count() room_count = yield hs.get_datastore().get_room_count()
@ -453,7 +455,7 @@ def run(hs):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users() r30_results = yield hs.get_datastore().count_r30_users()
for name, count in r30_results.iteritems(): for name, count in iteritems(r30_results):
stats["r30_users_" + name] = count stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()

View file

@ -25,6 +25,8 @@ import subprocess
import sys import sys
import time import time
from six import iteritems
import yaml import yaml
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
@ -173,7 +175,7 @@ def main():
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
cache_factors = config.get("synctl_cache_factors", {}) cache_factors = config.get("synctl_cache_factors", {})
for cache_name, factor in cache_factors.iteritems(): for cache_name, factor in iteritems(cache_factors):
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
worker_configfiles = [] worker_configfiles = []

View file

@ -30,10 +30,10 @@ class VoipConfig(Config):
## Turn ## ## Turn ##
# The public URIs of the TURN server to give to clients # The public URIs of the TURN server to give to clients
turn_uris: [] #turn_uris: []
# The shared secret used to compute passwords for the TURN server # The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET" #turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and # The Username and password if the TURN server needs them and
# does not use a token # does not use a token

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import iteritems
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
@ -159,7 +161,7 @@ def _encode_state_dict(state_dict):
return [ return [
(etype, state_key, v) (etype, state_key, v)
for (etype, state_key), v in state_dict.iteritems() for (etype, state_key), v in iteritems(state_dict)
] ]

View file

@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_request from synapse.http.servlet import assert_params_in_dict
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
""" """
# we could probably enforce a bunch of other fields here (room_id, sender, # we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc) # origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type', 'depth')) assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth'] depth = pdu_json['depth']
if not isinstance(depth, six.integer_types): if not isinstance(depth, six.integer_types):

View file

@ -30,7 +30,8 @@ from synapse.metrics import (
sent_edus_counter, sent_edus_counter,
sent_transactions_counter, sent_transactions_counter,
) )
from synapse.util import PreserveLoggingContext, logcontext from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import logcontext
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@ -165,10 +166,11 @@ class TransactionQueue(object):
if self._is_processing: if self._is_processing:
return return
# fire off a processing loop in the background. It's likely it will # fire off a processing loop in the background
# outlast the current request, so run it in the sentinel logcontext. run_as_background_process(
with PreserveLoggingContext(): "process_event_queue_for_federation",
self._process_event_queue_loop() self._process_event_queue_loop,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _process_event_queue_loop(self): def _process_event_queue_loop(self):
@ -432,14 +434,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Starting transaction loop", destination) logger.debug("TX [%s] Starting transaction loop", destination)
# Drop the logcontext before starting the transaction. It doesn't run_as_background_process(
# really make sense to log all the outbound transactions against "federation_transaction_transmission_loop",
# whatever path led us to this point: that's pretty arbitrary really. self._transaction_transmission_loop,
# destination,
# (this also means we can fire off _perform_transaction without )
# yielding)
with logcontext.PreserveLoggingContext():
self._transaction_transmission_loop(destination)
@defer.inlineCallbacks @defer.inlineCallbacks
def _transaction_transmission_loop(self, destination): def _transaction_transmission_loop(self, destination):

View file

@ -21,8 +21,8 @@ import logging
import sys import sys
import six import six
from six import iteritems from six import iteritems, itervalues
from six.moves import http_client from six.moves import http_client, zip
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
@ -43,7 +43,6 @@ from synapse.crypto.event_signing import (
add_hashes_and_signatures, add_hashes_and_signatures,
compute_event_signature, compute_event_signature,
) )
from synapse.events.utils import prune_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.state import resolve_events_with_factory from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
@ -52,8 +51,8 @@ from synapse.util.async import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
from ._base import BaseHandler from ._base import BaseHandler
@ -501,137 +500,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
"""Filter the given events for the given server, redacting those the
server can't see.
Assumes the server is currently in the room.
Returns
list[FrozenEvent]
"""
# First lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
)
)
visibility_ids = set()
for sids in event_to_state_ids.itervalues():
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
# is "shared" visiblity.
if not visibility_ids:
defer.returnValue(events)
event_map = yield self.store.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues()
)
if all_open:
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except Exception:
return False
# Parses mapping `event_id -> (type, state_key) -> state event_id`
# to get all state ids that we're interested in.
event_map = yield self.store.get_events([
e_id
for key_to_eid in list(event_to_state_ids.values())
for key, e_id in key_to_eid.items()
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.iteritems()
}
erased_senders = yield self.store.are_users_erased(
e.sender for e in events,
)
def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return prune_event(event)
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.itervalues():
if ev.type != EventTypes.Member:
continue
try:
domain = get_domain_from_id(ev.state_key)
except Exception:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
return prune_event(event)
return event
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities): def backfill(self, dest, room_id, limit, extremities):
@ -863,7 +731,7 @@ class FederationHandler(BaseHandler):
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
for (e_type, state_key), event in state.iteritems() for (e_type, state_key), event in iteritems(state)
if e_type == EventTypes.Member if e_type == EventTypes.Member
and event.membership == Membership.JOIN and event.membership == Membership.JOIN
] ]
@ -880,7 +748,7 @@ class FederationHandler(BaseHandler):
except Exception: except Exception:
pass pass
return sorted(joined_domains.iteritems(), key=lambda d: d[1]) return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state) curr_domains = get_domains_from_state(curr_state)
@ -943,7 +811,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains) tried_domains = set(likely_domains)
tried_domains.add(self.server_name) tried_domains.add(self.server_name)
event_ids = list(extremities.iterkeys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn( resolve = logcontext.preserve_fn(
@ -959,15 +827,15 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.itervalues() for e_id in ids.itervalues()], [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False get_prev_content=False
) )
states = { states = {
key: { key: {
k: state_map[e_id] k: state_map[e_id]
for k, e_id in state_dict.iteritems() for k, e_id in iteritems(state_dict)
if e_id in state_map if e_id in state_map
} for key, state_dict in states.iteritems() } for key, state_dict in iteritems(states)
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1038,16 +906,6 @@ class FederationHandler(BaseHandler):
[auth_id for auth_id, _ in event.auth_events], [auth_id for auth_id, _ in event.auth_events],
include_given=True include_given=True
) )
for event in auth:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
defer.returnValue([e for e in auth]) defer.returnValue([e for e in auth])
@log_function @log_function
@ -1503,18 +1361,6 @@ class FederationHandler(BaseHandler):
del results[(event.type, event.state_key)] del results[(event.type, event.state_key)]
res = list(results.values()) res = list(results.values())
for event in res:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
defer.returnValue(res) defer.returnValue(res)
else: else:
defer.returnValue([]) defer.returnValue([])
@ -1558,7 +1404,7 @@ class FederationHandler(BaseHandler):
limit limit
) )
events = yield self._filter_events_for_server(origin, room_id, events) events = yield filter_events_for_server(self.store, origin, events)
defer.returnValue(events) defer.returnValue(events)
@ -1586,18 +1432,6 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
if self.is_mine_id(event.event_id):
# FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were
# originally signed
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
in_room = yield self.auth.check_host_in_room( in_room = yield self.auth.check_host_in_room(
event.room_id, event.room_id,
origin origin
@ -1605,8 +1439,8 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server( events = yield filter_events_for_server(
origin, event.room_id, [event] self.store, origin, [event],
) )
event = events[0] event = events[0]
defer.returnValue(event) defer.returnValue(event)
@ -1681,7 +1515,7 @@ class FederationHandler(BaseHandler):
yield self.store.persist_events( yield self.store.persist_events(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
) )
@ -1862,15 +1696,6 @@ class FederationHandler(BaseHandler):
local_auth_chain, remote_auth_chain local_auth_chain, remote_auth_chain
) )
for event in ret["auth_chain"]:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
logger.debug("on_query_auth returning: %s", ret) logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret) defer.returnValue(ret)
@ -1896,8 +1721,8 @@ class FederationHandler(BaseHandler):
min_depth=min_depth, min_depth=min_depth,
) )
missing_events = yield self._filter_events_for_server( missing_events = yield filter_events_for_server(
origin, room_id, missing_events, self.store, origin, missing_events,
) )
defer.returnValue(missing_events) defer.returnValue(missing_events)

View file

@ -33,7 +33,7 @@ from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import send_event_to_master from synapse.replication.http.send_event import send_event_to_master
from synapse.types import RoomAlias, RoomStreamToken, UserID from synapse.types import RoomAlias, RoomStreamToken, UserID
from synapse.util.async import Limiter, ReadWriteLock from synapse.util.async import Linearizer, ReadWriteLock
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -427,7 +427,7 @@ class EventCreationHandler(object):
# We arbitrarily limit concurrent event creation for a room to 5. # We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much. # This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5) self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()

View file

@ -38,7 +38,7 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list. # This is used to indicate we should only return rooms published to the main list.
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
@ -50,7 +50,7 @@ class RoomListHandler(BaseHandler):
def get_local_public_room_list(self, limit=None, since_token=None, def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None, search_filter=None,
network_tuple=EMTPY_THIRD_PARTY_ID,): network_tuple=EMPTY_THIRD_PARTY_ID,):
"""Generate a local public room list. """Generate a local public room list.
There are multiple different lists: the main one plus one per third There are multiple different lists: the main one plus one per third
@ -87,7 +87,7 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None, def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None, search_filter=None,
network_tuple=EMTPY_THIRD_PARTY_ID,): network_tuple=EMPTY_THIRD_PARTY_ID,):
if since_token and since_token != "END": if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token) since_token = RoomListNextBatch.from_token(since_token)
else: else:

View file

@ -26,9 +26,11 @@ from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl, task from twisted.internet import defer, protocol, reactor, ssl, task
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, BrowserLikeRedirectAgent, ContentDecoderAgent
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.client import ( from twisted.web.client import (
Agent,
BrowserLikeRedirectAgent,
ContentDecoderAgent,
FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder, GzipDecoder,
HTTPConnectionPool, HTTPConnectionPool,
PartialDownloadError, PartialDownloadError,

View file

@ -38,7 +38,8 @@ outgoing_responses_counter = Counter(
) )
response_timer = Histogram( response_timer = Histogram(
"synapse_http_server_response_time_seconds", "sec", ["method", "servlet", "tag"] "synapse_http_server_response_time_seconds", "sec",
["method", "servlet", "tag", "code"],
) )
response_ru_utime = Counter( response_ru_utime = Counter(
@ -171,11 +172,13 @@ class RequestMetrics(object):
) )
return return
outgoing_responses_counter.labels(request.method, str(request.code)).inc() response_code = str(request.code)
outgoing_responses_counter.labels(request.method, response_code).inc()
response_count.labels(request.method, self.name, tag).inc() response_count.labels(request.method, self.name, tag).inc()
response_timer.labels(request.method, self.name, tag).observe( response_timer.labels(request.method, self.name, tag, response_code).observe(
time_sec - self.start time_sec - self.start
) )

View file

@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
return content return content
def assert_params_in_request(body, required): def assert_params_in_dict(body, required):
absent = [] absent = []
for k in required: for k in required:
if k not in body: if k not in body:

View file

@ -20,7 +20,7 @@ from twisted.web.server import Request, Site
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics from synapse.http.request_metrics import RequestMetrics
from synapse.util.logcontext import LoggingContext, ContextResourceUsage from synapse.util.logcontext import ContextResourceUsage, LoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,9 +42,10 @@ class SynapseRequest(Request):
which is handling the request, and returns a context manager. which is handling the request, and returns a context manager.
""" """
def __init__(self, site, *args, **kw): def __init__(self, site, channel, *args, **kw):
Request.__init__(self, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = site self.site = site
self._channel = channel
self.authenticated_entity = None self.authenticated_entity = None
self.start_time = 0 self.start_time = 0

View file

@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
from twisted.internet import defer
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
_background_process_start_count = Counter(
"synapse_background_process_start_count",
"Number of background processes started",
["name"],
)
# we set registry=None in all of these to stop them getting registered with
# the default registry. Instead we collect them all via the CustomCollector,
# which ensures that we can update them before they are collected.
#
_background_process_ru_utime = Counter(
"synapse_background_process_ru_utime_seconds",
"User CPU time used by background processes, in seconds",
["name"],
registry=None,
)
_background_process_ru_stime = Counter(
"synapse_background_process_ru_stime_seconds",
"System CPU time used by background processes, in seconds",
["name"],
registry=None,
)
_background_process_db_txn_count = Counter(
"synapse_background_process_db_txn_count",
"Number of database transactions done by background processes",
["name"],
registry=None,
)
_background_process_db_txn_duration = Counter(
"synapse_background_process_db_txn_duration_seconds",
("Seconds spent by background processes waiting for database "
"transactions, excluding scheduling time"),
["name"],
registry=None,
)
_background_process_db_sched_duration = Counter(
"synapse_background_process_db_sched_duration_seconds",
"Seconds spent by background processes waiting for database connections",
["name"],
registry=None,
)
# map from description to a counter, so that we can name our logcontexts
# incrementally. (It actually duplicates _background_process_start_count, but
# it's much simpler to do so than to try to combine them.)
_background_process_counts = dict() # type: dict[str, int]
# map from description to the currently running background processes.
#
# it's kept as a dict of sets rather than a big set so that we can keep track
# of process descriptions that no longer have any active processes.
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
class _Collector(object):
"""A custom metrics collector for the background process metrics.
Ensures that all of the metrics are up-to-date with any in-flight processes
before they are returned.
"""
def collect(self):
background_process_in_flight_count = GaugeMetricFamily(
"synapse_background_process_in_flight_count",
"Number of background processes in flight",
labels=["name"],
)
for desc, processes in six.iteritems(_background_processes):
background_process_in_flight_count.add_metric(
(desc,), len(processes),
)
for process in processes:
process.update_metrics()
yield background_process_in_flight_count
# now we need to run collect() over each of the static Counters, and
# yield each metric they return.
for m in (
_background_process_ru_utime,
_background_process_ru_stime,
_background_process_db_txn_count,
_background_process_db_txn_duration,
_background_process_db_sched_duration,
):
for r in m.collect():
yield r
REGISTRY.register(_Collector())
class _BackgroundProcess(object):
def __init__(self, desc, ctx):
self.desc = desc
self._context = ctx
self._reported_stats = None
def update_metrics(self):
"""Updates the metrics with values from this process."""
new_stats = self._context.get_resource_usage()
if self._reported_stats is None:
diff = new_stats
else:
diff = new_stats - self._reported_stats
self._reported_stats = new_stats
_background_process_ru_utime.labels(self.desc).inc(diff.ru_utime)
_background_process_ru_stime.labels(self.desc).inc(diff.ru_stime)
_background_process_db_txn_count.labels(self.desc).inc(
diff.db_txn_count,
)
_background_process_db_txn_duration.labels(self.desc).inc(
diff.db_txn_duration_sec,
)
_background_process_db_sched_duration.labels(self.desc).inc(
diff.db_sched_duration_sec,
)
def run_as_background_process(desc, func, *args, **kwargs):
"""Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the
background, instead of being associated with a particular request.
Args:
desc (str): a description for this background process type
func: a function, which may return a Deferred
args: positional args for func
kwargs: keyword args for func
Returns: None
"""
@defer.inlineCallbacks
def run():
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc()
with LoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
proc = _BackgroundProcess(desc, context)
_background_processes.setdefault(desc, set()).add(proc)
try:
yield func(*args, **kwargs)
finally:
proc.update_metrics()
_background_processes[desc].remove(proc)
with PreserveLoggingContext():
run()

View file

@ -49,7 +49,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,13 +14,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import PY3
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.rest.client import versions from synapse.rest.client import versions
from synapse.rest.client.v1 import admin, directory, events, initial_sync from synapse.rest.client.v1 import (
from synapse.rest.client.v1 import login as v1_login admin,
from synapse.rest.client.v1 import logout, presence, profile, push_rule, pusher directory,
from synapse.rest.client.v1 import register as v1_register events,
from synapse.rest.client.v1 import room, voip initial_sync,
login as v1_login,
logout,
presence,
profile,
push_rule,
pusher,
room,
voip,
)
from synapse.rest.client.v2_alpha import ( from synapse.rest.client.v2_alpha import (
account, account,
account_data, account_data,
@ -42,6 +54,11 @@ from synapse.rest.client.v2_alpha import (
user_directory, user_directory,
) )
if not PY3:
from synapse.rest.client.v1_only import (
register as v1_register,
)
class ClientRestResource(JsonResource): class ClientRestResource(JsonResource):
"""A resource for version 1 of the matrix client API.""" """A resource for version 1 of the matrix client API."""
@ -54,14 +71,22 @@ class ClientRestResource(JsonResource):
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
versions.register_servlets(client_resource) versions.register_servlets(client_resource)
# "v1" if not PY3:
room.register_servlets(hs, client_resource) # "v1" (Python 2 only)
v1_register.register_servlets(hs, client_resource)
# Deprecated in r0
initial_sync.register_servlets(hs, client_resource)
room.register_deprecated_servlets(hs, client_resource)
# Partially deprecated in r0
events.register_servlets(hs, client_resource) events.register_servlets(hs, client_resource)
v1_register.register_servlets(hs, client_resource)
# "v1" + "r0"
room.register_servlets(hs, client_resource)
v1_login.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource) presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)

View file

@ -17,38 +17,20 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object): class HttpTransactionCache(object):
def __init__(self, clock): def __init__(self, hs):
self.clock = clock self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = { self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
} }
@ -56,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins. # for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs): def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts """A helper function for fetch_or_execute which extracts
a transaction key from the given request. a transaction key from the given request.
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute fetch_or_execute
""" """
return self.fetch_or_execute( return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs self._get_transaction_key(request), fn, *args, **kwargs
) )
def fetch_or_execute(self, txn_key, fn, *args, **kwargs): def fetch_or_execute(self, txn_key, fn, *args, **kwargs):

View file

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import hmac
import logging import logging
from six.moves import http_client from six.moves import http_client
@ -22,7 +24,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -58,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class UserRegisterServlet(ClientV1RestServlet):
"""
Attributes:
NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
nonces (dict[str, int]): The nonces that we will accept. A dict of
nonce to the time it was generated, in int seconds.
"""
PATTERNS = client_path_patterns("/admin/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
super(UserRegisterServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.reactor = hs.get_reactor()
self.nonces = {}
self.hs = hs
def _clear_old_nonces(self):
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
now = int(self.reactor.seconds())
for k, v in list(self.nonces.items()):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]
def on_GET(self, request):
"""
Generate a new nonce.
"""
self._clear_old_nonces()
nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds())
return (200, {"nonce": nonce.encode('ascii')})
@defer.inlineCallbacks
def on_POST(self, request):
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
body = parse_json_object_from_request(request)
if "nonce" not in body:
raise SynapseError(
400, "nonce must be specified", errcode=Codes.BAD_JSON,
)
nonce = body["nonce"]
if nonce not in self.nonces:
raise SynapseError(
400, "unrecognised nonce",
)
# Delete the nonce, so it can't be reused, even if it's invalid
del self.nonces[nonce]
if "username" not in body:
raise SynapseError(
400, "username must be specified", errcode=Codes.BAD_JSON,
)
else:
if (not isinstance(body['username'], str) or len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8")
if b"\x00" in username:
raise SynapseError(400, "Invalid username")
if "password" not in body:
raise SynapseError(
400, "password must be specified", errcode=Codes.BAD_JSON,
)
else:
if (not isinstance(body['password'], str) or len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8")
if b"\x00" in password:
raise SynapseError(400, "Invalid password")
admin = body.get("admin", None)
got_mac = body["mac"]
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac.update(nonce)
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
# Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register(
localpart=username.lower(), password=password, admin=bool(admin),
generate_token=False,
)
result = yield register._create_registration_details(user_id, body)
defer.returnValue((200, result))
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@ -98,16 +224,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None) before_ts = parse_integer(request, "before_ts", required=True)
if not before_ts: logger.info("before_ts: %r", before_ts)
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
ret = yield self.media_repository.delete_old_remote_media(before_ts) ret = yield self.media_repository.delete_old_remote_media(before_ts)
@ -300,10 +418,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
new_room_user_id = content.get("new_room_user_id") new_room_user_id = content["new_room_user_id"]
if not new_room_user_id:
raise SynapseError(400, "Please provide field `new_room_user_id`")
room_creator_requester = create_requester(new_room_user_id) room_creator_requester = create_requester(new_room_user_id)
@ -464,9 +580,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"])
new_password = params['new_password'] new_password = params['new_password']
if not new_password:
raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password) logger.info("new_password: %r", new_password)
@ -514,12 +629,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table order = "name" # order by name in user table
start = request.args.get("start")[0] start = parse_integer(request, "start", required=True)
limit = request.args.get("limit")[0] limit = parse_integer(request, "limit", required=True)
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate( ret = yield self.handlers.admin_handler.get_users_paginate(
@ -551,12 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table order = "name" # order by name in user table
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["limit", "start"])
limit = params['limit'] limit = params['limit']
start = params['start'] start = params['start']
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate( ret = yield self.handlers.admin_handler.get_users_paginate(
@ -604,10 +713,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
term = request.args.get("term")[0] term = parse_string(request, "term", required=True)
if not term:
raise SynapseError(400, "Missing 'term' arg")
logger.info("term: %s ", term) logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users( ret = yield self.handlers.admin_handler.search_users(
@ -629,3 +735,4 @@ def register_servlets(hs, http_server):
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)

View file

@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock()) self.txns = HttpTransactionCache(hs)

View file

@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
room_alias = RoomAlias.from_string(room_alias)
logger.debug("Got room name: %s", room_alias.to_string()) logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"] room_id = content["room_id"]

View file

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_boolean
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = request.args.get("archived", None) == ["true"] include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # the acccess token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = get_access_token_from_request(request) access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token) yield self._auth_handler.delete_access_token(access_token)
else: else:
yield self._device_handler.delete_device( yield self._device_handler.delete_device(

View file

@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.servlet import parse_json_value_from_request from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@ -75,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
before = request.args.get("before", None) before = parse_string(request, "before")
if before: if before:
before = _namespaced_rule_id(spec, before[0]) before = _namespaced_rule_id(spec, before)
after = request.args.get("after", None) after = parse_string(request, "after")
if after: if after:
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after)
try: try:
yield self.store.add_push_rule( yield self.store.add_push_rule(

View file

@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['kind', 'app_id', 'app_display_name', assert_params_in_dict(
'device_display_name', 'pushkey', 'lang', 'data'] content,
missing = [] ['kind', 'app_id', 'app_display_name',
for i in reqd: 'device_display_name', 'pushkey', 'lang', 'data']
if i not in content: )
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content) logger.debug("Got pushers request with body: %r", content)
@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs):
super(RestServlet, self).__init__() super(PushersRemoveRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.auth = hs.get_auth() self.auth = hs.get_auth()

View file

@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import ( from synapse.http.servlet import (
assert_params_in_dict,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -435,9 +436,9 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None) filter_bytes = parse_string(request, "filter")
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None
@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0]) limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context( results = yield self.handlers.room_context_handler.get_event_context(
requester.user, requester.user,
@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]: if membership_action in ["invite", "ban", "unban", "kick"]:
if "user_id" not in content: assert_params_in_dict(content, ["user_id"])
raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"]) target = UserID.from_string(content["user_id"])
event_content = None event_content = None
@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0] batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search( results = yield self.handlers.search_handler.search(
requester.user, requester.user,
content, content,
@ -832,10 +832,13 @@ def register_servlets(hs, http_server):
RoomSendEventRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server)
RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server)

View file

@ -0,0 +1,3 @@
"""
REST APIs that are only used in v1 (the legacy API).
"""

View file

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
import re
from synapse.api.urls import CLIENT_PREFIX
def v1_only_client_path_patterns(path_regex, include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
list of SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
return patterns

View file

@ -18,18 +18,16 @@ import hmac
import logging import logging
from hashlib import sha1 from hashlib import sha1
from six import string_types
from twisted.internet import defer from twisted.internet import defer
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import v1_only_client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +50,7 @@ class RegisterRestServlet(ClientV1RestServlet):
handler doesn't have a concept of multi-stages or sessions. handler doesn't have a concept of multi-stages or sessions.
""" """
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False)
def __init__(self, hs): def __init__(self, hs):
""" """
@ -67,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
@ -124,8 +123,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"] session = (register_json["session"]
if "session" in register_json else None) if "session" in register_json else None)
login_type = None login_type = None
if "type" not in register_json: assert_params_in_dict(register_json, ["type"])
raise SynapseError(400, "Missing 'type' key.")
try: try:
login_type = register_json["type"] login_type = register_json["type"]
@ -310,11 +308,9 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request) as_token = self.auth.get_access_token_from_request(request)
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
assert_params_in_dict(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
@ -331,12 +327,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session): def _do_shared_secret(self, request, register_json, session):
if not isinstance(register_json.get("mac", None), string_types): assert_params_in_dict(register_json, ["mac", "user", "password"])
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -389,7 +380,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface """Handles user creation via a server-to-server interface
""" """
PATTERNS = client_path_patterns("/createUser$", releases=()) PATTERNS = v1_only_client_path_patterns("/createUser$")
def __init__(self, hs): def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)
@ -400,7 +391,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
access_token access_token
) )
@ -419,11 +410,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, requester, user_json): def _do_create(self, requester, user_json):
if "localpart" not in user_json: assert_params_in_dict(user_json, ["localpart", "displayname"])
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
localpart = user_json["localpart"].encode("utf-8") localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8")

View file

@ -20,12 +20,11 @@ from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_request, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
@ -48,7 +47,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
@ -81,7 +80,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt', 'country', 'phone_number', 'send_attempt',
]) ])
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
# #
# In the second case, we require a password to confirm their identity. # In the second case, we require a password to confirm their identity.
if has_access_token(request): if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth( params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),
@ -160,11 +159,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id user_id = threepid_user_id
else: else:
logger.error("Auth succeeded but no known type!", result.keys()) logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN) raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params: assert_params_in_dict(params, ["new_password"])
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self._set_password_handler.set_password( yield self._set_password_handler.set_password(
@ -229,15 +227,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(
required = ['id_server', 'client_secret', 'email', 'send_attempt'] body,
absent = [] ['id_server', 'client_secret', 'email', 'send_attempt'],
for k in required: )
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
if not check_3pid_allowed(self.hs, "email", body['email']): if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError( raise SynapseError(
@ -267,18 +260,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, [
required = [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt', 'country', 'phone_number', 'send_attempt',
] ])
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@ -373,15 +358,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ['medium', 'address'])
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()

View file

@ -18,14 +18,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
from synapse.http import servlet from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet): class DeleteDevicesRestServlet(RestServlet):
""" """
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict # DELETE
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = {}
else: else:
raise e raise e
if 'devices' not in body: assert_params_in_dict(body, ["devices"])
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
yield self.auth_handler.validate_user_via_ui_auth( yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),
@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id): def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.device_handler.update_device( yield self.device_handler.update_device(
requester.user.to_string(), requester.user.to_string(),
device_id, device_id,

View file

@ -24,12 +24,11 @@ from twisted.internet import defer
import synapse import synapse
import synapse.types import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_request, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
@ -69,7 +68,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
@ -105,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'country', 'phone_number',
'send_attempt', 'send_attempt',
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username'] desired_username = body['username']
appservice = None appservice = None
if has_access_token(request): if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which # fork off as soon as possible for ASes and shared secret auth which
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid # because the IRC bridges rely on being able to register stupid
# IDs. # IDs.
access_token = get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types): if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
@ -387,9 +386,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False add_msisdn = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: assert_params_in_dict(params, ["password"])
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
@ -566,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
reqd = ('medium', 'address', 'validated_at') try:
if any(x not in threepid for x in reqd): assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
# This will only happen if the ID server returns a malformed response except SynapseError as ex:
logger.info("Can't add incomplete 3pid") if ex.errcode == Codes.MISSING_PARAM:
defer.returnValue() # This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue(None)
raise
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self, params): def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled")) raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register( user_id, _ = yield self.registration_handler.register(
generate_token=False, generate_token=False,
make_guest=True make_guest=True

View file

@ -15,9 +15,17 @@
import logging import logging
from six import string_types
from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -42,12 +50,26 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("reason", "score"))
if not isinstance(body["reason"], string_types):
raise SynapseError(
http_client.BAD_REQUEST,
"Param 'reason' must be a string",
Codes.BAD_JSON,
)
if not isinstance(body["score"], int):
raise SynapseError(
http_client.BAD_REQUEST,
"Param 'score' must be an integer",
Codes.BAD_JSON,
)
yield self.store.add_event_report( yield self.store.add_event_report(
room_id=room_id, room_id=room_id,
event_id=event_id, event_id=event_id,
user_id=user_id, user_id=user_id,
reason=body.get("reason"), reason=body["reason"],
content=body, content=body,
received_ts=self.clock.time_msec(), received_ts=self.clock.time_msec(),
) )

View file

@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock()) self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):

View file

@ -16,6 +16,8 @@ from pydenticon import Generator
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.http.servlet import parse_integer
FOREGROUND = [ FOREGROUND = [
"rgb(45,79,255)", "rgb(45,79,255)",
"rgb(254,180,44)", "rgb(254,180,44)",
@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request): def render_GET(self, request):
name = "/".join(request.postpath) name = "/".join(request.postpath)
width = int(request.args.get("width", [96])[0]) width = parse_integer(request, "width", default=96)
height = int(request.args.get("height", [96])[0]) height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height) identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png") request.setHeader(b"Content-Type", b"image/png")
request.setHeader( request.setHeader(

View file

@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes, respond_with_json_bytes,
wrap_json_request_handler, wrap_json_request_handler,
) )
from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
url = request.args.get("url")[0] url = parse_string(request, "url")
if "ts" in request.args: if "ts" in request.args:
ts = int(request.args.get("ts")[0]) ts = parse_integer(request, "ts")
else: else:
ts = self.clock.time_msec() ts = self.clock.time_msec()

View file

@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413, code=413,
) )
upload_name = request.args.get("filename", None) upload_name = parse_string(request, "filename")
if upload_name: if upload_name:
try: try:
upload_name = upload_name[0].decode('UTF-8') upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError: except UnicodeDecodeError:
raise SynapseError( raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), msg="Invalid UTF-8 filename parameter: %r" % (upload_name),

42
synapse/secrets.py Normal file
View file

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Injectable secrets module for Synapse.
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
used in Python 3.6, and the API emulated in Python 2.7.
"""
import six
if six.PY3:
import secrets
def Secrets():
return secrets
else:
import os
import binascii
class Secrets(object):
def token_bytes(self, nbytes=32):
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
return binascii.hexlify(self.token_bytes(nbytes))

View file

@ -74,6 +74,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
) )
from synapse.secrets import Secrets
from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
@ -158,6 +159,7 @@ class HomeServer(object):
'groups_server_handler', 'groups_server_handler',
'groups_attestation_signing', 'groups_attestation_signing',
'groups_attestation_renewer', 'groups_attestation_renewer',
'secrets',
'spam_checker', 'spam_checker',
'room_member_handler', 'room_member_handler',
'federation_registry', 'federation_registry',
@ -405,6 +407,9 @@ class HomeServer(object):
def build_groups_attestation_renewer(self): def build_groups_attestation_renewer(self):
return GroupAttestionRenewer(self) return GroupAttestionRenewer(self)
def build_secrets(self):
return Secrets()
def build_spam_checker(self): def build_spam_checker(self):
return SpamChecker(self) return SpamChecker(self)

View file

@ -18,7 +18,7 @@ import hashlib
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import iteritems, itervalues from six import iteritems, iterkeys, itervalues
from frozendict import frozendict from frozendict import frozendict
@ -647,7 +647,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
for event_id in event_ids for event_id in event_ids
) )
if event_map is not None: if event_map is not None:
needed_events -= set(event_map.iterkeys()) needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events)) logger.info("Asking for %d conflicted events", len(needed_events))
@ -668,7 +668,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
new_needed_events = set(itervalues(auth_events)) new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None: if event_map is not None:
new_needed_events -= set(event_map.iterkeys()) new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events)) logger.info("Asking for %d auth events", len(new_needed_events))

View file

@ -344,7 +344,7 @@ class SQLBaseStore(object):
parent_context = LoggingContext.current_context() parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel: if parent_context == LoggingContext.sentinel:
logger.warn( logger.warn(
"Running db txn from sentinel context: metrics will be lost", "Starting db connection from sentinel context: metrics will be lost",
) )
parent_context = None parent_context = None

View file

@ -19,6 +19,8 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines from . import engines
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -87,10 +89,14 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_handlers = {} self._background_update_handlers = {}
self._all_done = False self._all_done = False
@defer.inlineCallbacks
def start_doing_background_updates(self): def start_doing_background_updates(self):
logger.info("Starting background schema updates") run_as_background_process(
"background_updates", self._run_background_updates,
)
@defer.inlineCallbacks
def _run_background_updates(self):
logger.info("Starting background schema updates")
while True: while True:
yield self.hs.get_clock().sleep( yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)

View file

@ -19,6 +19,7 @@ from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from . import background_updates from . import background_updates
@ -93,10 +94,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now) self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self): def _update_client_ips_batch(self):
to_update = self._batch_row_update def update():
self._batch_row_update = {} to_update = self._batch_row_update
return self.runInteraction( self._batch_row_update = {}
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn,
to_update,
)
run_as_background_process(
"update_client_ips", update,
) )
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(self, txn, to_update):

View file

@ -33,12 +33,13 @@ from synapse.api.errors import SynapseError
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -155,11 +156,8 @@ class _EventPeristenceQueue(object):
self._event_persist_queues[room_id] = queue self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id) self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off on the background. We don't want to # set handle_queue_loop off in the background
# attribute work done in it to the current request, so we drop the run_as_background_process("persist_events", handle_queue_loop)
# logcontext altogether.
with PreserveLoggingContext():
handle_queue_loop()
def _get_drainining_queue(self, room_id): def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())

View file

@ -25,6 +25,7 @@ from synapse.events import EventBase # noqa: F401
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
@ -322,10 +323,11 @@ class EventsWorkerStore(SQLBaseStore):
should_start = False should_start = False
if should_start: if should_start:
with PreserveLoggingContext(): run_as_background_process(
self.runWithConnection( "fetch_events",
self._do_fetch self.runWithConnection,
) self._do_fetch,
)
logger.debug("Loading %d events", len(events)) logger.debug("Loading %d events", len(events))
with PreserveLoggingContext(): with PreserveLoggingContext():

View file

@ -140,7 +140,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
""" """
room_ids = set(room_ids) room_ids = set(room_ids)
if from_key: if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key room_ids, from_key
) )
@ -151,7 +153,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
@ -162,7 +163,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
from the start. from the start.
Returns: Returns:
list: A list of receipts. Deferred[list]: A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
defer.succeed([])
return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cachedInlineCallbacks(num_args=3, tree=True)
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""See get_linearized_receipts_for_room
""" """
def f(txn): def f(txn):
if from_key: if from_key:
@ -211,7 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"content": content, "content": content,
}]) }])
@cachedList(cached_method_name="get_linearized_receipts_for_room", @cachedList(cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True) list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids: if not room_ids:
@ -373,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type) self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after( txn.call_after(
self._receipts_stream_cache.entity_has_changed, self._receipts_stream_cache.entity_has_changed,
@ -493,7 +506,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type) self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,

View file

@ -16,6 +16,7 @@
import logging import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod @classmethod
def from_request(cls, request, raise_invalid_params=True, def from_request(cls, request, raise_invalid_params=True,
default_limit=None): default_limit=None):
def get_param(name, default=None): direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
lst = request.args.get(name, [])
if len(lst) > 1:
raise SynapseError(
400, "%s must be specified only once" % (name,)
)
elif len(lst) == 1:
return lst[0]
else:
return default
direction = get_param("dir", 'f') from_tok = parse_string(request, "from")
if direction not in ['f', 'b']: to_tok = parse_string(request, "to")
raise SynapseError(400, "'dir' parameter is invalid.")
from_tok = get_param("from")
to_tok = get_param("to")
try: try:
if from_tok == "END": if from_tok == "END":
@ -88,12 +76,10 @@ class PaginationConfig(object):
except Exception: except Exception:
raise SynapseError(400, "'to' paramater is invalid") raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", None) limit = parse_integer(request, "limit", default=default_limit)
if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.")
if limit is None: if limit and limit < 0:
limit = default_limit raise SynapseError(400, "Limit must be 0 or above")
try: try:
return PaginationConfig(from_tok, to_tok, direction, limit) return PaginationConfig(from_tok, to_tok, direction, limit)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
@ -156,54 +157,72 @@ def concurrently_execute(func, args, limit):
class Linearizer(object): class Linearizer(object):
"""Linearizes access to resources based on a key. Useful to ensure only one """Limits concurrent access to resources based on a key. Useful to ensure
thing is happening at a time on a given resource. only a few things happen at a time on a given resource.
Example: Example:
with (yield linearizer.queue("test_key")): with (yield limiter.queue("test_key")):
# do some work. # do some work.
""" """
def __init__(self, name=None, clock=None): def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
max_count(int): The maximum number of concurrent accesses
"""
if name is None: if name is None:
self.name = id(self) self.name = id(self)
else: else:
self.name = name self.name = name
self.key_to_defer = {}
if not clock: if not clock:
from twisted.internet import reactor from twisted.internet import reactor
clock = Clock(reactor) clock = Clock(reactor)
self._clock = clock self._clock = clock
self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def queue(self, key): def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# we can wait on it later.
# Then we replace it with a deferred that we resolve *after* the
# context manager has exited.
# We only return the context manager after the previous deferred has
# resolved.
# This all has the net effect of creating a chain of deferreds that
# wait for the previous deferred before starting their work.
current_defer = self.key_to_defer.get(key)
new_defer = defer.Deferred() # If the number of things executing is greater than the maximum
self.key_to_defer[key] = new_defer # then add a deferred to the list of blocked items
# When on of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1][new_defer] = 1
if current_defer:
logger.info( logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key "Waiting to acquire linearizer lock %r for key %r", self.name, key,
) )
try: try:
with PreserveLoggingContext(): yield make_deferred_yieldable(new_defer)
yield current_defer except Exception as e:
except Exception: if isinstance(e, CancelledError):
logger.exception("Unexpected exception in Linearizer") logger.info(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
)
logger.info("Acquired linearizer lock %r for key %r", self.name, # we just have to take ourselves back out of the queue.
key) del entry[1][new_defer]
raise
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
# if the code holding the lock completes synchronously, then it # if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can # will recursively run the next claimant on the list. That can
@ -213,15 +232,15 @@ class Linearizer(object):
# In order to break the cycle, we add a cheeky sleep(0) here to # In order to break the cycle, we add a cheeky sleep(0) here to
# ensure that we fall back to the reactor between each iteration. # ensure that we fall back to the reactor between each iteration.
# #
# (There's no particular need for it to happen before we return # (This needs to happen while we hold the lock, and the context manager's exit
# the context manager, but it needs to happen while we hold the # code must be synchronous, so this is the only sensible place.)
# lock, and the context manager's exit code must be synchronous,
# so actually this is the only sensible place.
yield self._clock.sleep(0) yield self._clock.sleep(0)
else: else:
logger.info("Acquired uncontended linearizer lock %r for key %r", logger.info(
self.name, key) "Acquired uncontended linearizer lock %r for key %r", self.name, key,
)
entry[0] += 1
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -229,73 +248,15 @@ class Linearizer(object):
yield yield
finally: finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key) logger.info("Releasing linearizer lock %r for key %r", self.name, key)
with PreserveLoggingContext():
new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
class Limiter(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few thing happen at a time on a given resource.
Example:
with (yield limiter.queue("test_key")):
# do some work.
"""
def __init__(self, max_count):
"""
Args:
max_count(int): The maximum number of concurrent access
"""
self.max_count = max_count
# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing
# the second element is a list of deferreds for the things blocked from
# executing.
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
entry = self.key_to_defer.setdefault(key, [0, []])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
# When on of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1].append(new_defer)
logger.info("Waiting to acquire limiter lock for key %r", key)
with PreserveLoggingContext():
yield new_defer
logger.info("Acquired limiter lock for key %r", key)
else:
logger.info("Acquired uncontended limiter lock for key %r", key)
entry[0] += 1
@contextmanager
def _ctx_manager():
try:
yield
finally:
logger.info("Releasing limiter lock for key %r", key)
# We've finished executing so check if there are any things # We've finished executing so check if there are any things
# blocked waiting to execute and start one of them # blocked waiting to execute and start one of them
entry[0] -= 1 entry[0] -= 1
if entry[1]: if entry[1]:
next_def = entry[1].pop(0) (next_def, _) = entry[1].popitem(last=False)
# we need to run the next thing in the sentinel context.
with PreserveLoggingContext(): with PreserveLoggingContext():
next_def.callback(None) next_def.callback(None)
elif entry[0] == 0: elif entry[0] == 0:

View file

@ -16,6 +16,7 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +64,10 @@ class ExpiringCache(object):
return return
def f(): def f():
self._prune_cache() run_as_background_process(
"prune_cache_%s" % self._cache_name,
self._prune_cache,
)
self._clock.looping_call(f, self._expiry_ms / 2) self._clock.looping_call(f, self._expiry_ms / 2)

View file

@ -74,14 +74,13 @@ class StreamChangeCache(object):
assert type(stream_pos) is int assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos: if stream_pos >= self._earliest_known_stream_pos:
not_known_entities = set(entities) - set(self._entity_to_key) changed_entities = {
self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos),
)
}
result = ( result = changed_entities.intersection(entities)
{self._cache[k] for k in self._cache.islice(
start=self._cache.bisect_right(stream_pos))}
.intersection(entities)
.union(not_known_entities)
)
self.metrics.inc_hits() self.metrics.inc_hits()
else: else:

View file

@ -17,20 +17,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.util import unwrapFirstError from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
with PreserveLoggingContext(): distributor.fire("user_left_room", user=user, room_id=room_id)
distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
with PreserveLoggingContext(): distributor.fire("user_joined_room", user=user, room_id=room_id)
distributor.fire("user_joined_room", user=user, room_id=room_id)
class Distributor(object): class Distributor(object):
@ -44,9 +42,7 @@ class Distributor(object):
model will do for today. model will do for today.
""" """
def __init__(self, suppress_failures=True): def __init__(self):
self.suppress_failures = suppress_failures
self.signals = {} self.signals = {}
self.pre_registration = {} self.pre_registration = {}
@ -56,7 +52,6 @@ class Distributor(object):
self.signals[name] = Signal( self.signals[name] = Signal(
name, name,
suppress_failures=self.suppress_failures,
) )
if name in self.pre_registration: if name in self.pre_registration:
@ -75,10 +70,18 @@ class Distributor(object):
self.pre_registration[name].append(observer) self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs): def fire(self, name, *args, **kwargs):
"""Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred.
"""
if name not in self.signals: if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name)) raise KeyError("%r does not have a signal named %s" % (self, name))
return self.signals[name].fire(*args, **kwargs) run_as_background_process(
name,
self.signals[name].fire,
*args, **kwargs
)
class Signal(object): class Signal(object):
@ -91,9 +94,8 @@ class Signal(object):
method into all of the observers. method into all of the observers.
""" """
def __init__(self, name, suppress_failures): def __init__(self, name):
self.name = name self.name = name
self.suppress_failures = suppress_failures
self.observers = [] self.observers = []
def observe(self, observer): def observe(self, observer):
@ -103,7 +105,6 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -121,22 +122,17 @@ class Signal(object):
failure.type, failure.type,
failure.value, failure.value,
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures:
return failure
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext(): deferreds = [
deferreds = [ run_in_background(do, o)
do(observer) for o in self.observers
for observer in self.observers ]
]
res = yield defer.gatherResults( return make_deferred_yieldable(defer.gatherResults(
deferreds, consumeErrors=True deferreds, consumeErrors=True,
).addErrback(unwrapFirstError) ))
defer.returnValue(res)
def __repr__(self): def __repr__(self):
return "<Signal name=%r>" % (self.name,) return "<Signal name=%r>" % (self.name,)

View file

@ -99,6 +99,17 @@ class ContextResourceUsage(object):
self.db_sched_duration_sec = 0 self.db_sched_duration_sec = 0
self.evt_db_fetch_count = 0 self.evt_db_fetch_count = 0
def __repr__(self):
return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
"db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
self.ru_stime,
self.ru_utime,
self.db_txn_count,
self.db_txn_duration_sec,
self.db_sched_duration_sec,
self.evt_db_fetch_count,)
def __iadd__(self, other): def __iadd__(self, other):
"""Add another ContextResourceUsage's stats to this one's. """Add another ContextResourceUsage's stats to this one's.

View file

@ -104,12 +104,19 @@ class Measure(object):
logger.warn("Expected context. (%r)", self.name) logger.warn("Expected context. (%r)", self.name)
return return
usage = context.get_resource_usage() - self.start_usage current = context.get_resource_usage()
block_ru_utime.labels(self.name).inc(usage.ru_utime) usage = current - self.start_usage
block_ru_stime.labels(self.name).inc(usage.ru_stime) try:
block_db_txn_count.labels(self.name).inc(usage.db_txn_count) block_ru_utime.labels(self.name).inc(usage.ru_utime)
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_ru_stime.labels(self.name).inc(usage.ru_stime)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
logger.warn(
"Failed to save metrics! OLD: %r, NEW: %r",
self.start_usage, current
)
if self.created_context: if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb) self.start_context.__exit__(exc_type, exc_val, exc_tb)

View file

@ -92,13 +92,22 @@ class _PerHostRatelimiter(object):
self.window_size = window_size self.window_size = window_size
self.sleep_limit = sleep_limit self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec self.sleep_sec = sleep_msec / 1000.0
self.reject_limit = reject_limit self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests self.concurrent_requests = concurrent_requests
# request_id objects for requests which have been slept
self.sleeping_requests = set() self.sleeping_requests = set()
# map from request_id object to Deferred for requests which are ready
# for processing but have been queued
self.ready_request_queue = collections.OrderedDict() self.ready_request_queue = collections.OrderedDict()
# request id objects for requests which are in progress
self.current_processing = set() self.current_processing = set()
# times at which we have recently (within the last window_size ms)
# received requests.
self.request_times = [] self.request_times = []
@contextlib.contextmanager @contextlib.contextmanager
@ -117,11 +126,15 @@ class _PerHostRatelimiter(object):
def _on_enter(self, request_id): def _on_enter(self, request_id):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# remove any entries from request_times which aren't within the window
self.request_times[:] = [ self.request_times[:] = [
r for r in self.request_times r for r in self.request_times
if time_now - r < self.window_size if time_now - r < self.window_size
] ]
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit: if queue_size > self.reject_limit:
raise LimitExceededError( raise LimitExceededError(
@ -134,9 +147,13 @@ class _PerHostRatelimiter(object):
def queue_request(): def queue_request():
if len(self.current_processing) > self.concurrent_requests: if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred() queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer self.ready_request_queue[request_id] = queue_defer
logger.info(
"Ratelimiter: queueing request (queue now %i items)",
len(self.ready_request_queue),
)
return queue_defer return queue_defer
else: else:
return defer.succeed(None) return defer.succeed(None)
@ -148,10 +165,9 @@ class _PerHostRatelimiter(object):
if len(self.request_times) > self.sleep_limit: if len(self.request_times) > self.sleep_limit:
logger.debug( logger.debug(
"Ratelimit [%s]: sleeping req", "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
id(request_id),
) )
ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0) ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id) self.sleeping_requests.add(request_id)
@ -200,11 +216,8 @@ class _PerHostRatelimiter(object):
) )
self.current_processing.discard(request_id) self.current_processing.discard(request_id)
try: try:
request_id, deferred = self.ready_request_queue.popitem() # start processing the next item on the queue.
_, deferred = self.ready_request_queue.popitem(last=False)
# XXX: why do we do the following? the on_start callback above will
# do it for us.
self.current_processing.add(request_id)
with PreserveLoggingContext(): with PreserveLoggingContext():
deferred.callback(None) deferred.callback(None)

View file

@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
import operator import operator
from six import iteritems, itervalues
from six.moves import map
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.types import get_domain_from_id
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -218,10 +222,161 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
return event return event
# check each event: gives an iterable[None|EventBase] # check each event: gives an iterable[None|EventBase]
filtered_events = itertools.imap(allowed, events) filtered_events = map(allowed, events)
# remove the None entries # remove the None entries
filtered_events = filter(operator.truth, filtered_events) filtered_events = filter(operator.truth, filtered_events)
# we turn it into a list before returning it. # we turn it into a list before returning it.
defer.returnValue(list(filtered_events)) defer.returnValue(list(filtered_events))
@defer.inlineCallbacks
def filter_events_for_server(store, server_name, events):
# Whatever else we do, we need to check for senders which have requested
# erasure of their data.
erased_senders = yield store.are_users_erased(
e.sender for e in events,
)
def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return prune_event(event)
# state will be None if we decided we didn't need to filter by
# room membership.
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in itervalues(state):
if ev.type != EventTypes.Member:
continue
try:
domain = get_domain_from_id(ev.state_key)
except Exception:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
# server has no users in the room: redact
return prune_event(event)
return event
# Next lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
)
)
visibility_ids = set()
for sids in itervalues(event_to_state_ids):
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
# is "shared" visiblity.
if not visibility_ids:
all_open = True
else:
event_map = yield store.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map)
)
if all_open:
# all the history_visibility state affecting these events is open, so
# we don't need to filter by membership state. We *do* need to check
# for user erasure, though.
if erased_senders:
events = [
redact_disallowed(e, None)
for e in events
]
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
# We only want to pull out member events that correspond to the
# server's domain.
#
# event_to_state_ids contains lots of duplicates, so it turns out to be
# cheaper to build a complete set of unique
# ((type, state_key), event_id) tuples, and then filter out the ones we
# don't want.
#
state_key_to_event_id_set = {
e
for key_to_eid in itervalues(event_to_state_ids)
for e in key_to_eid.items()
}
def include(typ, state_key):
if typ != EventTypes.Member:
return True
# we avoid using get_domain_from_id here for efficiency.
idx = state_key.find(":")
if idx == -1:
return False
return state_key[idx + 1:] == server_name
event_map = yield store.get_events([
e_id
for key, e_id in state_key_to_event_id_set
if include(key[0], key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in iteritems(key_to_eid)
if inner_e_id in event_map
}
for e_id, key_to_eid in iteritems(event_to_state_ids)
}
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])

View file

@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock) self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!") self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo" self.mock_key = "foo"

View file

@ -0,0 +1,305 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import hmac
import json
from mock import Mock
from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
from synapse.util import Clock
from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class UserRegisterTestCase(unittest.TestCase):
def setUp(self):
self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
self.secrets = Mock()
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
self.resource = JsonResource(self.hs)
register_servlets(self.hs, self.resource)
def test_disabled(self):
"""
If there is no shared secret, registration through this method will be
prevented.
"""
self.hs.config.registration_shared_secret = None
request, channel = make_request("POST", self.url, b'{}')
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
'Shared secret registration is not enabled', channel.json_body["error"]
)
def test_get_nonce(self):
"""
Calling GET on the endpoint will return a randomised nonce, using the
homeserver's secrets provider.
"""
secrets = Mock()
secrets.token_hex = Mock(return_value="abcd")
self.hs.get_secrets = Mock(return_value=secrets)
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
def test_expired_nonce(self):
"""
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
# 59 seconds
self.clock.advance(59)
body = json.dumps({"nonce": nonce})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds
self.clock.advance(2)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
def test_register_incorrect_nonce(self):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
"""
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self):
"""
A valid unrecognised nonce.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
{
"nonce": nonce,
"username": "bob",
"password": "abc123",
"admin": True,
"mac": want_mac,
}
).encode('utf8')
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
def test_missing_parts(self):
"""
Synapse will complain if you don't give nonce, username, password, and
mac. Admin is optional. Additional checks are done for length and
type.
"""
def nonce():
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
return channel.json_body["nonce"]
#
# Nonce check
#
# Must be present
body = json.dumps({})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
#
# Username checks
#
# Must be present
body = json.dumps({"nonce": nonce()})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
#
# Username checks
#
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])

View file

@ -14,100 +14,30 @@
# limitations under the License. # limitations under the License.
""" Tests REST events for /events paths.""" """ Tests REST events for /events paths."""
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from six import PY3
# twisted imports
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.client.v1.events
import synapse.rest.client.v1.register
import synapse.rest.client.v1.room
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver from ....utils import MockHttpResource, setup_test_homeserver
from .utils import RestTestCase from .utils import RestTestCase
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
class EventStreamPaginationApiTestCase(unittest.TestCase):
""" Tests event streaming query parameters and start/end keys used in the
Pagination stream API. """
user_id = "sid1"
def setUp(self):
# configure stream and inject items
pass
def tearDown(self):
pass
def TODO_test_long_poll(self):
# stream from 'end' key, send (self+other) message, expect message.
# stream from 'END', send (self+other) message, expect message.
# stream from 'end' key, send (self+other) topic, expect topic.
# stream from 'END', send (self+other) topic, expect topic.
# stream from 'end' key, send (self+other) invite, expect invite.
# stream from 'END', send (self+other) invite, expect invite.
pass
def TODO_test_stream_forward(self):
# stream from START, expect injected items
# stream from 'start' key, expect same content
# stream from 'end' key, expect nothing
# stream from 'END', expect nothing
# The following is needed for cases where content is removed e.g. you
# left a room, so the token you're streaming from is > the one that
# would be returned naturally from START>END.
# stream from very new token (higher than end key), expect same token
# returned as end key
pass
def TODO_test_limits(self):
# stream from a key, expect limit_num items
# stream from START, expect limit_num items
pass
def TODO_test_range(self):
# stream from key to key, expect X items
# stream from key to END, expect X items
# stream from START to key, expect X items
# stream from START to END, expect all items
pass
def TODO_test_direction(self):
# stream from END to START and fwds, expect newest first
# stream from END to START and bwds, expect oldest first
# stream from START to END and fwds, expect oldest first
# stream from START to END and bwds, expect newest first
pass
class EventStreamPermissionsTestCase(RestTestCase): class EventStreamPermissionsTestCase(RestTestCase):
""" Tests event streaming (GET /events). """ """ Tests event streaming (GET /events). """
if PY3:
skip = "Skip on Py3 until ported to use not V1 only register."
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
import synapse.rest.client.v1.events
import synapse.rest.client.v1_only.register
import synapse.rest.client.v1.room
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource) synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)

View file

@ -16,27 +16,26 @@
import json import json
from mock import Mock from mock import Mock
from six import PY3
from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactorClock
from synapse.rest.client.v1.register import CreateUserRestServlet from synapse.http.server import JsonResource
from synapse.rest.client.v1_only.register import register_servlets
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders from tests.server import make_request, setup_test_homeserver
class CreateUserServletTestCase(unittest.TestCase): class CreateUserServletTestCase(unittest.TestCase):
"""
Tests for CreateUserRestServlet.
"""
if PY3:
skip = "Not ported to Python 3."
def setUp(self): def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = ""
self.request = Mock(
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
path='/_matrix/client/api/v1/createUser'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.registration_handler = Mock() self.registration_handler = Mock()
self.appservice = Mock(sender="@as:test") self.appservice = Mock(sender="@as:test")
@ -44,39 +43,49 @@ class CreateUserServletTestCase(unittest.TestCase):
get_app_service_by_token=Mock(return_value=self.appservice) get_app_service_by_token=Mock(return_value=self.appservice)
) )
# do the dance to hook things up to the hs global handlers = Mock(registration_handler=self.registration_handler)
handlers = Mock( self.clock = MemoryReactorClock()
registration_handler=self.registration_handler, self.hs_clock = Clock(self.clock)
self.hs = self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
) )
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers) self.hs.get_handlers = Mock(return_value=handlers)
self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks
def test_POST_createuser_with_valid_user(self): def test_POST_createuser_with_valid_user(self):
res = JsonResource(self.hs)
register_servlets(self.hs, res)
request_data = json.dumps(
{
"localpart": "someone",
"displayname": "someone interesting",
"duration_seconds": 200,
}
)
url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
user_id = "@someone:interesting" user_id = "@someone:interesting"
token = "my token" token = "my token"
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"localpart": "someone",
"displayname": "someone interesting",
"duration_seconds": 200
})
self.registration_handler.get_or_create_user = Mock( self.registration_handler.get_or_create_user = Mock(
return_value=(user_id, token) return_value=(user_id, token)
) )
(code, result) = yield self.servlet.on_POST(self.request) request, channel = make_request(b"POST", url, request_data)
self.assertEquals(code, 200) request.render(res)
# Advance the clock because it waits
self.clock.advance(1)
self.assertEquals(channel.result["code"], b"200")
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))

File diff suppressed because it is too large Load diff

View file

@ -16,13 +16,14 @@
import json import json
import time import time
# twisted imports import attr
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
# trial imports
from tests import unittest from tests import unittest
from tests.server import make_request, wait_until_result
class RestTestCase(unittest.TestCase): class RestTestCase(unittest.TestCase):
@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase):
for key in required: for key in required:
self.assertEquals(required[key], actual[key], self.assertEquals(required[key], actual[key],
msg="%s mismatch. %s" % (key, actual)) msg="%s mismatch. %s" % (key, actual))
@attr.s
class RestHelper(object):
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
hs = attr.ib()
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = b"/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
if tok:
path = path + b"?access_token=%s" % tok.encode('ascii')
request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id
return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
room=room,
src=src,
targ=targ,
tok=tok,
membership=Membership.INVITE,
expect_code=expect_code,
)
def join(self, room=None, user=None, expect_code=200, tok=None):
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.JOIN,
expect_code=expect_code,
)
def leave(self, room=None, user=None, expect_code=200, tok=None):
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.LEAVE,
expect_code=expect_code,
)
def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
temp_id = self.auth_user_id
self.auth_user_id = src
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
if tok:
path = path + "?access_token=%s" % tok
data = {"membership": membership}
request, channel = make_request(
b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
)
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
self.auth_user_id = temp_id
@defer.inlineCallbacks
def register(self, user_id):
(code, response) = yield self.mock_resource.trigger(
"POST",
"/_matrix/client/r0/register",
json.dumps(
{"user": user_id, "password": "test", "type": "m.login.password"}
),
)
self.assertEquals(200, code)
defer.returnValue(response)
@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body
if tok:
path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
self.assertEquals(expect_code, code, msg=str(response))

View file

@ -1,61 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.types import UserID
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
PATH_PREFIX = "/_matrix/client/v2_alpha"
class V2AlphaRestTestCase(unittest.TestCase):
# Consumer must define
# USER_ID = <some string>
# TO_REGISTER = [<list of REST servlets to register>]
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
datastore=self.make_datastore_mock(),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
hs.get_auth().get_user_by_access_token = get_user_by_access_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
store = Mock(spec=[
"insert_client_ip",
])
store.get_app_service_by_token = Mock(return_value=None)
return store

View file

@ -13,35 +13,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
from synapse.types import UserID from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
from ....utils import MockHttpResource, setup_test_homeserver ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
setup_test_homeserver,
wait_until_result,
)
PATH_PREFIX = "/_matrix/client/v2_alpha" PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = b"@apple:test"
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.hs = yield setup_test_homeserver( self.hs = setup_test_homeserver(
http_client=None, http_client=None, clock=self.hs_clock, reactor=self.clock
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
) )
self.auth = self.hs.get_auth() self.auth = self.hs.get_auth()
@ -55,82 +57,103 @@ class FilterTestCase(unittest.TestCase):
def get_user_by_req(request, allow_guest=False, rights="access"): def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester( return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None) UserID.from_string(self.USER_ID), 1, False, None
)
self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering() self.filtering = self.hs.get_filtering()
self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.mock_resource) r.register_servlets(self.hs, self.resource)
@defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger( request, channel = make_request(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON b"POST",
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
) )
self.assertEquals(200, code) request.render(self.resource)
self.assertEquals({"filter_id": "0"}, response) wait_until_result(self.clock, channel)
filter = yield self.store.get_user_filter(
user_localpart='apple', self.assertEqual(channel.result["code"], b"200")
filter_id=0, self.assertEqual(channel.json_body, {"filter_id": "0"})
) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
self.assertEquals(filter, self.EXAMPLE_FILTER) self.clock.advance(0)
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
@defer.inlineCallbacks
def test_add_filter_for_other_user(self): def test_add_filter_for_other_user(self):
(code, response) = yield self.mock_resource.trigger( request, channel = make_request(
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON b"POST",
b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
) )
self.assertEquals(403, code) request.render(self.resource)
self.assertEquals(response['errcode'], Codes.FORBIDDEN) wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_add_filter_non_local_user(self): def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine _is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False self.hs.is_mine = lambda target_user: False
(code, response) = yield self.mock_resource.trigger( request, channel = make_request(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON b"POST",
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource)
wait_until_result(self.clock, channel)
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine
self.assertEquals(403, code) self.assertEqual(channel.result["code"], b"403")
self.assertEquals(response['errcode'], Codes.FORBIDDEN) self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
filter_id = yield self.filtering.add_user_filter( filter_id = self.filtering.add_user_filter(
user_localpart='apple', user_localpart="apple", user_filter=self.EXAMPLE_FILTER
user_filter=self.EXAMPLE_FILTER
) )
(code, response) = yield self.mock_resource.trigger_get( self.clock.advance(1)
"/user/%s/filter/%s" % (self.USER_ID, filter_id) filter_id = filter_id.result
request, channel = make_request(
b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
) )
self.assertEquals(200, code) request.render(self.resource)
self.assertEquals(self.EXAMPLE_FILTER, response) wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
@defer.inlineCallbacks
def test_get_filter_non_existant(self): def test_get_filter_non_existant(self):
(code, response) = yield self.mock_resource.trigger_get( request, channel = make_request(
"/user/%s/filter/12382148321" % (self.USER_ID) b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
) )
self.assertEquals(400, code) request.render(self.resource)
self.assertEquals(response['errcode'], Codes.NOT_FOUND) wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode # Currently invalid params do not have an appropriate errcode
# in errors.py # in errors.py
@defer.inlineCallbacks
def test_get_filter_invalid_id(self): def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get( request, channel = make_request(
"/user/%s/filter/foobar" % (self.USER_ID) b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
) )
self.assertEquals(400, code) request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
@defer.inlineCallbacks
def test_get_filter_no_id(self): def test_get_filter_no_id(self):
(code, response) = yield self.mock_resource.trigger_get( request, channel = make_request(
"/user/%s/filter/" % (self.USER_ID) b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
) )
self.assertEquals(400, code) request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400")

View file

@ -2,165 +2,192 @@ import json
from mock import Mock from mock import Mock
from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders from tests.server import make_request, setup_test_homeserver, wait_until_result
class RegisterRestServletTestCase(unittest.TestCase): class RegisterRestServletTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = "" self.clock = MemoryReactorClock()
self.request = Mock( self.hs_clock = Clock(self.clock)
content=Mock(read=Mock(side_effect=lambda: self.request_data)), self.url = b"/_matrix/client/r0/register"
path='/_matrix/api/v2_alpha/register'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(
side_effect=lambda x: self.appservice) get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
) )
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock( self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None) get_session_data=Mock(return_value=None),
) )
self.registration_handler = Mock() self.registration_handler = Mock()
self.identity_handler = Mock() self.identity_handler = Mock()
self.login_handler = Mock() self.login_handler = Mock()
self.device_handler = Mock() self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global # do the dance to hook it up to the hs global
self.handlers = Mock( self.handlers = Mock(
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler, identity_handler=self.identity_handler,
login_handler=self.login_handler login_handler=self.login_handler,
)
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
) )
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
# init the thing we're testing self.resource = JsonResource(self.hs)
self.servlet = RegisterRestServlet(self.hs) register_servlets(self.hs, self.resource)
@defer.inlineCallbacks
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
token = "kermits_access_token" token = "kermits_access_token"
self.request.args = { self.appservice = {"id": "1234"}
"access_token": "i_am_an_app_service" self.registration_handler.appservice_register = Mock(return_value=user_id)
} self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.request_data = json.dumps({ request_data = json.dumps({"username": "kermit"})
"username": "kermit"
})
self.appservice = {
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
return_value=user_id
)
self.auth_handler.get_access_token_for_user_id = Mock(
return_value=token
)
(code, result) = yield self.servlet.on_POST(self.request) request, channel = make_request(
self.assertEquals(code, 200) b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = None # no application service exists self.appservice = None # no application service exists
result = yield self.servlet.on_POST(self.request) request_data = json.dumps({"username": "kermit"})
self.assertEquals(result, (401, None)) request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self): def test_POST_bad_password(self):
self.request_data = json.dumps({ request_data = json.dumps({"username": "kermit", "password": 666})
"username": "kermit", request, channel = make_request(b"POST", self.url, request_data)
"password": 666 request.render(self.resource)
}) wait_until_result(self.clock, channel)
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"], "Invalid password"
)
def test_POST_bad_username(self): def test_POST_bad_username(self):
self.request_data = json.dumps({ request_data = json.dumps({"username": 777, "password": "monkey"})
"username": 777, request, channel = make_request(b"POST", self.url, request_data)
"password": "monkey" request.render(self.resource)
}) wait_until_result(self.clock, channel)
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"], "Invalid username"
)
@defer.inlineCallbacks
def test_POST_user_valid(self): def test_POST_user_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
token = "kermits_access_token" token = "kermits_access_token"
device_id = "frogfone" device_id = "frogfone"
self.request_data = json.dumps({ request_data = json.dumps(
"username": "kermit", {"username": "kermit", "password": "monkey", "device_id": device_id}
"password": "monkey",
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(
return_value=token
) )
self.device_handler.check_device_registered = \ self.registration_handler.check_username = Mock(return_value=True)
Mock(return_value=device_id) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
self.auth_handler.get_login_tuple_for_user_id( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None) user_id, device_id=device_id, initial_device_display_name=None
)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False
self.request_data = json.dumps({ request_data = json.dumps({"username": "kermit", "password": "monkey"})
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, { self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
"username": "kermit",
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t")) self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError) request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"],
"Registration has been disabled",
)
def test_POST_guest_registration(self):
user_id = "a@b"
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
det_data = {
"user_id": user_id,
"home_server": self.hs.hostname,
"device_id": "guest_device",
}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"], "Guest access is disabled"
)

View file

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.types
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import sync
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
setup_test_homeserver,
wait_until_result,
)
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
USER_ID = b"@apple:test"
TO_REGISTER = [sync]
def setUp(self):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.auth = self.hs.get_auth()
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None
)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.resource)
def test_sync_argless(self):
request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200")
self.assertTrue(
set(
[
"next_batch",
"rooms",
"presence",
"account_data",
"to_device",
"device_lists",
]
).issubset(set(channel.json_body.keys()))
)

View file

@ -80,6 +80,11 @@ def make_request(method, path, content=b""):
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
""" """
# Decorate it to be the full path
if not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
path = path.replace("//", "/")
if isinstance(content, text_type): if isinstance(content, text_type):
content = content.encode('utf8') content = content.encode('utf8')
@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100):
clock.advance(0.1) clock.advance(0.1)
def render(request, resource, clock):
request.render(resource)
wait_until_result(clock, request._channel)
class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
A MemoryReactorClock that supports callFromThread. A MemoryReactorClock that supports callFromThread.

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,8 +16,6 @@
from mock import Mock, patch from mock import Mock, patch
from twisted.internet import defer
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from . import unittest from . import unittest
@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dist = Distributor() self.dist = Distributor()
@defer.inlineCallbacks
def test_signal_dispatch(self): def test_signal_dispatch(self):
self.dist.declare("alert") self.dist.declare("alert")
observer = Mock() observer = Mock()
self.dist.observe("alert", observer) self.dist.observe("alert", observer)
d = self.dist.fire("alert", 1, 2, 3) self.dist.fire("alert", 1, 2, 3)
yield d
self.assertTrue(d.called)
observer.assert_called_with(1, 2, 3) observer.assert_called_with(1, 2, 3)
@defer.inlineCallbacks
def test_signal_dispatch_deferred(self):
self.dist.declare("whine")
d_inner = defer.Deferred()
def observer():
return d_inner
self.dist.observe("whine", observer)
d_outer = self.dist.fire("whine")
self.assertFalse(d_outer.called)
d_inner.callback(None)
yield d_outer
self.assertTrue(d_outer.called)
@defer.inlineCallbacks
def test_signal_catch(self): def test_signal_catch(self):
self.dist.declare("alarm") self.dist.declare("alarm")
@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
with patch( with patch(
"synapse.util.distributor.logger", spec=["warning"] "synapse.util.distributor.logger", spec=["warning"]
) as mock_logger: ) as mock_logger:
d = self.dist.fire("alarm", "Go") self.dist.fire("alarm", "Go")
yield d
self.assertTrue(d.called)
observers[0].assert_called_once_with("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go") observers[1].assert_called_once_with("Go")
@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase):
mock_logger.warning.call_args[0][0], str mock_logger.warning.call_args[0][0], str
) )
@defer.inlineCallbacks
def test_signal_catch_no_suppress(self):
# Gut-wrenching
self.dist.suppress_failures = False
self.dist.declare("whail")
class MyException(Exception):
pass
@defer.inlineCallbacks
def observer():
raise MyException("Oopsie")
self.dist.observe("whail", observer)
d = self.dist.fire("whail")
yield self.assertFailure(d, MyException)
self.dist.suppress_failures = True
@defer.inlineCallbacks
def test_signal_prereg(self): def test_signal_prereg(self):
observer = Mock() observer = Mock()
self.dist.observe("flare", observer) self.dist.observe("flare", observer)
self.dist.declare("flare") self.dist.declare("flare")
yield self.dist.fire("flare", 4, 5) self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5) observer.assert_called_with(4, 5)

View file

@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
) )
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
@unittest.DEBUG
def test_cant_hide_past_history(self): def test_cant_hide_past_history(self):
""" """
If you send a message, you must be able to provide the direct If you send a message, you must be able to provide the direct
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
for x, y in d.items() for x, y in d.items()
if x == ("m.room.member", "@us:test") if x == ("m.room.member", "@us:test")
], ],
"auth_chain_ids": d.values(), "auth_chain_ids": list(d.values()),
} }
) )

View file

@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase):
return (200, kwargs) return (200, kwargs)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback) res.register_paths(
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
)
request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
request.render(res) request.render(res)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase):
raise Exception("boo") raise Exception("boo")
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'500') self.assertEqual(channel.result["code"], b'500')
@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase):
return d return d
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) request.render(res)
# No error has been raised yet # No error has been raised yet
@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.result["code"], b'403')
@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase):
self.fail("shouldn't ever get here") self.fail("shouldn't ever get here")
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foobar") request, channel = make_request(b"GET", b"/_matrix/foobar")
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.result["code"], b'400')

324
tests/test_visibility.py Normal file
View file

@ -0,0 +1,324 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
import tests.unittest
from tests.utils import setup_test_homeserver
logger = logging.getLogger(__name__)
TEST_ROOM_ID = "!TEST:ROOM"
class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_filtering(self):
#
# The events to be filtered consist of 10 membership events (it doesn't
# really matter if they are joins or leaves, so let's make them joins).
# One of those membership events is going to be for a user on the
# server we are filtering for (so we can check the filtering is doing
# the right thing).
#
# before we do that, we persist some other events to act as state.
self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
events_to_filter = []
for i in range(0, 10):
user = "@user%i:%s" % (
i, "test_server" if i == 5 else "other_server"
)
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter,
)
# the result should be 5 redacted events, and 5 unredacted events.
for i in range(0, 5):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertNotIn("a", filtered[i].content)
for i in range(5, 10):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
@tests.unittest.DEBUG
@defer.inlineCallbacks
def test_erased_user(self):
# 4 message events, from erased and unerased users, with a membership
# change in the middle of them.
events_to_filter = []
evt = yield self.inject_message("@unerased:local_hs")
events_to_filter.append(evt)
evt = yield self.inject_message("@erased:local_hs")
events_to_filter.append(evt)
evt = yield self.inject_room_member("@joiner:remote_hs")
events_to_filter.append(evt)
evt = yield self.inject_message("@unerased:local_hs")
events_to_filter.append(evt)
evt = yield self.inject_message("@erased:local_hs")
events_to_filter.append(evt)
# the erasey user gets erased
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter,
)
for i in range(0, len(events_to_filter)):
self.assertEqual(
events_to_filter[i].event_id, filtered[i].event_id,
"Unexpected event at result position %i" % (i, )
)
for i in (0, 3):
self.assertEqual(
events_to_filter[i].content["body"], filtered[i].content["body"],
"Unexpected event content at result position %i" % (i,)
)
for i in (1, 4):
self.assertNotIn("body", filtered[i].content)
@defer.inlineCallbacks
def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility}
builder = self.event_builder_factory.new({
"type": "m.room.history_visibility",
"sender": user_id,
"state_key": "",
"room_id": TEST_ROOM_ID,
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.hs.get_datastore().persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}):
content = {"membership": membership}
content.update(extra_content)
builder = self.event_builder_factory.new({
"type": "m.room.member",
"sender": user_id,
"state_key": user_id,
"room_id": TEST_ROOM_ID,
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.hs.get_datastore().persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
if content is None:
content = {"body": "testytest"}
builder = self.event_builder_factory.new({
"type": "m.room.message",
"sender": user_id,
"room_id": TEST_ROOM_ID,
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.hs.get_datastore().persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def test_large_room(self):
# see what happens when we have a large room with hundreds of thousands
# of membership events
# As above, the events to be filtered consist of 10 membership events,
# where one of them is for a user on the server we are filtering for.
import cProfile
import pstats
import time
# we stub out the store, because building up all that state the normal
# way is very slow.
test_store = _TestStore()
# our initial state is 100000 membership events and one
# history_visibility event.
room_state = []
history_visibility_evt = FrozenEvent({
"event_id": "$history_vis",
"type": "m.room.history_visibility",
"sender": "@resident_user_0:test.com",
"state_key": "",
"room_id": TEST_ROOM_ID,
"content": {"history_visibility": "joined"},
})
room_state.append(history_visibility_evt)
test_store.add_event(history_visibility_evt)
for i in range(0, 100000):
user = "@resident_user_%i:test.com" % (i, )
evt = FrozenEvent({
"event_id": "$res_event_%i" % (i, ),
"type": "m.room.member",
"state_key": user,
"sender": user,
"room_id": TEST_ROOM_ID,
"content": {
"membership": "join",
"extra": "zzz,"
},
})
room_state.append(evt)
test_store.add_event(evt)
events_to_filter = []
for i in range(0, 10):
user = "@user%i:%s" % (
i, "test_server" if i == 5 else "other_server"
)
evt = FrozenEvent({
"event_id": "$evt%i" % (i, ),
"type": "m.room.member",
"state_key": user,
"sender": user,
"room_id": TEST_ROOM_ID,
"content": {
"membership": "join",
"extra": "zzz",
},
})
events_to_filter.append(evt)
room_state.append(evt)
test_store.add_event(evt)
test_store.set_state_ids_for_event(evt, {
(e.type, e.state_key): e.event_id for e in room_state
})
pr = cProfile.Profile()
pr.enable()
logger.info("Starting filtering")
start = time.time()
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter,
)
logger.info("Filtering took %f seconds", time.time() - start)
pr.disable()
with open("filter_events_for_server.profile", "w+") as f:
ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
ps.print_stats()
# the result should be 5 redacted events, and 5 unredacted events.
for i in range(0, 5):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertNotIn("extra", filtered[i].content)
for i in range(5, 10):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["extra"], "zzz")
test_large_room.skip = "Disabled by default because it's slow"
class _TestStore(object):
"""Implements a few methods of the DataStore, so that we can test
filter_events_for_server
"""
def __init__(self):
# data for get_events: a map from event_id to event
self.events = {}
# data for get_state_ids_for_events mock: a map from event_id to
# a map from (type_state_key) -> event_id for the state at that
# event
self.state_ids_for_events = {}
def add_event(self, event):
self.events[event.event_id] = event
def set_state_ids_for_event(self, event, state):
self.state_ids_for_events[event.event_id] = state
def get_state_ids_for_events(self, events, types):
res = {}
include_memberships = False
for (type, state_key) in types:
if type == "m.room.history_visibility":
continue
if type != "m.room.member" or state_key is not None:
raise RuntimeError(
"Unimplemented: get_state_ids with type (%s, %s)" %
(type, state_key),
)
include_memberships = True
if include_memberships:
for event_id in events:
res[event_id] = self.state_ids_for_events[event_id]
else:
k = ("m.room.history_visibility", "")
for event_id in events:
hve = self.state_ids_for_events[event_id][k]
res[event_id] = {k: hve}
return succeed(res)
def get_events(self, events):
return succeed({
event_id: self.events[event_id] for event_id in events
})
def are_users_erased(self, users):
return succeed({u: False for u in users})

View file

@ -109,6 +109,17 @@ class TestCase(unittest.TestCase):
except AssertionError as e: except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key) raise (type(e))(e.message + " for '.%s'" % key)
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
Args:
required (dict): The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEquals(required[key], actual[key],
msg="%s mismatch. %s" % (key, actual))
def DEBUG(target): def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG. """A decorator to set the .loglevel attribute to logging.DEBUG.

View file

@ -1,70 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.util.async import Limiter
from tests import unittest
class LimiterTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_limiter(self):
limiter = Limiter(3)
key = object()
d1 = limiter.queue(key)
cm1 = yield d1
d2 = limiter.queue(key)
cm2 = yield d2
d3 = limiter.queue(key)
cm3 = yield d3
d4 = limiter.queue(key)
self.assertFalse(d4.called)
d5 = limiter.queue(key)
self.assertFalse(d5.called)
with cm1:
self.assertFalse(d4.called)
self.assertFalse(d5.called)
self.assertTrue(d4.called)
self.assertFalse(d5.called)
with cm3:
self.assertFalse(d5.called)
self.assertTrue(d5.called)
with cm2:
pass
with (yield d4):
pass
with (yield d5):
pass
d6 = limiter.queue(key)
with (yield d6):
pass

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,6 +17,7 @@
from six.moves import range from six.moves import range
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
func(i) func(i)
return func(1000) return func(1000)
@defer.inlineCallbacks
def test_multiple_entries(self):
limiter = Linearizer(max_count=3)
key = object()
d1 = limiter.queue(key)
cm1 = yield d1
d2 = limiter.queue(key)
cm2 = yield d2
d3 = limiter.queue(key)
cm3 = yield d3
d4 = limiter.queue(key)
self.assertFalse(d4.called)
d5 = limiter.queue(key)
self.assertFalse(d5.called)
with cm1:
self.assertFalse(d4.called)
self.assertFalse(d5.called)
cm4 = yield d4
self.assertFalse(d5.called)
with cm3:
self.assertFalse(d5.called)
cm5 = yield d5
with cm2:
pass
with cm4:
pass
with cm5:
pass
d6 = limiter.queue(key)
with (yield d6):
pass
@defer.inlineCallbacks
def test_cancellation(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
d3 = linearizer.queue(key)
self.assertFalse(d3.called)
d2.cancel()
with cm1:
pass
self.assertTrue(d2.called)
try:
yield d2
self.fail("Expected d2 to raise CancelledError")
except CancelledError:
pass
with (yield d3):
pass

Some files were not shown because too many files have changed in this diff Show more