Merge branch 'develop' of github.com:matrix-org/synapse into erikj/refactor_repl_servlet

This commit is contained in:
Erik Johnston 2018-08-03 09:25:15 +01:00
commit cb298ff623
67 changed files with 1074 additions and 648 deletions

View file

@ -27,8 +27,9 @@ Describe here the problem that you are experiencing, or the feature you are requ
Describe how what happens differs from what you expected. Describe how what happens differs from what you expected.
If you can identify any relevant log snippets from _homeserver.log_, please include <!-- If you can identify any relevant log snippets from _homeserver.log_, please include
those here (please be careful to remove any personal or private data): those (please be careful to remove any personal or private data). Please surround them with
``` (three backticks, on a line on their own), so that they are formatted legibly. -->
### Version information ### Version information

View file

@ -63,3 +63,6 @@ Christoph Witzany <christoph at web.crofting.com>
Pierre Jaury <pierre at jaury.eu> Pierre Jaury <pierre at jaury.eu>
* Docker packaging * Docker packaging
Serban Constantin <serban.constantin at gmail dot com>
* Small bug fix

View file

@ -1,3 +1,12 @@
Synapse 0.33.1 (2018-08-02)
===========================
SECURITY FIXES
--------------
- Fix a potential issue where servers could request events for rooms they have not joined. ([\#3641](https://github.com/matrix-org/synapse/issues/3641))
- Fix a potential issue where users could see events in private rooms before they joined. ([\#3642](https://github.com/matrix-org/synapse/issues/3642))
Synapse 0.33.0 (2018-07-19) Synapse 0.33.0 (2018-07-19)
=========================== ===========================

View file

@ -51,7 +51,7 @@ makes it horribly hard to review otherwise.
Changelog Changelog
~~~~~~~~~ ~~~~~~~~~
All changes, even minor ones, need a corresponding changelog All changes, even minor ones, need a corresponding changelog / newsfragment
entry. These are managed by Towncrier entry. These are managed by Towncrier
(https://github.com/hawkowl/towncrier). (https://github.com/hawkowl/towncrier).

View file

@ -1,16 +1,32 @@
FROM docker.io/python:2-alpine3.7 FROM docker.io/python:2-alpine3.7
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev RUN apk add --no-cache --virtual .nacl_deps \
build-base \
libffi-dev \
libjpeg-turbo-dev \
libressl-dev \
libxslt-dev \
linux-headers \
postgresql-dev \
su-exec \
zlib-dev
COPY . /synapse COPY . /synapse
# A wheel cache may be provided in ./cache for faster build # A wheel cache may be provided in ./cache for faster build
RUN cd /synapse \ RUN cd /synapse \
&& pip install --upgrade pip setuptools psycopg2 lxml \ && pip install --upgrade \
lxml \
pip \
psycopg2 \
setuptools \
&& mkdir -p /synapse/cache \ && mkdir -p /synapse/cache \
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \ && pip install -f /synapse/cache --upgrade --process-dependency-links . \
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \ && mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
&& rm -rf setup.py setup.cfg synapse && rm -rf \
setup.cfg \
setup.py \
synapse
VOLUME ["/data"] VOLUME ["/data"]

1
changelog.d/2952.bugfix Normal file
View file

@ -0,0 +1 @@
Make /directory/list API return 404 for room not found instead of 400

1
changelog.d/3384.misc Normal file
View file

@ -0,0 +1 @@
Rewrite cache list decorator

1
changelog.d/3543.misc Normal file
View file

@ -0,0 +1 @@
Improve Dockerfile and docker-compose instructions

1
changelog.d/3569.bugfix Normal file
View file

@ -0,0 +1 @@
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.

1
changelog.d/3612.misc Normal file
View file

@ -0,0 +1 @@
Make EventStore inherit from EventFederationStore

1
changelog.d/3621.misc Normal file
View file

@ -0,0 +1 @@
Refactor FederationHandler to move DB writes into separate functions

1
changelog.d/3628.misc Normal file
View file

@ -0,0 +1 @@
Remove unused field "pdu_failures" from transactions.

1
changelog.d/3630.feature Normal file
View file

@ -0,0 +1 @@
Add ability to limit number of monthly active users on the server

1
changelog.d/3634.misc Normal file
View file

@ -0,0 +1 @@
rename replication_layer to federation_client

1
changelog.d/3638.misc Normal file
View file

@ -0,0 +1 @@
Factor out exception handling in federation_client

1
changelog.d/3639.feature Normal file
View file

@ -0,0 +1 @@
When we fail to join a room over federation, pass the error code back to the client.

1
changelog.d/3645.misc Normal file
View file

@ -0,0 +1 @@
Update CONTRIBUTING to mention newsfragments.

View file

@ -9,13 +9,7 @@ use that server.
## Build ## Build
Build the docker image with the `docker build` command from the root of the synapse repository. Build the docker image with the `docker-compose build` command.
```
docker build -t docker.io/matrixdotorg/synapse .
```
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project. You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.

View file

@ -6,6 +6,7 @@ version: '3'
services: services:
synapse: synapse:
build: ../..
image: docker.io/matrixdotorg/synapse:latest image: docker.io/matrixdotorg/synapse:latest
# Since snyapse does not retry to connect to the database, restart upon # Since snyapse does not retry to connect to the database, restart upon
# failure # failure

View file

@ -0,0 +1,6 @@
# Using the Synapse Grafana dashboard
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
3. Set up additional recording rules

View file

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.33.0" __version__ = "0.33.1"

View file

@ -252,10 +252,10 @@ class Auth(object):
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None)) defer.returnValue((None, None))
if "user_id" not in request.args: if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))
user_id = request.args["user_id"][0] user_id = request.args[b"user_id"][0].decode('utf8')
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))

View file

@ -55,6 +55,7 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
@ -69,20 +70,6 @@ class CodeMessageException(RuntimeError):
self.code = code self.code = code
self.msg = msg self.msg = msg
def error_dict(self):
return cs_error(self.msg)
class MatrixCodeMessageException(CodeMessageException):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
super(MatrixCodeMessageException, self).__init__(code, msg)
self.errcode = errcode
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error """A base exception type for matrix errors which have an errcode and error
@ -108,38 +95,28 @@ class SynapseError(CodeMessageException):
self.errcode, self.errcode,
) )
@classmethod
def from_http_response_exception(cls, err):
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to class ProxiedRequestError(SynapseError):
decide how to map the failure onto a matrix error to send back to the """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
client.
An attempt is made to parse the body of the http response as a matrix Attributes:
error. If that succeeds, the errcode and error message from the body errcode (str): Matrix error code e.g 'M_FORBIDDEN'
are used as the errcode and error message in the new synapse error.
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
set to the reason code from the HTTP response.
Args:
err (HttpResponseException):
Returns:
SynapseError:
""" """
# try to parse the body as json, to get better errcode/msg, but def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
# default to M_UNKNOWN with the HTTP status as the error text super(ProxiedRequestError, self).__init__(
try: code, msg, errcode
j = json.loads(err.response) )
except ValueError: if additional_fields is None:
j = {} self._additional_fields = {}
errcode = j.get('errcode', Codes.UNKNOWN) else:
errmsg = j.get('error', err.msg) self._additional_fields = dict(additional_fields)
res = SynapseError(err.code, errmsg, errcode) def error_dict(self):
return res return cs_error(
self.msg,
self.errcode,
**self._additional_fields
)
class ConsentNotGivenError(SynapseError): class ConsentNotGivenError(SynapseError):
@ -308,14 +285,6 @@ class LimitExceededError(SynapseError):
) )
def cs_exception(exception):
if isinstance(exception, CodeMessageException):
return exception.error_dict()
else:
logger.error("Unknown exception type: %s", type(exception))
return {}
def cs_error(msg, code=Codes.UNKNOWN, **kwargs): def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server """ Utility method for constructing an error response for client-server
interactions. interactions.
@ -372,7 +341,7 @@ class HttpResponseException(CodeMessageException):
Represents an HTTP-level failure of an outbound request Represents an HTTP-level failure of an outbound request
Attributes: Attributes:
response (str): body of response response (bytes): body of response
""" """
def __init__(self, code, msg, response): def __init__(self, code, msg, response):
""" """
@ -380,7 +349,39 @@ class HttpResponseException(CodeMessageException):
Args: Args:
code (int): HTTP status code code (int): HTTP status code
msg (str): reason phrase from HTTP response status line msg (str): reason phrase from HTTP response status line
response (str): body of response response (bytes): body of response
""" """
super(HttpResponseException, self).__init__(code, msg) super(HttpResponseException, self).__init__(code, msg)
self.response = response self.response = response
def to_synapse_error(self):
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to
decide how to map the failure onto a matrix error to send back to the
client.
An attempt is made to parse the body of the http response as a matrix
error. If that succeeds, the errcode and error message from the body
are used as the errcode and error message in the new synapse error.
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
set to the reason code from the HTTP response.
Returns:
SynapseError:
"""
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
j = json.loads(self.response)
except ValueError:
j = {}
if not isinstance(j, dict):
j = {}
errcode = j.pop('errcode', Codes.UNKNOWN)
errmsg = j.pop('error', self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)

View file

@ -20,6 +20,8 @@ import sys
from six import iteritems from six import iteritems
from prometheus_client import Gauge
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.resource import EncodingResourceWrapper, NoResource
@ -300,6 +302,11 @@ class SynapseHomeServer(HomeServer):
quit_with_error(e.message) quit_with_error(e.message)
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
def setup(config_options): def setup(config_options):
""" """
Args: Args:
@ -512,6 +519,18 @@ def run(hs):
# table will decrease # table will decrease
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
@defer.inlineCallbacks
def generate_monthly_active_users():
count = 0
if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().count_monthly_users()
current_mau_gauge.set(float(count))
max_mau_value_gauge.set(float(hs.config.max_mau_value))
generate_monthly_active_users()
if hs.config.limit_usage_by_mau:
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
if hs.config.report_stats: if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals") logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)

View file

@ -67,6 +67,14 @@ class ServerConfig(Config):
"block_non_admin_invites", False, "block_non_admin_invites", False,
) )
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
if self.limit_usage_by_mau:
self.max_mau_value = config.get(
"max_mau_value", 0,
)
else:
self.max_mau_value = 0
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(
@ -209,6 +217,8 @@ class ServerConfig(Config):
# different cores. See # different cores. See
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/. # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
# #
# This setting requires the affinity package to be installed!
#
# cpu_affinity: 0xFFFFFFFF # cpu_affinity: 0xFFFFFFFF
# Whether to serve a web client from the HTTP/HTTPS root resource. # Whether to serve a web client from the HTTP/HTTPS root resource.

View file

@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000 PDU_RETRY_TIME_MS = 1 * 60 * 1000
class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
pass
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationClient, self).__init__(hs) super(FederationClient, self).__init__(hs)
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
"""Try an operation on a series of servers, until it succeeds
Args:
description (unicode): description of the operation we're doing, for logging
destinations (Iterable[unicode]): list of server_names to try
callback (callable): Function to run for each server. Passed a single
argument: the server_name to try. May return a deferred.
If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is
reraised.
Otherwise, if the callback raises an Exception the error is logged and the
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
Returns:
The [Deferred] result of callback, if it succeeds
Raises:
SynapseError if the chosen remote server returns a 300/400 code.
RuntimeError if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
try:
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
logger.warn(
"Failed to %s via %s: %s",
description, destination, e,
)
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
logger.warn(
"Failed to %s via %s: %i %s",
description, destination, e.code, e.message,
)
except Exception:
logger.warn(
"Failed to %s via %s",
description, destination, exc_info=1,
)
raise RuntimeError("Failed to %s via any server", description)
def make_membership_event(self, destinations, room_id, user_id, membership, def make_membership_event(self, destinations, room_id, user_id, membership,
content={},): content={},):
""" """
@ -481,7 +543,7 @@ class FederationClient(FederationBase):
Deferred: resolves to a tuple of (origin (str), event (object)) Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event. where origin is the remote homeserver which generated the event.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
@ -492,11 +554,9 @@ class FederationClient(FederationBase):
"make_membership_event called with membership='%s', must be one of %s" % "make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships)) (membership, ",".join(valid_memberships))
) )
for destination in destinations:
if destination == self.server_name:
continue
try: @defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership destination, room_id, user_id, membership
) )
@ -518,24 +578,11 @@ class FederationClient(FederationBase):
defer.returnValue( defer.returnValue(
(destination, ev) (destination, ev)
) )
break
except CodeMessageException as e: return self._try_destination_list(
if not 500 <= e.code < 600: "make_" + membership, destinations, send_request,
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
) )
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
"""Sends a join event to one of a list of homeservers. """Sends a join event to one of a list of homeservers.
@ -552,17 +599,14 @@ class FederationClient(FederationBase):
giving the serer the event was sent to, ``state`` (?) and giving the serer the event was sent to, ``state`` (?) and
``auth_chain``. ``auth_chain``.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
for destination in destinations: @defer.inlineCallbacks
if destination == self.server_name: def send_request(destination):
continue
try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
destination=destination, destination=destination,
@ -624,31 +668,22 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException as e: return self._try_destination_list("send_join", destinations, send_request)
if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu): def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try:
code, content = yield self.transport_layer.send_invite( code, content = yield self.transport_layer.send_invite(
destination=destination, destination=destination,
room_id=room_id, room_id=room_id,
event_id=event_id, event_id=event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
except HttpResponseException as e:
if e.code == 403:
raise e.to_synapse_error()
raise
pdu_dict = content["event"] pdu_dict = content["event"]
@ -663,7 +698,6 @@ class FederationClient(FederationBase):
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu): def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers. """Sends a leave event to one of a list of homeservers.
@ -680,16 +714,13 @@ class FederationClient(FederationBase):
Return: Return:
Deferred: resolves to None. Deferred: resolves to None.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``SynapseError`` if the chosen remote server
returns a non-200 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
for destination in destinations: @defer.inlineCallbacks
if destination == self.server_name: def send_request(destination):
continue
try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave( _, content = yield self.transport_layer.send_leave(
destination=destination, destination=destination,
@ -700,15 +731,8 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
defer.returnValue(None) defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.") return self._try_destination_list("send_leave", destinations, send_request)
def get_public_rooms(self, destination, limit=None, since_token=None, def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False, search_filter=None, include_all_networks=False,

View file

@ -207,10 +207,6 @@ class FederationServer(FederationBase):
edu.content edu.content
) )
pdu_failures = getattr(transaction, "pdu_failures", [])
for fail in pdu_failures:
logger.info("Got failure %r", fail)
response = { response = {
"pdus": pdu_results, "pdus": pdu_results,
} }
@ -430,6 +426,7 @@ class FederationServer(FederationBase):
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, origin,
event_id, event_id,
room_id,
signed_auth, signed_auth,
content.get("rejects", []), content.get("rejects", []),
content.get("missing", []), content.get("missing", []),

View file

@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu self.edus = SortedDict() # stream position -> Edu
self.failures = SortedDict() # stream position -> (destination, Failure)
self.device_messages = SortedDict() # stream position -> destination self.device_messages = SortedDict() # stream position -> destination
self.pos = 1 self.pos = 1
@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [ for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
"edus", "failures", "device_messages", "pos_time", "edus", "device_messages", "pos_time",
]: ]:
register(queue_name, getattr(self, queue_name)) register(queue_name, getattr(self, queue_name))
@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]: for key in keys[:i]:
del self.edus[key] del self.edus[key]
# Delete things out of failure map
keys = self.failures.keys()
i = self.failures.bisect_left(position_to_delete)
for key in keys[:i]:
del self.failures[key]
# Delete things out of device map # Delete things out of device map
keys = self.device_messages.keys() keys = self.device_messages.keys()
i = self.device_messages.bisect_left(position_to_delete) i = self.device_messages.bisect_left(position_to_delete)
@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
self.notifier.on_new_replication_data()
def send_device_messages(self, destination): def send_device_messages(self, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus: for (pos, edu) in edus:
rows.append((pos, EduRow(edu))) rows.append((pos, EduRow(edu)))
# Fetch changed failures
i = self.failures.bisect_right(from_token)
j = self.failures.bisect_right(to_token) + 1
failures = self.failures.items()[i:j]
for (pos, (destination, failure)) in failures:
rows.append((pos, FailureRow(
destination=destination,
failure=failure,
)))
# Fetch changed device messages # Fetch changed device messages
i = self.device_messages.bisect_right(from_token) i = self.device_messages.bisect_right(from_token)
j = self.device_messages.bisect_right(to_token) + 1 j = self.device_messages.bisect_right(to_token) + 1
@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu) buff.edus.setdefault(self.edu.destination, []).append(self.edu)
class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
"destination", # str
"failure",
))):
"""Streams failures to a remote server. Failures are issued when there was
something wrong with a transaction the remote sent us, e.g. it included
an event that was invalid.
"""
TypeId = "f"
@staticmethod
def from_data(data):
return FailureRow(
destination=data["destination"],
failure=data["failure"],
)
def to_data(self):
return {
"destination": self.destination,
"failure": self.failure,
}
def add_to_buffer(self, buff):
buff.failures.setdefault(self.destination, []).append(self.failure)
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str "destination", # str
))): ))):
@ -471,7 +417,6 @@ TypeToRow = {
PresenceRow, PresenceRow,
KeyedEduRow, KeyedEduRow,
EduRow, EduRow,
FailureRow,
DeviceRow, DeviceRow,
) )
} }
@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState) "presence", # list(UserPresenceState)
"keyed_edus", # dict of destination -> { key -> Edu } "keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu] "edus", # dict of destination -> [Edu]
"failures", # dict of destination -> [failures]
"device_destinations", # set of destinations "device_destinations", # set of destinations
)) ))
@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
presence=[], presence=[],
keyed_edus={}, keyed_edus={},
edus={}, edus={},
failures={},
device_destinations=set(), device_destinations=set(),
) )
@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
edu.destination, edu.edu_type, edu.content, key=None, edu.destination, edu.edu_type, edu.content, key=None,
) )
for destination, failure_list in iteritems(buff.failures):
for failure in failure_list:
transaction_queue.send_failure(destination, failure)
for destination in buff.device_destinations: for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination) transaction_queue.send_device_messages(destination)

View file

@ -116,9 +116,6 @@ class TransactionQueue(object):
), ),
) )
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# destination -> stream_id of last successfully sent to-device message. # destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int. # NB: may be a long or an int.
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
@ -382,19 +379,6 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination) self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
).append(failure)
self._attempt_new_transaction(destination)
def send_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
@ -469,7 +453,6 @@ class TransactionQueue(object):
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend( pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values() self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@ -497,7 +480,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures: if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = ( self.last_device_stream_id_by_dest[destination] = (
device_stream_id device_stream_id
@ -507,7 +490,7 @@ class TransactionQueue(object):
# END CRITICAL SECTION # END CRITICAL SECTION
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus,
) )
if success: if success:
sent_transactions_counter.inc() sent_transactions_counter.inc()
@ -584,14 +567,12 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus):
pending_failures):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus] pdus = [x[0] for x in pending_pdus]
edus = pending_edus edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True success = True
@ -601,11 +582,10 @@ class TransactionQueue(object):
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d)",
destination, txn_id, destination, txn_id,
len(pdus), len(pdus),
len(edus), len(edus),
len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -617,7 +597,6 @@ class TransactionQueue(object):
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
edus=edus, edus=edus,
pdu_failures=failures,
) )
self._next_txn_id += 1 self._next_txn_id += 1
@ -627,12 +606,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.info( logger.info(
"TX [%s] {%s} Sending transaction [%s]," "TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pdus), len(pdus),
len(edus), len(edus),
len(failures),
) )
# Actually send the transaction # Actually send the transaction

View file

@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
param_dict = dict(kv.split("=") for kv in params) param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value): def strip_quotes(value):
if value.startswith(b"\""): if value.startswith("\""):
return value[1:-1] return value[1:-1]
else: else:
return value return value
@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
) )
logger.info( logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin, transaction_id, origin,
len(transaction_data.get("pdus", [])), len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])), len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
) )
# We should ideally be getting this from the security layer. # We should ideally be getting this from the security layer.

View file

@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids", "previous_ids",
"pdus", "pdus",
"edus", "edus",
"pdu_failures",
] ]
internal_keys = [ internal_keys = [

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import unicodedata
import attr import attr
import bcrypt import bcrypt
@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self._check_mau_limits()
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
# special case to check for "password" for the check_password interface # special case to check for "password" for the check_password interface
# for the auth providers # for the auth providers
password = login_submission.get("password") password = login_submission.get("password")
if login_type == LoginType.PASSWORD: if login_type == LoginType.PASSWORD:
if not self._password_enabled: if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.") raise SynapseError(400, "Password login has been disabled.")
@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (unicode): complete @user:id
password (unicode): the provided password
Returns: Returns:
(str) the canonical_user_id, or None if unknown user / bad password (unicode) the canonical_user_id, or None if unknown user / bad password
""" """
lookupres = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres: if not lookupres:
@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self._check_mau_limits()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon) user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True, user_id)
return user_id
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
"""Computes a secure hash of password. """Computes a secure hash of password.
Args: Args:
password (str): Password to hash. password (unicode): Password to hash.
Returns: Returns:
Deferred(str): Hashed password. Deferred(unicode): Hashed password.
""" """
def _do_hash(): def _do_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, # Normalise the Unicode in the password
bcrypt.gensalt(self.bcrypt_rounds)) pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii')
return make_deferred_yieldable( return make_deferred_yieldable(
threads.deferToThreadPool( threads.deferToThreadPool(
@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
password (str): Password to hash. password (unicode): Password to hash.
stored_hash (str): Expected hash value. stored_hash (unicode): Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw( return bcrypt.checkpw(
password.encode('utf8') + self.hs.config.password_pepper, pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8') stored_hash.encode('utf8')
) )
@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Ensure that if mau blocking is enabled that invalid users cannot
log in.
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View file

@ -19,10 +19,12 @@ import random
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -129,11 +131,13 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler): class EventHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, user, event_id): def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event. """Retrieve a single specified event.
Args: Args:
user (synapse.types.UserID): The user requesting the event user (synapse.types.UserID): The user requesting the event
room_id (str|None): The expected room id. We'll return None if the
event's room does not match.
event_id (str): The event ID to obtain. event_id (str): The event ID to obtain.
Returns: Returns:
dict: An event, or None if there is no event matching this ID. dict: An event, or None if there is no event matching this ID.
@ -142,13 +146,26 @@ class EventHandler(BaseHandler):
AuthError if the user does not have the rights to inspect this AuthError if the user does not have the rights to inspect this
event. event.
""" """
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id, check_room_id=room_id)
if not event: if not event:
defer.returnValue(None) defer.returnValue(None)
return return
if hasattr(event, "room_id"): users = yield self.store.get_users_in_room(event.room_id)
yield self.auth.check_joined_room(event.room_id, user.to_string()) is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
self.store,
user.to_string(),
[event],
is_peeking=is_peeking
)
if not filtered:
raise AuthError(
403,
"You don't have permission to access that event."
)
defer.returnValue(event) defer.returnValue(event)

View file

@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
# know about # know about
for p in prevs - seen: for p in prevs - seen:
state, got_auth_chain = ( state, got_auth_chain = (
yield self.replication_layer.get_state_for_room( yield self.federation_client.get_state_for_room(
origin, pdu.room_id, p origin, pdu.room_id, p
) )
) )
@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
# #
# see https://github.com/matrix-org/synapse/pull/1744 # see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.replication_layer.get_missing_events( missing_events = yield self.federation_client.get_missing_events(
origin, origin,
pdu.room_id, pdu.room_id,
earliest_events_ids=list(latest), earliest_events_ids=list(latest),
@ -400,7 +400,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( yield self._persist_auth_tree(
origin, auth_chain, state, event origin, auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
@ -444,7 +444,7 @@ class FederationHandler(BaseHandler):
yield self._handle_new_events(origin, event_infos) yield self._handle_new_events(origin, event_infos)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context = yield self._handle_new_event(
origin, origin,
event, event,
state=state, state=state,
@ -469,17 +469,6 @@ class FederationHandler(BaseHandler):
except StoreError: except StoreError:
logger.exception("Failed to store room.") logger.exception("Failed to store room.")
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally # Only fire user_joined_room if the user has acutally
@ -501,7 +490,7 @@ class FederationHandler(BaseHandler):
if newly_joined: if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield self.user_joined_room(user, event.room_id)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
@ -522,7 +511,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
events = yield self.replication_layer.backfill( events = yield self.federation_client.backfill(
dest, dest,
room_id, room_id,
limit=limit, limit=limit,
@ -570,7 +559,7 @@ class FederationHandler(BaseHandler):
state_events = {} state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room( state, auth = yield self.federation_client.get_state_for_room(
destination=dest, destination=dest,
room_id=room_id, room_id=room_id,
event_id=e_id event_id=e_id
@ -612,7 +601,7 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
logcontext.run_in_background( logcontext.run_in_background(
self.replication_layer.get_pdu, self.federation_client.get_pdu,
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -893,7 +882,7 @@ class FederationHandler(BaseHandler):
Invites must be signed by the invitee's server before distribution. Invites must be signed by the invitee's server before distribution.
""" """
pdu = yield self.replication_layer.send_invite( pdu = yield self.federation_client.send_invite(
destination=target_host, destination=target_host,
room_id=event.room_id, room_id=event.room_id,
event_id=event.event_id, event_id=event.event_id,
@ -942,7 +931,7 @@ class FederationHandler(BaseHandler):
self.room_queues[room_id] = [] self.room_queues[room_id] = []
yield self.store.clean_room_for_join(room_id) yield self._clean_room_for_join(room_id)
handled_events = set() handled_events = set()
@ -955,7 +944,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.replication_layer.send_join(target_hosts, event) ret = yield self.federation_client.send_join(target_hosts, event)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -981,15 +970,10 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( yield self._persist_auth_tree(
origin, auth_chain, state, event origin, auth_chain, state, event
) )
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -1084,7 +1068,7 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems. # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event( context = yield self._handle_new_event(
origin, event origin, event
) )
@ -1094,20 +1078,10 @@ class FederationHandler(BaseHandler):
event.signatures, event.signatures,
) )
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield self.user_joined_room(user, event.room_id)
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
@ -1176,17 +1150,7 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)])
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
)
target_user = UserID.from_string(event.state_key)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
defer.returnValue(event) defer.returnValue(event)
@ -1211,30 +1175,20 @@ class FederationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
yield self.replication_layer.send_leave( yield self.federation_client.send_leave(
target_hosts, target_hosts,
event event
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)])
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
)
target_user = UserID.from_string(event.state_key)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},): content={},):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.federation_client.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
@ -1318,7 +1272,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event( yield self._handle_new_event(
origin, event origin, event
) )
@ -1328,22 +1282,17 @@ class FederationHandler(BaseHandler):
event.signatures, event.signatures,
) )
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id): def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id,
)
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@ -1354,8 +1303,7 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e for e in state (e.type, e.state_key): e for e in state
} }
event = yield self.store.get_event(event_id) if event.is_state():
if event and event.is_state():
# Get previous state # Get previous state
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"] prev_id = event.unsigned["replaces_state"]
@ -1374,6 +1322,10 @@ class FederationHandler(BaseHandler):
def get_state_ids_for_pdu(self, room_id, event_id): def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id,
)
state_groups = yield self.store.get_state_groups_ids( state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id] room_id, [event_id]
) )
@ -1382,8 +1334,7 @@ class FederationHandler(BaseHandler):
_, state = state_groups.items().pop() _, state = state_groups.items().pop()
results = state results = state
event = yield self.store.get_event(event_id) if event.is_state():
if event and event.is_state():
# Get previous state # Get previous state
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"] prev_id = event.unsigned["replaces_state"]
@ -1472,9 +1423,8 @@ class FederationHandler(BaseHandler):
event, context event, context
) )
event_stream_id, max_stream_id = yield self.store.persist_event( yield self._persist_events(
event, [(event, context)],
context=context,
backfilled=backfilled, backfilled=backfilled,
) )
except: # noqa: E722, as we reraise the exception this is fine. except: # noqa: E722, as we reraise the exception this is fine.
@ -1487,15 +1437,7 @@ class FederationHandler(BaseHandler):
six.reraise(tp, value, tb) six.reraise(tp, value, tb)
if not backfilled: defer.returnValue(context)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
logcontext.run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False): def _handle_new_events(self, origin, event_infos, backfilled=False):
@ -1503,6 +1445,8 @@ class FederationHandler(BaseHandler):
should not depend on one another, e.g. this should be used to persist should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
Notifies about the events where appropriate.
""" """
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
@ -1517,7 +1461,7 @@ class FederationHandler(BaseHandler):
], consumeErrors=True, ], consumeErrors=True,
)) ))
yield self.store.persist_events( yield self._persist_events(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
@ -1529,7 +1473,8 @@ class FederationHandler(BaseHandler):
def _persist_auth_tree(self, origin, auth_events, state, event): def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event separately. Notifies about the persisted events
where appropriate.
Will attempt to fetch missing auth events. Will attempt to fetch missing auth events.
@ -1540,8 +1485,7 @@ class FederationHandler(BaseHandler):
event (Event) event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event Deferred
call for `event`
""" """
events_to_context = {} events_to_context = {}
for e in itertools.chain(auth_events, state): for e in itertools.chain(auth_events, state):
@ -1567,7 +1511,7 @@ class FederationHandler(BaseHandler):
missing_auth_events.add(e_id) missing_auth_events.add(e_id)
for e_id in missing_auth_events: for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu( m_ev = yield self.federation_client.get_pdu(
[origin], [origin],
e_id, e_id,
outlier=True, outlier=True,
@ -1605,7 +1549,7 @@ class FederationHandler(BaseHandler):
raise raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self.store.persist_events( yield self._persist_events(
[ [
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
@ -1616,12 +1560,10 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
event_stream_id, max_stream_id = yield self.store.persist_event( yield self._persist_events(
event, new_event_context, [(event, new_event_context)],
) )
defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None): def _prep_event(self, origin, event, state=None, auth_events=None):
""" """
@ -1678,8 +1620,19 @@ class FederationHandler(BaseHandler):
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
missing): missing):
in_room = yield self.auth.check_host_in_room(
room_id,
origin
)
if not in_room:
raise AuthError(403, "Host not in room.")
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
# Just go through and process each event in `remote_auth_chain`. We # Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong. # don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain: for e in remote_auth_chain:
@ -1689,7 +1642,6 @@ class FederationHandler(BaseHandler):
pass pass
# Now get the current auth_chain for the event. # Now get the current auth_chain for the event.
event = yield self.store.get_event(event_id)
local_auth_chain = yield self.store.get_auth_chain( local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events], [auth_id for auth_id, _ in event.auth_events],
include_given=True include_given=True
@ -1777,7 +1729,7 @@ class FederationHandler(BaseHandler):
logger.info("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them. # If we don't have all the auth events, we need to get them.
try: try:
remote_auth_chain = yield self.replication_layer.get_event_auth( remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id origin, event.room_id, event.event_id
) )
@ -1893,7 +1845,7 @@ class FederationHandler(BaseHandler):
try: try:
# 2. Get remote difference. # 2. Get remote difference.
result = yield self.replication_layer.query_auth( result = yield self.federation_client.query_auth(
origin, origin,
event.room_id, event.room_id,
event.event_id, event.event_id,
@ -2192,7 +2144,7 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.replication_layer.forward_third_party_invite( yield self.federation_client.forward_third_party_invite(
destinations, destinations,
room_id, room_id,
event_dict, event_dict,
@ -2347,3 +2299,69 @@ class FederationHandler(BaseHandler):
) )
if "valid" not in response or not response["valid"]: if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid") raise AuthError(403, "Third party certificate was invalid")
@defer.inlineCallbacks
def _persist_events(self, event_and_contexts, backfilled=False):
"""Persists events and tells the notifier/pushers about them, if
necessary.
Args:
event_and_contexts(list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether these events are a result of
backfilling or not
Returns:
Deferred
"""
max_stream_id = yield self.store.persist_events(
event_and_contexts,
backfilled=backfilled,
)
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the
event or not.
Args:
event (FrozenEvent)
max_stream_id (int): The max_stream_id returned by persist_events
"""
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
# We notify for memberships if its an invite for one of our
# users
if event.internal_metadata.is_outlier():
if event.membership != Membership.INVITE:
if not self.is_mine_id(target_user_id):
return
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
elif event.internal_metadata.is_outlier():
return
event_stream_id = event.internal_metadata.stream_ordering
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
logcontext.run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id,
)
def _clean_room_for_join(self, room_id):
return self.store.clean_room_for_join(room_id)
def user_joined_room(self, user, room_id):
"""Called when a new user has joined the room
"""
return user_joined_room(self.distributor, user, room_id)

View file

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
Codes, Codes,
MatrixCodeMessageException, HttpResponseException,
SynapseError, SynapseError,
) )
@ -85,7 +85,6 @@ class IdentityHandler(BaseHandler):
) )
defer.returnValue(None) defer.returnValue(None)
data = {}
try: try:
data = yield self.http_client.get_json( data = yield self.http_client.get_json(
"https://%s%s" % ( "https://%s%s" % (
@ -94,11 +93,9 @@ class IdentityHandler(BaseHandler):
), ),
{'sid': creds['sid'], 'client_secret': client_secret} {'sid': creds['sid'], 'client_secret': client_secret}
) )
except MatrixCodeMessageException as e: except HttpResponseException as e:
logger.info("getValidated3pid failed with Matrix error: %r", e) logger.info("getValidated3pid failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data: if 'medium' in data:
defer.returnValue(data) defer.returnValue(data)
@ -136,7 +133,7 @@ class IdentityHandler(BaseHandler):
) )
logger.debug("bound threepid %r to %s", creds, mxid) logger.debug("bound threepid %r to %s", creds, mxid)
except CodeMessageException as e: except CodeMessageException as e:
data = json.loads(e.msg) data = json.loads(e.msg) # XXX WAT?
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -209,12 +206,9 @@ class IdentityHandler(BaseHandler):
params params
) )
defer.returnValue(data) defer.returnValue(data)
except MatrixCodeMessageException as e: except HttpResponseException as e:
logger.info("Proxied requestToken failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e.to_synapse_error()
@defer.inlineCallbacks @defer.inlineCallbacks
def requestMsisdnToken( def requestMsisdnToken(
@ -244,9 +238,6 @@ class IdentityHandler(BaseHandler):
params params
) )
defer.returnValue(data) defer.returnValue(data)
except MatrixCodeMessageException as e: except HttpResponseException as e:
logger.info("Proxied requestToken failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e.to_synapse_error()

View file

@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be generate_token (bool): Whether a new access token should be
@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield self._check_mau_limits()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
yield self._check_mau_limits()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self._check_mau_limits()
need_register = True need_register = True
try: try:
@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
action="join", action="join",
) )
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Do not accept registrations if monthly active user limits exceeded
and limiting is enabled
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise RegistrationError(
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
)

View file

@ -39,12 +39,7 @@ from twisted.web.client import (
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from synapse.api.errors import ( from synapse.api.errors import Codes, HttpResponseException, SynapseError
CodeMessageException,
Codes,
MatrixCodeMessageException,
SynapseError,
)
from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint from synapse.http.endpoint import SpiderEndpoint
from synapse.util.async import add_timeout_to_deferred from synapse.util.async import add_timeout_to_deferred
@ -132,6 +127,11 @@ class SimpleHttpClient(object):
Returns: Returns:
Deferred[object]: parsed json Deferred[object]: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
""" """
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
@ -155,7 +155,10 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else:
raise HttpResponseException(response.code, response.phrase, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json_get_json(self, uri, post_json, headers=None): def post_json_get_json(self, uri, post_json, headers=None):
@ -169,6 +172,11 @@ class SimpleHttpClient(object):
Returns: Returns:
Deferred[object]: parsed json Deferred[object]: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
""" """
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
@ -193,9 +201,7 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else: else:
raise self._exceptionFromFailedRequest(response, body) raise HttpResponseException(response.code, response.phrase, body)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}, headers=None): def get_json(self, uri, args={}, headers=None):
@ -213,14 +219,12 @@ class SimpleHttpClient(object):
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
Raises: Raises:
On a non-2xx HTTP response. The response body will be used as the HttpResponseException On a non-2xx HTTP response.
error message.
ValueError: if the response was not JSON
""" """
try:
body = yield self.get_raw(uri, args, headers=headers) body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None): def put_json(self, uri, json_body, args={}, headers=None):
@ -239,7 +243,9 @@ class SimpleHttpClient(object):
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
Raises: Raises:
On a non-2xx HTTP response. HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
@ -266,10 +272,7 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else: else:
# NB: This is explicitly not json.loads(body)'d because the contract raise HttpResponseException(response.code, response.phrase, body)
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_raw(self, uri, args={}, headers=None): def get_raw(self, uri, args={}, headers=None):
@ -287,8 +290,7 @@ class SimpleHttpClient(object):
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text. HTTP body at text.
Raises: Raises:
On a non-2xx HTTP response. The response body will be used as the HttpResponseException on a non-2xx HTTP response.
error message.
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
@ -311,16 +313,7 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(body) defer.returnValue(body)
else: else:
raise CodeMessageException(response.code, body) raise HttpResponseException(response.code, response.phrase, body)
def _exceptionFromFailedRequest(self, response, body):
try:
jsonBody = json.loads(body)
errcode = jsonBody['errcode']
error = jsonBody['error']
return MatrixCodeMessageException(response.code, error, errcode)
except (ValueError, KeyError):
return CodeMessageException(response.code, body)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out. # The two should be factored out.

View file

@ -13,12 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cgi import cgi
import collections import collections
import logging import logging
import urllib
from six.moves import http_client from six import PY3
from six.moves import http_client, urllib
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -35,7 +36,6 @@ from synapse.api.errors import (
Codes, Codes,
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
cs_exception,
) )
from synapse.http.request_metrics import requests_counter from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
@ -76,16 +76,13 @@ def wrap_json_request_handler(h):
def wrapped_request_handler(self, request): def wrapped_request_handler(self, request):
try: try:
yield h(self, request) yield h(self, request)
except CodeMessageException as e: except SynapseError as e:
code = e.code code = e.code
if isinstance(e, SynapseError):
logger.info( logger.info(
"%s SynapseError: %s - %s", request, code, e.msg "%s SynapseError: %s - %s", request, code, e.msg
) )
else:
logger.exception(e)
respond_with_json( respond_with_json(
request, code, cs_exception(e), send_cors=True, request, code, e.error_dict(), send_cors=True,
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
) )
@ -264,6 +261,7 @@ class JsonResource(HttpServer, resource.Resource):
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
@ -296,8 +294,19 @@ class JsonResource(HttpServer, resource.Resource):
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler. # installed by @request_handler.
def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: _unquote(value) if value else value
for name, value in group_dict.items() for name, value in group_dict.items()
}) })
@ -313,9 +322,9 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request): request (twisted.web.http.Request):
Returns: Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict Tuple[Callable, dict[unicode, unicode]]: callback method, and the
mapping keys to path components as specified in the handler's dict mapping keys to path components as specified in the
path match regexp. handler's path match regexp.
The callback will normally be a method registered via The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either register_paths, so will return (possibly via Deferred) either
@ -327,7 +336,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path.decode('ascii'))
if m: if m:
# We found a match! # We found a match!
return path_entry.callback, m.groupdict() return path_entry.callback, m.groupdict()
@ -383,7 +392,7 @@ class RootRedirect(resource.Resource):
self.url = path self.url = path
def render_GET(self, request): def render_GET(self, request):
return redirectTo(self.url, request) return redirectTo(self.url.encode('ascii'), request)
def getChild(self, name, request): def getChild(self, name, request):
if len(name) == 0: if len(name) == 0:
@ -404,12 +413,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
return return
if pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n" json_bytes = (encode_pretty_printed_json(json_object) + "\n"
).encode("utf-8")
else: else:
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
else: else:
json_bytes = json.dumps(json_object) json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes( return respond_with_json_bytes(
request, code, json_bytes, request, code, json_bytes,

View file

@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body: if not content_bytes and allow_empty_body:
return None return None
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try: try:
content = json.loads(content_bytes) content_unicode = content_bytes.decode('utf8')
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try:
content = json.loads(content_unicode)
except Exception as e: except Exception as e:
logger.warn("Unable to parse JSON: %s", e) logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View file

@ -23,8 +23,7 @@ from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
MatrixCodeMessageException, HttpResponseException,
SynapseError,
) )
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -160,11 +159,11 @@ class ReplicationEndpoint(object):
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.
yield clock.sleep(1) yield clock.sleep(1)
except MatrixCodeMessageException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And
# importantly, not stack traces everywhere) # importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
defer.returnValue(result) defer.returnValue(result)

View file

@ -18,6 +18,7 @@ import hashlib
import hmac import hmac
import logging import logging
from six import text_type
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "username must be specified", errcode=Codes.BAD_JSON, 400, "username must be specified", errcode=Codes.BAD_JSON,
) )
else: else:
if (not isinstance(body['username'], str) or len(body['username']) > 512): if (
not isinstance(body['username'], text_type)
or len(body['username']) > 512
):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8") username = body["username"].encode("utf-8")
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "password must be specified", errcode=Codes.BAD_JSON, 400, "password must be specified", errcode=Codes.BAD_JSON,
) )
else: else:
if (not isinstance(body['password'], str) or len(body['password']) > 512): if (
not isinstance(body['password'], text_type)
or len(body['password']) > 512
):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8") password = body["password"].encode("utf-8")
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin") want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest() want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac): if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
raise SynapseError( raise SynapseError(403, "HMAC incorrect")
403, "HMAC incorrect",
)
# Reuse the parts of RegisterRestServlet to reduce code duplication # Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet
register = RegisterRestServlet(self.hs) register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register( (user_id, _) = yield register.registration_handler.register(
localpart=username.lower(), password=password, admin=bool(admin), localpart=body['username'].lower(),
password=body["password"],
admin=bool(admin),
generate_token=False, generate_token=False,
) )

View file

@ -18,7 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.types import RoomAlias from synapse.types import RoomAlias
@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
room = yield self.store.get_room(room_id) room = yield self.store.get_room(room_id)
if room is None: if room is None:
raise SynapseError(400, "Unknown room") raise NotFoundError("Unknown room")
defer.returnValue((200, { defer.returnValue((200, {
"visibility": "public" if room["is_public"] else "private" "visibility": "public" if room["is_public"] else "private"

View file

@ -88,7 +88,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id) event = yield self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View file

@ -506,7 +506,7 @@ class RoomEventServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id) event = yield self.event_handler.get_event(requester.user, room_id, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View file

@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
kind = "user" kind = b"user"
if "kind" in request.args: if b"kind" in request.args:
kind = request.args["kind"][0] kind = request.args[b"kind"][0]
if kind == "guest": if kind == b"guest":
ret = yield self._do_guest_registration(body) ret = yield self._do_guest_registration(body)
defer.returnValue(ret) defer.returnValue(ret)
return return
elif kind != "user": elif kind != b"user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind,)
) )
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
assert_params_in_dict(params, ["password"]) assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
new_password = params.get("password", None)
if desired_username is not None: if desired_username is not None:
desired_username = desired_username.lower() desired_username = desired_username.lower()

View file

@ -379,7 +379,7 @@ class MediaRepository(object):
logger.warn("HTTP error fetching remote media %s/%s: %s", logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response) server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND: if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e) raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except SynapseError: except SynapseError:

View file

@ -177,7 +177,7 @@ class MediaStorage(object):
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer( consumer = BackgroundFileConsumer(
open(local_path, "w"), self.hs.get_reactor()) open(local_path, "wb"), self.hs.get_reactor())
yield res.write_to_consumer(consumer) yield res.write_to_consumer(consumer)
yield consumer.wait() yield consumer.wait()
defer.returnValue(local_path) defer.returnValue(local_path)

View file

@ -577,7 +577,7 @@ def _make_state_cache_entry(
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func) return sorted(events, key=key_func)

View file

@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore, ApplicationServiceStore,
EventsStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
RejectionsStore, RejectionsStore,
@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore, PusherStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,
@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")] extra_tables=[("local_invites", "stream_id")]
@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
def count_monthly_users(self):
"""Counts the number of users who used this homeserver in the last 30 days
This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau
Returns:
Defered[int]
"""
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone()
return count
return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self): def count_r30_users(self):
""" """
Counts the number of 30 day retained users, defined as:- Counts the number of 30 day retained users, defined as:-

View file

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from ._base import SQLBaseStore from ._base import SQLBaseStore

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -343,6 +343,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
table="events", table="events",
keyvalues={ keyvalues={
"event_id": event_id, "event_id": event_id,
"room_id": room_id,
}, },
retcol="depth", retcol="depth",
allow_none=True, allow_none=True,

View file

@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
def encode_json(json_object): def encode_json(json_object):
return frozendict_json_encoder.encode(json_object) """
Encode a Python object as JSON and return it in a Unicode string.
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
out = out.decode('utf8')
return out
class _EventPeristenceQueue(object): class _EventPeristenceQueue(object):
@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
return f return f
class EventsStore(EventsWorkerStore): # inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@ -231,12 +241,18 @@ class EventsStore(EventsWorkerStore):
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False): def persist_events(self, events_and_contexts, backfilled=False):
""" """
Write events to the database Write events to the database
Args: Args:
events_and_contexts: list of tuples of (event, context) events_and_contexts: list of tuples of (event, context)
backfilled: ? backfilled (bool): Whether the results are retrieved from federation
via backfill or not. Used to determine if they're "new" events
which might update the current state etc.
Returns:
Deferred[int]: the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {}
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
@ -253,10 +269,14 @@ class EventsStore(EventsWorkerStore):
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return make_deferred_yieldable( yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue(max_persisted_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False): def persist_event(self, event, context, backfilled=False):
@ -1054,7 +1074,7 @@ class EventsStore(EventsWorkerStore):
metadata_json = encode_json( metadata_json = encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
).decode("UTF-8") )
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
@ -1168,8 +1188,8 @@ class EventsStore(EventsWorkerStore):
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": encode_json( "internal_metadata": encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
).decode("UTF-8"), ),
"json": encode_json(event_dict(event)).decode("UTF-8"), "json": encode_json(event_dict(event)),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],

View file

@ -19,7 +19,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import NotFoundError
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -77,7 +77,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False, get_prev_content=False, allow_rejected=False,
allow_none=False): allow_none=False, check_room_id=None):
"""Get an event from the database by event_id. """Get an event from the database by event_id.
Args: Args:
@ -88,7 +88,9 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events. allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if allow_none (bool): If True, return None if no event found, if
False throw an exception. False throw a NotFoundError
check_room_id (str|None): if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns: Returns:
Deferred : A FrozenEvent. Deferred : A FrozenEvent.
@ -100,10 +102,16 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
if not events and not allow_none: event = events[0] if events else None
raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None) if event is not None and check_room_id is not None:
if event.room_id != check_room_id:
event = None
if event is None and not allow_none:
raise NotFoundError("Could not find event %s" % (event_id,))
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True, def get_events(self, event_ids, check_redacted=True,

View file

@ -24,7 +24,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string

View file

@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
txn (cursor): txn (cursor):
event_id (str): Id for the Event. event_id (str): Id for the Event.
Returns: Returns:
A dict of algorithm -> hash. A dict[unicode, bytes] of algorithm -> hash.
""" """
query = ( query = (
"SELECT algorithm, hash" "SELECT algorithm, hash"

View file

@ -43,7 +43,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background

View file

@ -137,7 +137,7 @@ class DomainSpecificString(
@classmethod @classmethod
def from_string(cls, s): def from_string(cls, s):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0] != cls.SIGIL: if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % ( raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL, cls.__name__, cls.SIGIL,
)) ))

View file

@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its
# whenever we are invalidated # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
results = {} results = {}
cached_defers = {}
missing = [] def update_results_dict(res, arg):
results[arg] = res
# list of deferreds to wait for
cached_defers = []
missing = set()
# If the cache takes a single arg then that is used as the key, # If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used. # otherwise a tuple is used.
if num_args == 1: if num_args == 1:
def cache_get(arg): def arg_to_cache_key(arg):
return cache.get(arg, callback=invalidate_callback) return arg
else: else:
key = list(keyargs) keylist = list(keyargs)
def cache_get(arg): def arg_to_cache_key(arg):
key[self.list_pos] = arg keylist[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback) return tuple(keylist)
for arg in list_args: for arg in list_args:
try: try:
res = cache_get(arg) res = cache.get(arg_to_cache_key(arg),
callback=invalidate_callback)
if not isinstance(res, ObservableDeferred): if not isinstance(res, ObservableDeferred):
results[arg] = res results[arg] = res
elif not res.has_succeeded(): elif not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(update_results_dict, arg)
cached_defers[arg] = res cached_defers.append(res)
else: else:
results[arg] = res.get_result() results[arg] = res.get_result()
except KeyError: except KeyError:
missing.append(arg) missing.add(arg)
if missing: if missing:
args_to_call = dict(arg_dict) # we need an observable deferred for each entry in the list,
args_to_call[self.list_name] = missing # which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
ret_d = defer.maybeDeferred( def complete_all(res):
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
for e in missing:
val = res.get(e, None)
deferreds_map[e].callback(val)
results[e] = val
def errback(f):
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
for e in missing:
key = arg_to_cache_key(e)
cache.invalidate(key)
deferreds_map[e].errback(f)
# return the failure, to propagate to our caller.
return f
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call), logcontext.preserve_fn(self.function_to_call),
**args_to_call **args_to_call
) ).addCallbacks(complete_all, errback))
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
if cached_defers: if cached_defers:
def update_results_dict(res): d = defer.gatherResults(
results.update(res) cached_defers,
return results
return logcontext.make_deferred_yieldable(defer.gatherResults(
list(cached_defers.values()),
consumeErrors=True, consumeErrors=True,
).addCallback(update_results_dict).addErrback( ).addCallbacks(
lambda _: results,
unwrapFirstError unwrapFirstError
)) )
return logcontext.make_deferred_yieldable(d)
else: else:
return results return results
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache. cache.
Args: Args:
cache (Cache): The underlying cache to use. cached_method_name (str): The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache num_args (int): Number of arguments to use as the key in the cache

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import string_types from six import binary_type, text_type
from canonicaljson import json from canonicaljson import json
from frozendict import frozendict from frozendict import frozendict
@ -26,7 +26,7 @@ def freeze(o):
if isinstance(o, frozendict): if isinstance(o, frozendict):
return o return o
if isinstance(o, string_types): if isinstance(o, (binary_type, text_type)):
return o return o
try: try:
@ -41,7 +41,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)): if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()}) return dict({k: unfreeze(v) for k, v in o.items()})
if isinstance(o, string_types): if isinstance(o, (binary_type, text_type)):
return o return o
try: try:

View file

@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
self.test_token = "_test_token_" self.test_token = b"_test_token_"
# this is overridden for the appservice tests # this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10" request.getClientIP.return_value = "192.168.10.10"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42" request.getClientIP.return_value = "131.111.8.42"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None, ip_range_whitelist=None,
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id) self.assertEquals(
requester.user.to_string(),
masquerading_user_id.decode('utf8')
)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None, ip_range_whitelist=None,
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [token] request.args[b"access_token"] = [token.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertEqual(UserID.from_string(USER_ID), requester.user)
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
# the token should *not* work now # the token should *not* work now
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [guest_tok] request.args[b"access_token"] = [guest_tok.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
import pymacaroons import pymacaroons
@ -19,6 +20,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.api.errors import synapse.api.errors
from synapse.api.errors import AuthError
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator() self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
self.hs.config.max_mau_value = 50
self.small_number_of_users = 1
self.large_number_of_users = 100
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
token = self.macaroon_generator.generate_access_token("some_user") token = self.macaroon_generator.generate_access_token("some_user")
@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce) v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.assertEqual(
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
) self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.hs.clock.now = 6000 self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
@defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
self.assertEqual( user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
self.assertEqual(
"a_user", user_id
) )
# add another "user_id" caveat, which might allow us to override the # add another "user_id" caveat, which might allow us to override the
@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user") macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token(
"user_a", 5000
)
return pymacaroons.Macaroon.deserialize(token)

View file

@ -17,6 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import RegistrationError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
requester, local_part, display_name) requester, local_part, display_name)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
def test_cannot_register_when_mau_limits_exceeded(self):
local_part = "someone"
display_name = "someone"
requester = create_requester("@as:test")
store = self.hs.get_datastore()
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_users = 1
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
# Ensure does not throw exception
yield self.handler.get_or_create_user(requester, 'a', display_name)
self.hs.config.limit_usage_by_mau = True
with self.assertRaises(RegistrationError):
yield self.handler.get_or_create_user(requester, 'b', display_name)
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
self._macaroon_mock_generator("another_secret")
# Ensure does not throw exception
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
with self.assertRaises(RegistrationError):
yield self.handler.register(localpart=local_part)
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
with self.assertRaises(RegistrationError):
yield self.handler.register_saml2(local_part)
def _macaroon_mock_generator(self, secret):
"""
Reset macaroon generator in the case where the test creates multiple users
"""
macaroon_generator = Mock(
generate_access_token=Mock(return_value=secret))
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler

View file

@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content, "content": content,
} }
], ],
"pdu_failures": [],
} }

View file

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.utils
class InitTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(InitTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
hs.config.max_mau_value = 50
hs.config.limit_usage_by_mau = True
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_count_monthly_users(self):
count = yield self.store.count_monthly_users()
self.assertEqual(0, count)
yield self._insert_user_ips("@user:server1")
yield self._insert_user_ips("@user:server2")
count = yield self.store.count_monthly_users()
self.assertEqual(2, count)
@defer.inlineCallbacks
def _insert_user_ips(self, user):
"""
Helper function to populate user_ips without using batch insertion infra
args:
user (str): specify username i.e. @user:server.com
"""
yield self.store._simple_upsert(
table="user_ips",
keyvalues={
"user_id": user,
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": "device_id",
},
values={
"last_seen": self.clock.time_msec(),
}
)

View file

@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 3) r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips') self.assertEqual(r, 'chips')
obj.mock.assert_not_called() obj.mock.assert_not_called()
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
defer.returnValue(self.mock(args1, arg2))
with logcontext.LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
obj.mock.return_value = {10: 'fish', 20: 'chips'}
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
r = yield d1
self.assertEqual(
logcontext.LoggingContext.current_context(),
c1
)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = {30: 'peas'}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
self.assertEqual(r, {20: 'chips', 30: 'peas'})
obj.mock.reset_mock()
# all the values should now be cached
r = yield obj.fn(10, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(20, 2)
self.assertEqual(r, 'chips')
r = yield obj.fn(30, 2)
self.assertEqual(r, 'peas')
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
defer.returnValue(self.mock(args1, arg2))
obj = Cls()
invalidate0 = mock.Mock()
invalidate1 = mock.Mock()
# cache miss
obj.mock.return_value = {10: 'fish', 20: 'chips'}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()
# cache hit
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
obj.mock.assert_not_called()
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
invalidate0.assert_not_called()
invalidate1.assert_not_called()
# now if we invalidate the keys, both invalidations should get called
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()

View file

@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
self.prefix = prefix self.prefix = prefix
def trigger_get(self, path): def trigger_get(self, path):
return self.trigger("GET", path, None) return self.trigger(b"GET", path, None)
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
headers = {} headers = {}
if federation_auth: if federation_auth:
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
except Exception: except Exception:
pass pass
if isinstance(path, bytes):
path = path.decode('utf8')
for (method, pattern, func) in self.callbacks: for (method, pattern, func) in self.callbacks:
if http_method != method: if http_method != method:
continue continue
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
if matcher: if matcher:
try: try:
args = [ args = [
urlparse.unquote(u).decode("UTF-8") urlparse.unquote(u)
for u in matcher.groups() for u in matcher.groups()
] ]