forked from MirrorHub/synapse
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/refactor_repl_servlet
This commit is contained in:
commit
cb298ff623
67 changed files with 1074 additions and 648 deletions
5
.github/ISSUE_TEMPLATE.md
vendored
5
.github/ISSUE_TEMPLATE.md
vendored
|
@ -27,8 +27,9 @@ Describe here the problem that you are experiencing, or the feature you are requ
|
|||
|
||||
Describe how what happens differs from what you expected.
|
||||
|
||||
If you can identify any relevant log snippets from _homeserver.log_, please include
|
||||
those here (please be careful to remove any personal or private data):
|
||||
<!-- If you can identify any relevant log snippets from _homeserver.log_, please include
|
||||
those (please be careful to remove any personal or private data). Please surround them with
|
||||
``` (three backticks, on a line on their own), so that they are formatted legibly. -->
|
||||
|
||||
### Version information
|
||||
|
||||
|
|
|
@ -62,4 +62,7 @@ Christoph Witzany <christoph at web.crofting.com>
|
|||
* Add LDAP support for authentication
|
||||
|
||||
Pierre Jaury <pierre at jaury.eu>
|
||||
* Docker packaging
|
||||
* Docker packaging
|
||||
|
||||
Serban Constantin <serban.constantin at gmail dot com>
|
||||
* Small bug fix
|
|
@ -1,3 +1,12 @@
|
|||
Synapse 0.33.1 (2018-08-02)
|
||||
===========================
|
||||
|
||||
SECURITY FIXES
|
||||
--------------
|
||||
|
||||
- Fix a potential issue where servers could request events for rooms they have not joined. ([\#3641](https://github.com/matrix-org/synapse/issues/3641))
|
||||
- Fix a potential issue where users could see events in private rooms before they joined. ([\#3642](https://github.com/matrix-org/synapse/issues/3642))
|
||||
|
||||
Synapse 0.33.0 (2018-07-19)
|
||||
===========================
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ makes it horribly hard to review otherwise.
|
|||
Changelog
|
||||
~~~~~~~~~
|
||||
|
||||
All changes, even minor ones, need a corresponding changelog
|
||||
All changes, even minor ones, need a corresponding changelog / newsfragment
|
||||
entry. These are managed by Towncrier
|
||||
(https://github.com/hawkowl/towncrier).
|
||||
|
||||
|
|
22
Dockerfile
22
Dockerfile
|
@ -1,16 +1,32 @@
|
|||
FROM docker.io/python:2-alpine3.7
|
||||
|
||||
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev
|
||||
RUN apk add --no-cache --virtual .nacl_deps \
|
||||
build-base \
|
||||
libffi-dev \
|
||||
libjpeg-turbo-dev \
|
||||
libressl-dev \
|
||||
libxslt-dev \
|
||||
linux-headers \
|
||||
postgresql-dev \
|
||||
su-exec \
|
||||
zlib-dev
|
||||
|
||||
COPY . /synapse
|
||||
|
||||
# A wheel cache may be provided in ./cache for faster build
|
||||
RUN cd /synapse \
|
||||
&& pip install --upgrade pip setuptools psycopg2 lxml \
|
||||
&& pip install --upgrade \
|
||||
lxml \
|
||||
pip \
|
||||
psycopg2 \
|
||||
setuptools \
|
||||
&& mkdir -p /synapse/cache \
|
||||
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
|
||||
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
|
||||
&& rm -rf setup.py setup.cfg synapse
|
||||
&& rm -rf \
|
||||
setup.cfg \
|
||||
setup.py \
|
||||
synapse
|
||||
|
||||
VOLUME ["/data"]
|
||||
|
||||
|
|
1
changelog.d/2952.bugfix
Normal file
1
changelog.d/2952.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Make /directory/list API return 404 for room not found instead of 400
|
1
changelog.d/3384.misc
Normal file
1
changelog.d/3384.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Rewrite cache list decorator
|
1
changelog.d/3543.misc
Normal file
1
changelog.d/3543.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve Dockerfile and docker-compose instructions
|
1
changelog.d/3569.bugfix
Normal file
1
changelog.d/3569.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.
|
1
changelog.d/3612.misc
Normal file
1
changelog.d/3612.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Make EventStore inherit from EventFederationStore
|
1
changelog.d/3621.misc
Normal file
1
changelog.d/3621.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor FederationHandler to move DB writes into separate functions
|
1
changelog.d/3628.misc
Normal file
1
changelog.d/3628.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove unused field "pdu_failures" from transactions.
|
1
changelog.d/3630.feature
Normal file
1
changelog.d/3630.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add ability to limit number of monthly active users on the server
|
1
changelog.d/3634.misc
Normal file
1
changelog.d/3634.misc
Normal file
|
@ -0,0 +1 @@
|
|||
rename replication_layer to federation_client
|
1
changelog.d/3638.misc
Normal file
1
changelog.d/3638.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Factor out exception handling in federation_client
|
1
changelog.d/3639.feature
Normal file
1
changelog.d/3639.feature
Normal file
|
@ -0,0 +1 @@
|
|||
When we fail to join a room over federation, pass the error code back to the client.
|
1
changelog.d/3645.misc
Normal file
1
changelog.d/3645.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Update CONTRIBUTING to mention newsfragments.
|
|
@ -9,13 +9,7 @@ use that server.
|
|||
|
||||
## Build
|
||||
|
||||
Build the docker image with the `docker build` command from the root of the synapse repository.
|
||||
|
||||
```
|
||||
docker build -t docker.io/matrixdotorg/synapse .
|
||||
```
|
||||
|
||||
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
|
||||
Build the docker image with the `docker-compose build` command.
|
||||
|
||||
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ version: '3'
|
|||
services:
|
||||
|
||||
synapse:
|
||||
build: ../..
|
||||
image: docker.io/matrixdotorg/synapse:latest
|
||||
# Since snyapse does not retry to connect to the database, restart upon
|
||||
# failure
|
||||
|
|
6
contrib/grafana/README.md
Normal file
6
contrib/grafana/README.md
Normal file
|
@ -0,0 +1,6 @@
|
|||
# Using the Synapse Grafana dashboard
|
||||
|
||||
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
|
||||
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst
|
||||
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
|
||||
3. Set up additional recording rules
|
|
@ -17,4 +17,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.33.0"
|
||||
__version__ = "0.33.1"
|
||||
|
|
|
@ -252,10 +252,10 @@ class Auth(object):
|
|||
if ip_address not in app_service.ip_range_whitelist:
|
||||
defer.returnValue((None, None))
|
||||
|
||||
if "user_id" not in request.args:
|
||||
if b"user_id" not in request.args:
|
||||
defer.returnValue((app_service.sender, app_service))
|
||||
|
||||
user_id = request.args["user_id"][0]
|
||||
user_id = request.args[b"user_id"][0].decode('utf8')
|
||||
if app_service.sender == user_id:
|
||||
defer.returnValue((app_service.sender, app_service))
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ class Codes(object):
|
|||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
||||
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
||||
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
@ -69,20 +70,6 @@ class CodeMessageException(RuntimeError):
|
|||
self.code = code
|
||||
self.msg = msg
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(self.msg)
|
||||
|
||||
|
||||
class MatrixCodeMessageException(CodeMessageException):
|
||||
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
|
||||
|
||||
Attributes:
|
||||
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
||||
super(MatrixCodeMessageException, self).__init__(code, msg)
|
||||
self.errcode = errcode
|
||||
|
||||
|
||||
class SynapseError(CodeMessageException):
|
||||
"""A base exception type for matrix errors which have an errcode and error
|
||||
|
@ -108,38 +95,28 @@ class SynapseError(CodeMessageException):
|
|||
self.errcode,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_http_response_exception(cls, err):
|
||||
"""Make a SynapseError based on an HTTPResponseException
|
||||
|
||||
This is useful when a proxied request has failed, and we need to
|
||||
decide how to map the failure onto a matrix error to send back to the
|
||||
client.
|
||||
class ProxiedRequestError(SynapseError):
|
||||
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
|
||||
|
||||
An attempt is made to parse the body of the http response as a matrix
|
||||
error. If that succeeds, the errcode and error message from the body
|
||||
are used as the errcode and error message in the new synapse error.
|
||||
Attributes:
|
||||
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
|
||||
super(ProxiedRequestError, self).__init__(
|
||||
code, msg, errcode
|
||||
)
|
||||
if additional_fields is None:
|
||||
self._additional_fields = {}
|
||||
else:
|
||||
self._additional_fields = dict(additional_fields)
|
||||
|
||||
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
|
||||
set to the reason code from the HTTP response.
|
||||
|
||||
Args:
|
||||
err (HttpResponseException):
|
||||
|
||||
Returns:
|
||||
SynapseError:
|
||||
"""
|
||||
# try to parse the body as json, to get better errcode/msg, but
|
||||
# default to M_UNKNOWN with the HTTP status as the error text
|
||||
try:
|
||||
j = json.loads(err.response)
|
||||
except ValueError:
|
||||
j = {}
|
||||
errcode = j.get('errcode', Codes.UNKNOWN)
|
||||
errmsg = j.get('error', err.msg)
|
||||
|
||||
res = SynapseError(err.code, errmsg, errcode)
|
||||
return res
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
**self._additional_fields
|
||||
)
|
||||
|
||||
|
||||
class ConsentNotGivenError(SynapseError):
|
||||
|
@ -308,14 +285,6 @@ class LimitExceededError(SynapseError):
|
|||
)
|
||||
|
||||
|
||||
def cs_exception(exception):
|
||||
if isinstance(exception, CodeMessageException):
|
||||
return exception.error_dict()
|
||||
else:
|
||||
logger.error("Unknown exception type: %s", type(exception))
|
||||
return {}
|
||||
|
||||
|
||||
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
|
||||
""" Utility method for constructing an error response for client-server
|
||||
interactions.
|
||||
|
@ -372,7 +341,7 @@ class HttpResponseException(CodeMessageException):
|
|||
Represents an HTTP-level failure of an outbound request
|
||||
|
||||
Attributes:
|
||||
response (str): body of response
|
||||
response (bytes): body of response
|
||||
"""
|
||||
def __init__(self, code, msg, response):
|
||||
"""
|
||||
|
@ -380,7 +349,39 @@ class HttpResponseException(CodeMessageException):
|
|||
Args:
|
||||
code (int): HTTP status code
|
||||
msg (str): reason phrase from HTTP response status line
|
||||
response (str): body of response
|
||||
response (bytes): body of response
|
||||
"""
|
||||
super(HttpResponseException, self).__init__(code, msg)
|
||||
self.response = response
|
||||
|
||||
def to_synapse_error(self):
|
||||
"""Make a SynapseError based on an HTTPResponseException
|
||||
|
||||
This is useful when a proxied request has failed, and we need to
|
||||
decide how to map the failure onto a matrix error to send back to the
|
||||
client.
|
||||
|
||||
An attempt is made to parse the body of the http response as a matrix
|
||||
error. If that succeeds, the errcode and error message from the body
|
||||
are used as the errcode and error message in the new synapse error.
|
||||
|
||||
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
|
||||
set to the reason code from the HTTP response.
|
||||
|
||||
Returns:
|
||||
SynapseError:
|
||||
"""
|
||||
# try to parse the body as json, to get better errcode/msg, but
|
||||
# default to M_UNKNOWN with the HTTP status as the error text
|
||||
try:
|
||||
j = json.loads(self.response)
|
||||
except ValueError:
|
||||
j = {}
|
||||
|
||||
if not isinstance(j, dict):
|
||||
j = {}
|
||||
|
||||
errcode = j.pop('errcode', Codes.UNKNOWN)
|
||||
errmsg = j.pop('error', self.msg)
|
||||
|
||||
return ProxiedRequestError(self.code, errmsg, errcode, j)
|
||||
|
|
|
@ -20,6 +20,8 @@ import sys
|
|||
|
||||
from six import iteritems
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from twisted.application import service
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||
|
@ -300,6 +302,11 @@ class SynapseHomeServer(HomeServer):
|
|||
quit_with_error(e.message)
|
||||
|
||||
|
||||
# Gauges to expose monthly active user control metrics
|
||||
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
|
||||
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
"""
|
||||
Args:
|
||||
|
@ -512,6 +519,18 @@ def run(hs):
|
|||
# table will decrease
|
||||
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_monthly_active_users():
|
||||
count = 0
|
||||
if hs.config.limit_usage_by_mau:
|
||||
count = yield hs.get_datastore().count_monthly_users()
|
||||
current_mau_gauge.set(float(count))
|
||||
max_mau_value_gauge.set(float(hs.config.max_mau_value))
|
||||
|
||||
generate_monthly_active_users()
|
||||
if hs.config.limit_usage_by_mau:
|
||||
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
|
||||
|
||||
if hs.config.report_stats:
|
||||
logger.info("Scheduling stats reporting for 3 hour intervals")
|
||||
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
|
||||
|
|
|
@ -67,6 +67,14 @@ class ServerConfig(Config):
|
|||
"block_non_admin_invites", False,
|
||||
)
|
||||
|
||||
# Options to control access by tracking MAU
|
||||
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
|
||||
if self.limit_usage_by_mau:
|
||||
self.max_mau_value = config.get(
|
||||
"max_mau_value", 0,
|
||||
)
|
||||
else:
|
||||
self.max_mau_value = 0
|
||||
# FIXME: federation_domain_whitelist needs sytests
|
||||
self.federation_domain_whitelist = None
|
||||
federation_domain_whitelist = config.get(
|
||||
|
@ -209,6 +217,8 @@ class ServerConfig(Config):
|
|||
# different cores. See
|
||||
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
|
||||
#
|
||||
# This setting requires the affinity package to be installed!
|
||||
#
|
||||
# cpu_affinity: 0xFFFFFFFF
|
||||
|
||||
# Whether to serve a web client from the HTTP/HTTPS root resource.
|
||||
|
|
|
@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
|
|||
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
||||
|
||||
|
||||
class InvalidResponseError(RuntimeError):
|
||||
"""Helper for _try_destination_list: indicates that the server returned a response
|
||||
we couldn't parse
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationClient, self).__init__(hs)
|
||||
|
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
|
|||
defer.returnValue(signed_auth)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _try_destination_list(self, description, destinations, callback):
|
||||
"""Try an operation on a series of servers, until it succeeds
|
||||
|
||||
Args:
|
||||
description (unicode): description of the operation we're doing, for logging
|
||||
|
||||
destinations (Iterable[unicode]): list of server_names to try
|
||||
|
||||
callback (callable): Function to run for each server. Passed a single
|
||||
argument: the server_name to try. May return a deferred.
|
||||
|
||||
If the callback raises a CodeMessageException with a 300/400 code,
|
||||
attempts to perform the operation stop immediately and the exception is
|
||||
reraised.
|
||||
|
||||
Otherwise, if the callback raises an Exception the error is logged and the
|
||||
next server tried. Normally the stacktrace is logged but this is
|
||||
suppressed if the exception is an InvalidResponseError.
|
||||
|
||||
Returns:
|
||||
The [Deferred] result of callback, if it succeeds
|
||||
|
||||
Raises:
|
||||
SynapseError if the chosen remote server returns a 300/400 code.
|
||||
|
||||
RuntimeError if no servers were reachable.
|
||||
"""
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
res = yield callback(destination)
|
||||
defer.returnValue(res)
|
||||
except InvalidResponseError as e:
|
||||
logger.warn(
|
||||
"Failed to %s via %s: %s",
|
||||
description, destination, e,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise e.to_synapse_error()
|
||||
else:
|
||||
logger.warn(
|
||||
"Failed to %s via %s: %i %s",
|
||||
description, destination, e.code, e.message,
|
||||
)
|
||||
except Exception:
|
||||
logger.warn(
|
||||
"Failed to %s via %s",
|
||||
description, destination, exc_info=1,
|
||||
)
|
||||
|
||||
raise RuntimeError("Failed to %s via any server", description)
|
||||
|
||||
def make_membership_event(self, destinations, room_id, user_id, membership,
|
||||
content={},):
|
||||
"""
|
||||
|
@ -481,7 +543,7 @@ class FederationClient(FederationBase):
|
|||
Deferred: resolves to a tuple of (origin (str), event (object))
|
||||
where origin is the remote homeserver which generated the event.
|
||||
|
||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
||||
Fails with a ``SynapseError`` if the chosen remote server
|
||||
returns a 300/400 code.
|
||||
|
||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||
|
@ -492,50 +554,35 @@ class FederationClient(FederationBase):
|
|||
"make_membership_event called with membership='%s', must be one of %s" %
|
||||
(membership, ",".join(valid_memberships))
|
||||
)
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
ret = yield self.transport_layer.make_membership_event(
|
||||
destination, room_id, user_id, membership
|
||||
)
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
ret = yield self.transport_layer.make_membership_event(
|
||||
destination, room_id, user_id, membership
|
||||
)
|
||||
|
||||
pdu_dict = ret["event"]
|
||||
pdu_dict = ret["event"]
|
||||
|
||||
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
|
||||
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
|
||||
|
||||
pdu_dict["content"].update(content)
|
||||
pdu_dict["content"].update(content)
|
||||
|
||||
# The protoevent received over the JSON wire may not have all
|
||||
# the required fields. Lets just gloss over that because
|
||||
# there's some we never care about
|
||||
if "prev_state" not in pdu_dict:
|
||||
pdu_dict["prev_state"] = []
|
||||
# The protoevent received over the JSON wire may not have all
|
||||
# the required fields. Lets just gloss over that because
|
||||
# there's some we never care about
|
||||
if "prev_state" not in pdu_dict:
|
||||
pdu_dict["prev_state"] = []
|
||||
|
||||
ev = builder.EventBuilder(pdu_dict)
|
||||
ev = builder.EventBuilder(pdu_dict)
|
||||
|
||||
defer.returnValue(
|
||||
(destination, ev)
|
||||
)
|
||||
break
|
||||
except CodeMessageException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise
|
||||
else:
|
||||
logger.warn(
|
||||
"Failed to make_%s via %s: %s",
|
||||
membership, destination, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Failed to make_%s via %s: %s",
|
||||
membership, destination, e.message
|
||||
)
|
||||
defer.returnValue(
|
||||
(destination, ev)
|
||||
)
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
return self._try_destination_list(
|
||||
"make_" + membership, destinations, send_request,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_join(self, destinations, pdu):
|
||||
"""Sends a join event to one of a list of homeservers.
|
||||
|
||||
|
@ -552,103 +599,91 @@ class FederationClient(FederationBase):
|
|||
giving the serer the event was sent to, ``state`` (?) and
|
||||
``auth_chain``.
|
||||
|
||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
||||
Fails with a ``SynapseError`` if the chosen remote server
|
||||
returns a 300/400 code.
|
||||
|
||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||
"""
|
||||
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
try:
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
state = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("state", [])
|
||||
]
|
||||
|
||||
state = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("state", [])
|
||||
]
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
pdus = {
|
||||
p.event_id: p
|
||||
for p in itertools.chain(state, auth_chain)
|
||||
}
|
||||
|
||||
pdus = {
|
||||
p.event_id: p
|
||||
for p in itertools.chain(state, auth_chain)
|
||||
}
|
||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, list(pdus.values()),
|
||||
outlier=True,
|
||||
)
|
||||
|
||||
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, list(pdus.values()),
|
||||
outlier=True,
|
||||
)
|
||||
valid_pdus_map = {
|
||||
p.event_id: p
|
||||
for p in valid_pdus
|
||||
}
|
||||
|
||||
valid_pdus_map = {
|
||||
p.event_id: p
|
||||
for p in valid_pdus
|
||||
}
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
signed_state = [
|
||||
copy.copy(valid_pdus_map[p.event_id])
|
||||
for p in state
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
signed_state = [
|
||||
copy.copy(valid_pdus_map[p.event_id])
|
||||
for p in state
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
signed_auth = [
|
||||
valid_pdus_map[p.event_id]
|
||||
for p in auth_chain
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
|
||||
signed_auth = [
|
||||
valid_pdus_map[p.event_id]
|
||||
for p in auth_chain
|
||||
if p.event_id in valid_pdus_map
|
||||
]
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
for s in signed_state:
|
||||
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
||||
|
||||
# NB: We *need* to copy to ensure that we don't have multiple
|
||||
# references being passed on, as that causes... issues.
|
||||
for s in signed_state:
|
||||
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
defer.returnValue({
|
||||
"state": signed_state,
|
||||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
})
|
||||
except CodeMessageException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise
|
||||
else:
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
destination, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
destination, e.message
|
||||
)
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
defer.returnValue({
|
||||
"state": signed_state,
|
||||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
})
|
||||
return self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, destination, room_id, event_id, pdu):
|
||||
time_now = self._clock.time_msec()
|
||||
code, content = yield self.transport_layer.send_invite(
|
||||
destination=destination,
|
||||
room_id=room_id,
|
||||
event_id=event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
try:
|
||||
code, content = yield self.transport_layer.send_invite(
|
||||
destination=destination,
|
||||
room_id=room_id,
|
||||
event_id=event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if e.code == 403:
|
||||
raise e.to_synapse_error()
|
||||
raise
|
||||
|
||||
pdu_dict = content["event"]
|
||||
|
||||
|
@ -663,7 +698,6 @@ class FederationClient(FederationBase):
|
|||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_leave(self, destinations, pdu):
|
||||
"""Sends a leave event to one of a list of homeservers.
|
||||
|
||||
|
@ -680,35 +714,25 @@ class FederationClient(FederationBase):
|
|||
Return:
|
||||
Deferred: resolves to None.
|
||||
|
||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
||||
returns a non-200 code.
|
||||
Fails with a ``SynapseError`` if the chosen remote server
|
||||
returns a 300/400 code.
|
||||
|
||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||
"""
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_leave(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
try:
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_leave(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
logger.debug("Got content: %s", content)
|
||||
defer.returnValue(None)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
defer.returnValue(None)
|
||||
except CodeMessageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to send_leave via %s: %s",
|
||||
destination, e.message
|
||||
)
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
return self._try_destination_list("send_leave", destinations, send_request)
|
||||
|
||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||
search_filter=None, include_all_networks=False,
|
||||
|
|
|
@ -207,10 +207,6 @@ class FederationServer(FederationBase):
|
|||
edu.content
|
||||
)
|
||||
|
||||
pdu_failures = getattr(transaction, "pdu_failures", [])
|
||||
for fail in pdu_failures:
|
||||
logger.info("Got failure %r", fail)
|
||||
|
||||
response = {
|
||||
"pdus": pdu_results,
|
||||
}
|
||||
|
@ -430,6 +426,7 @@ class FederationServer(FederationBase):
|
|||
ret = yield self.handler.on_query_auth(
|
||||
origin,
|
||||
event_id,
|
||||
room_id,
|
||||
signed_auth,
|
||||
content.get("rejects", []),
|
||||
content.get("missing", []),
|
||||
|
|
|
@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
|
|||
|
||||
self.edus = SortedDict() # stream position -> Edu
|
||||
|
||||
self.failures = SortedDict() # stream position -> (destination, Failure)
|
||||
|
||||
self.device_messages = SortedDict() # stream position -> destination
|
||||
|
||||
self.pos = 1
|
||||
|
@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
|
|||
|
||||
for queue_name in [
|
||||
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||
"edus", "failures", "device_messages", "pos_time",
|
||||
"edus", "device_messages", "pos_time",
|
||||
]:
|
||||
register(queue_name, getattr(self, queue_name))
|
||||
|
||||
|
@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
|
|||
for key in keys[:i]:
|
||||
del self.edus[key]
|
||||
|
||||
# Delete things out of failure map
|
||||
keys = self.failures.keys()
|
||||
i = self.failures.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.failures[key]
|
||||
|
||||
# Delete things out of device map
|
||||
keys = self.device_messages.keys()
|
||||
i = self.device_messages.bisect_left(position_to_delete)
|
||||
|
@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
|
|||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
||||
self.failures[pos] = (destination, str(failure))
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
|
|||
for (pos, edu) in edus:
|
||||
rows.append((pos, EduRow(edu)))
|
||||
|
||||
# Fetch changed failures
|
||||
i = self.failures.bisect_right(from_token)
|
||||
j = self.failures.bisect_right(to_token) + 1
|
||||
failures = self.failures.items()[i:j]
|
||||
|
||||
for (pos, (destination, failure)) in failures:
|
||||
rows.append((pos, FailureRow(
|
||||
destination=destination,
|
||||
failure=failure,
|
||||
)))
|
||||
|
||||
# Fetch changed device messages
|
||||
i = self.device_messages.bisect_right(from_token)
|
||||
j = self.device_messages.bisect_right(to_token) + 1
|
||||
|
@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
|
|||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||
|
||||
|
||||
class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
|
||||
"destination", # str
|
||||
"failure",
|
||||
))):
|
||||
"""Streams failures to a remote server. Failures are issued when there was
|
||||
something wrong with a transaction the remote sent us, e.g. it included
|
||||
an event that was invalid.
|
||||
"""
|
||||
|
||||
TypeId = "f"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
return FailureRow(
|
||||
destination=data["destination"],
|
||||
failure=data["failure"],
|
||||
)
|
||||
|
||||
def to_data(self):
|
||||
return {
|
||||
"destination": self.destination,
|
||||
"failure": self.failure,
|
||||
}
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
buff.failures.setdefault(self.destination, []).append(self.failure)
|
||||
|
||||
|
||||
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
||||
"destination", # str
|
||||
))):
|
||||
|
@ -471,7 +417,6 @@ TypeToRow = {
|
|||
PresenceRow,
|
||||
KeyedEduRow,
|
||||
EduRow,
|
||||
FailureRow,
|
||||
DeviceRow,
|
||||
)
|
||||
}
|
||||
|
@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
|
|||
"presence", # list(UserPresenceState)
|
||||
"keyed_edus", # dict of destination -> { key -> Edu }
|
||||
"edus", # dict of destination -> [Edu]
|
||||
"failures", # dict of destination -> [failures]
|
||||
"device_destinations", # set of destinations
|
||||
))
|
||||
|
||||
|
@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
|
|||
presence=[],
|
||||
keyed_edus={},
|
||||
edus={},
|
||||
failures={},
|
||||
device_destinations=set(),
|
||||
)
|
||||
|
||||
|
@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
|
|||
edu.destination, edu.edu_type, edu.content, key=None,
|
||||
)
|
||||
|
||||
for destination, failure_list in iteritems(buff.failures):
|
||||
for failure in failure_list:
|
||||
transaction_queue.send_failure(destination, failure)
|
||||
|
||||
for destination in buff.device_destinations:
|
||||
transaction_queue.send_device_messages(destination)
|
||||
|
|
|
@ -116,9 +116,6 @@ class TransactionQueue(object):
|
|||
),
|
||||
)
|
||||
|
||||
# destination -> list of tuple(failure, deferred)
|
||||
self.pending_failures_by_dest = {}
|
||||
|
||||
# destination -> stream_id of last successfully sent to-device message.
|
||||
# NB: may be a long or an int.
|
||||
self.last_device_stream_id_by_dest = {}
|
||||
|
@ -382,19 +379,6 @@ class TransactionQueue(object):
|
|||
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
return
|
||||
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
self.pending_failures_by_dest.setdefault(
|
||||
destination, []
|
||||
).append(failure)
|
||||
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
return
|
||||
|
@ -469,7 +453,6 @@ class TransactionQueue(object):
|
|||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||
pending_presence = self.pending_presence_by_dest.pop(destination, {})
|
||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||
|
||||
pending_edus.extend(
|
||||
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
|
||||
|
@ -497,7 +480,7 @@ class TransactionQueue(object):
|
|||
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||
destination, len(pending_pdus))
|
||||
|
||||
if not pending_pdus and not pending_edus and not pending_failures:
|
||||
if not pending_pdus and not pending_edus:
|
||||
logger.debug("TX [%s] Nothing to send", destination)
|
||||
self.last_device_stream_id_by_dest[destination] = (
|
||||
device_stream_id
|
||||
|
@ -507,7 +490,7 @@ class TransactionQueue(object):
|
|||
# END CRITICAL SECTION
|
||||
|
||||
success = yield self._send_new_transaction(
|
||||
destination, pending_pdus, pending_edus, pending_failures,
|
||||
destination, pending_pdus, pending_edus,
|
||||
)
|
||||
if success:
|
||||
sent_transactions_counter.inc()
|
||||
|
@ -584,14 +567,12 @@ class TransactionQueue(object):
|
|||
|
||||
@measure_func("_send_new_transaction")
|
||||
@defer.inlineCallbacks
|
||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
||||
pending_failures):
|
||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus):
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[1])
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = pending_edus
|
||||
failures = [x.get_dict() for x in pending_failures]
|
||||
|
||||
success = True
|
||||
|
||||
|
@ -601,11 +582,10 @@ class TransactionQueue(object):
|
|||
|
||||
logger.debug(
|
||||
"TX [%s] {%s} Attempting new transaction"
|
||||
" (pdus: %d, edus: %d, failures: %d)",
|
||||
" (pdus: %d, edus: %d)",
|
||||
destination, txn_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
len(failures)
|
||||
)
|
||||
|
||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
@ -617,7 +597,6 @@ class TransactionQueue(object):
|
|||
destination=destination,
|
||||
pdus=pdus,
|
||||
edus=edus,
|
||||
pdu_failures=failures,
|
||||
)
|
||||
|
||||
self._next_txn_id += 1
|
||||
|
@ -627,12 +606,11 @@ class TransactionQueue(object):
|
|||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.info(
|
||||
"TX [%s] {%s} Sending transaction [%s],"
|
||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
||||
" (PDUs: %d, EDUs: %d)",
|
||||
destination, txn_id,
|
||||
transaction.transaction_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
len(failures),
|
||||
)
|
||||
|
||||
# Actually send the transaction
|
||||
|
|
|
@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
|
|||
param_dict = dict(kv.split("=") for kv in params)
|
||||
|
||||
def strip_quotes(value):
|
||||
if value.startswith(b"\""):
|
||||
if value.startswith("\""):
|
||||
return value[1:-1]
|
||||
else:
|
||||
return value
|
||||
|
@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
|
|||
)
|
||||
|
||||
logger.info(
|
||||
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
|
||||
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
|
||||
transaction_id, origin,
|
||||
len(transaction_data.get("pdus", [])),
|
||||
len(transaction_data.get("edus", [])),
|
||||
len(transaction_data.get("failures", [])),
|
||||
)
|
||||
|
||||
# We should ideally be getting this from the security layer.
|
||||
|
|
|
@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
|
|||
"previous_ids",
|
||||
"pdus",
|
||||
"edus",
|
||||
"pdu_failures",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import unicodedata
|
||||
|
||||
import attr
|
||||
import bcrypt
|
||||
|
@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
yield self._check_mau_limits()
|
||||
|
||||
# the device *should* have been registered before we got here; however,
|
||||
# it's possible we raced against a DELETE operation. The thing we
|
||||
|
@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
|
|||
# special case to check for "password" for the check_password interface
|
||||
# for the auth providers
|
||||
password = login_submission.get("password")
|
||||
|
||||
if login_type == LoginType.PASSWORD:
|
||||
if not self._password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
|
@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
|
|||
multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
user_id (unicode): complete @user:id
|
||||
password (unicode): the provided password
|
||||
Returns:
|
||||
(str) the canonical_user_id, or None if unknown user / bad password
|
||||
(unicode) the canonical_user_id, or None if unknown user / bad password
|
||||
"""
|
||||
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
if not lookupres:
|
||||
|
@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
|
|||
device_id)
|
||||
defer.returnValue(access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||
yield self._check_mau_limits()
|
||||
auth_api = self.hs.get_auth()
|
||||
user_id = None
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
||||
return user_id
|
||||
except Exception:
|
||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_token(self, access_token):
|
||||
|
@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
|
|||
"""Computes a secure hash of password.
|
||||
|
||||
Args:
|
||||
password (str): Password to hash.
|
||||
password (unicode): Password to hash.
|
||||
|
||||
Returns:
|
||||
Deferred(str): Hashed password.
|
||||
Deferred(unicode): Hashed password.
|
||||
"""
|
||||
def _do_hash():
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
# Normalise the Unicode in the password
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
return bcrypt.hashpw(
|
||||
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||
bcrypt.gensalt(self.bcrypt_rounds),
|
||||
).decode('ascii')
|
||||
|
||||
return make_deferred_yieldable(
|
||||
threads.deferToThreadPool(
|
||||
|
@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
|
|||
"""Validates that self.hash(password) == stored_hash.
|
||||
|
||||
Args:
|
||||
password (str): Password to hash.
|
||||
stored_hash (str): Expected hash value.
|
||||
password (unicode): Password to hash.
|
||||
stored_hash (unicode): Expected hash value.
|
||||
|
||||
Returns:
|
||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||
"""
|
||||
|
||||
def _do_validate_hash():
|
||||
# Normalise the Unicode in the password
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
return bcrypt.checkpw(
|
||||
password.encode('utf8') + self.hs.config.password_pepper,
|
||||
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||
stored_hash.encode('utf8')
|
||||
)
|
||||
|
||||
|
@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
return defer.succeed(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_mau_limits(self):
|
||||
"""
|
||||
Ensure that if mau blocking is enabled that invalid users cannot
|
||||
log in.
|
||||
"""
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
current_mau = yield self.store.count_monthly_users()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise AuthError(
|
||||
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class MacaroonGenerator(object):
|
||||
|
|
|
@ -19,10 +19,12 @@ import random
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.types import UserID
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -129,11 +131,13 @@ class EventStreamHandler(BaseHandler):
|
|||
class EventHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, user, event_id):
|
||||
def get_event(self, user, room_id, event_id):
|
||||
"""Retrieve a single specified event.
|
||||
|
||||
Args:
|
||||
user (synapse.types.UserID): The user requesting the event
|
||||
room_id (str|None): The expected room id. We'll return None if the
|
||||
event's room does not match.
|
||||
event_id (str): The event ID to obtain.
|
||||
Returns:
|
||||
dict: An event, or None if there is no event matching this ID.
|
||||
|
@ -142,13 +146,26 @@ class EventHandler(BaseHandler):
|
|||
AuthError if the user does not have the rights to inspect this
|
||||
event.
|
||||
"""
|
||||
event = yield self.store.get_event(event_id)
|
||||
event = yield self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
if not event:
|
||||
defer.returnValue(None)
|
||||
return
|
||||
|
||||
if hasattr(event, "room_id"):
|
||||
yield self.auth.check_joined_room(event.room_id, user.to_string())
|
||||
users = yield self.store.get_users_in_room(event.room_id)
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
filtered = yield filter_events_for_client(
|
||||
self.store,
|
||||
user.to_string(),
|
||||
[event],
|
||||
is_peeking=is_peeking
|
||||
)
|
||||
|
||||
if not filtered:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to access that event."
|
||||
)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
|
|
@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
|
|||
self.hs = hs
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.replication_layer = hs.get_federation_client()
|
||||
self.federation_client = hs.get_federation_client()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.server_name = hs.hostname
|
||||
self.keyring = hs.get_keyring()
|
||||
|
@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
|
|||
# know about
|
||||
for p in prevs - seen:
|
||||
state, got_auth_chain = (
|
||||
yield self.replication_layer.get_state_for_room(
|
||||
yield self.federation_client.get_state_for_room(
|
||||
origin, pdu.room_id, p
|
||||
)
|
||||
)
|
||||
|
@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
|
|||
#
|
||||
# see https://github.com/matrix-org/synapse/pull/1744
|
||||
|
||||
missing_events = yield self.replication_layer.get_missing_events(
|
||||
missing_events = yield self.federation_client.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
|
@ -400,7 +400,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
try:
|
||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||
yield self._persist_auth_tree(
|
||||
origin, auth_chain, state, event
|
||||
)
|
||||
except AuthError as e:
|
||||
|
@ -444,7 +444,7 @@ class FederationHandler(BaseHandler):
|
|||
yield self._handle_new_events(origin, event_infos)
|
||||
|
||||
try:
|
||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
context = yield self._handle_new_event(
|
||||
origin,
|
||||
event,
|
||||
state=state,
|
||||
|
@ -469,17 +469,6 @@ class FederationHandler(BaseHandler):
|
|||
except StoreError:
|
||||
logger.exception("Failed to store room.")
|
||||
|
||||
extra_users = []
|
||||
if event.type == EventTypes.Member:
|
||||
target_user_id = event.state_key
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally
|
||||
|
@ -501,7 +490,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if newly_joined:
|
||||
user = UserID.from_string(event.state_key)
|
||||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
yield self.user_joined_room(user, event.room_id)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
|
@ -522,7 +511,7 @@ class FederationHandler(BaseHandler):
|
|||
if dest == self.server_name:
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
events = yield self.replication_layer.backfill(
|
||||
events = yield self.federation_client.backfill(
|
||||
dest,
|
||||
room_id,
|
||||
limit=limit,
|
||||
|
@ -570,7 +559,7 @@ class FederationHandler(BaseHandler):
|
|||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = yield self.replication_layer.get_state_for_room(
|
||||
state, auth = yield self.federation_client.get_state_for_room(
|
||||
destination=dest,
|
||||
room_id=room_id,
|
||||
event_id=e_id
|
||||
|
@ -612,7 +601,7 @@ class FederationHandler(BaseHandler):
|
|||
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
logcontext.run_in_background(
|
||||
self.replication_layer.get_pdu,
|
||||
self.federation_client.get_pdu,
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
|
@ -893,7 +882,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
Invites must be signed by the invitee's server before distribution.
|
||||
"""
|
||||
pdu = yield self.replication_layer.send_invite(
|
||||
pdu = yield self.federation_client.send_invite(
|
||||
destination=target_host,
|
||||
room_id=event.room_id,
|
||||
event_id=event.event_id,
|
||||
|
@ -942,7 +931,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self.room_queues[room_id] = []
|
||||
|
||||
yield self.store.clean_room_for_join(room_id)
|
||||
yield self._clean_room_for_join(room_id)
|
||||
|
||||
handled_events = set()
|
||||
|
||||
|
@ -955,7 +944,7 @@ class FederationHandler(BaseHandler):
|
|||
target_hosts.insert(0, origin)
|
||||
except ValueError:
|
||||
pass
|
||||
ret = yield self.replication_layer.send_join(target_hosts, event)
|
||||
ret = yield self.federation_client.send_join(target_hosts, event)
|
||||
|
||||
origin = ret["origin"]
|
||||
state = ret["state"]
|
||||
|
@ -981,15 +970,10 @@ class FederationHandler(BaseHandler):
|
|||
# FIXME
|
||||
pass
|
||||
|
||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||
yield self._persist_auth_tree(
|
||||
origin, auth_chain, state, event
|
||||
)
|
||||
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=[joinee]
|
||||
)
|
||||
|
||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||
finally:
|
||||
room_queue = self.room_queues[room_id]
|
||||
|
@ -1084,7 +1068,7 @@ class FederationHandler(BaseHandler):
|
|||
# would introduce the danger of backwards-compatibility problems.
|
||||
event.internal_metadata.send_on_behalf_of = origin
|
||||
|
||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
context = yield self._handle_new_event(
|
||||
origin, event
|
||||
)
|
||||
|
||||
|
@ -1094,20 +1078,10 @@ class FederationHandler(BaseHandler):
|
|||
event.signatures,
|
||||
)
|
||||
|
||||
extra_users = []
|
||||
if event.type == EventTypes.Member:
|
||||
target_user_id = event.state_key
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.JOIN:
|
||||
user = UserID.from_string(event.state_key)
|
||||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
yield self.user_joined_room(user, event.room_id)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
|
||||
|
@ -1176,17 +1150,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
)
|
||||
|
||||
target_user = UserID.from_string(event.state_key)
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=[target_user],
|
||||
)
|
||||
yield self._persist_events([(event, context)])
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
|
@ -1211,30 +1175,20 @@ class FederationHandler(BaseHandler):
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
yield self.replication_layer.send_leave(
|
||||
yield self.federation_client.send_leave(
|
||||
target_hosts,
|
||||
event
|
||||
)
|
||||
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
)
|
||||
|
||||
target_user = UserID.from_string(event.state_key)
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=[target_user],
|
||||
)
|
||||
yield self._persist_events([(event, context)])
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
||||
content={},):
|
||||
origin, pdu = yield self.replication_layer.make_membership_event(
|
||||
origin, pdu = yield self.federation_client.make_membership_event(
|
||||
target_hosts,
|
||||
room_id,
|
||||
user_id,
|
||||
|
@ -1318,7 +1272,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event.internal_metadata.outlier = False
|
||||
|
||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
yield self._handle_new_event(
|
||||
origin, event
|
||||
)
|
||||
|
||||
|
@ -1328,22 +1282,17 @@ class FederationHandler(BaseHandler):
|
|||
event.signatures,
|
||||
)
|
||||
|
||||
extra_users = []
|
||||
if event.type == EventTypes.Member:
|
||||
target_user_id = event.state_key
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
)
|
||||
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, room_id, event_id):
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
|
||||
event = yield self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id,
|
||||
)
|
||||
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
room_id, [event_id]
|
||||
)
|
||||
|
@ -1354,8 +1303,7 @@ class FederationHandler(BaseHandler):
|
|||
(e.type, e.state_key): e for e in state
|
||||
}
|
||||
|
||||
event = yield self.store.get_event(event_id)
|
||||
if event and event.is_state():
|
||||
if event.is_state():
|
||||
# Get previous state
|
||||
if "replaces_state" in event.unsigned:
|
||||
prev_id = event.unsigned["replaces_state"]
|
||||
|
@ -1374,6 +1322,10 @@ class FederationHandler(BaseHandler):
|
|||
def get_state_ids_for_pdu(self, room_id, event_id):
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
event = yield self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id,
|
||||
)
|
||||
|
||||
state_groups = yield self.store.get_state_groups_ids(
|
||||
room_id, [event_id]
|
||||
)
|
||||
|
@ -1382,8 +1334,7 @@ class FederationHandler(BaseHandler):
|
|||
_, state = state_groups.items().pop()
|
||||
results = state
|
||||
|
||||
event = yield self.store.get_event(event_id)
|
||||
if event and event.is_state():
|
||||
if event.is_state():
|
||||
# Get previous state
|
||||
if "replaces_state" in event.unsigned:
|
||||
prev_id = event.unsigned["replaces_state"]
|
||||
|
@ -1472,9 +1423,8 @@ class FederationHandler(BaseHandler):
|
|||
event, context
|
||||
)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
yield self._persist_events(
|
||||
[(event, context)],
|
||||
backfilled=backfilled,
|
||||
)
|
||||
except: # noqa: E722, as we reraise the exception this is fine.
|
||||
|
@ -1487,15 +1437,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
six.reraise(tp, value, tb)
|
||||
|
||||
if not backfilled:
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
logcontext.run_in_background(
|
||||
self.pusher_pool.on_new_notifications,
|
||||
event_stream_id, max_stream_id,
|
||||
)
|
||||
|
||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_events(self, origin, event_infos, backfilled=False):
|
||||
|
@ -1503,6 +1445,8 @@ class FederationHandler(BaseHandler):
|
|||
should not depend on one another, e.g. this should be used to persist
|
||||
a bunch of outliers, but not a chunk of individual events that depend
|
||||
on each other for state calculations.
|
||||
|
||||
Notifies about the events where appropriate.
|
||||
"""
|
||||
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
|
@ -1517,7 +1461,7 @@ class FederationHandler(BaseHandler):
|
|||
], consumeErrors=True,
|
||||
))
|
||||
|
||||
yield self.store.persist_events(
|
||||
yield self._persist_events(
|
||||
[
|
||||
(ev_info["event"], context)
|
||||
for ev_info, context in zip(event_infos, contexts)
|
||||
|
@ -1529,7 +1473,8 @@ class FederationHandler(BaseHandler):
|
|||
def _persist_auth_tree(self, origin, auth_events, state, event):
|
||||
"""Checks the auth chain is valid (and passes auth checks) for the
|
||||
state and event. Then persists the auth chain and state atomically.
|
||||
Persists the event seperately.
|
||||
Persists the event separately. Notifies about the persisted events
|
||||
where appropriate.
|
||||
|
||||
Will attempt to fetch missing auth events.
|
||||
|
||||
|
@ -1540,8 +1485,7 @@ class FederationHandler(BaseHandler):
|
|||
event (Event)
|
||||
|
||||
Returns:
|
||||
2-tuple of (event_stream_id, max_stream_id) from the persist_event
|
||||
call for `event`
|
||||
Deferred
|
||||
"""
|
||||
events_to_context = {}
|
||||
for e in itertools.chain(auth_events, state):
|
||||
|
@ -1567,7 +1511,7 @@ class FederationHandler(BaseHandler):
|
|||
missing_auth_events.add(e_id)
|
||||
|
||||
for e_id in missing_auth_events:
|
||||
m_ev = yield self.replication_layer.get_pdu(
|
||||
m_ev = yield self.federation_client.get_pdu(
|
||||
[origin],
|
||||
e_id,
|
||||
outlier=True,
|
||||
|
@ -1605,7 +1549,7 @@ class FederationHandler(BaseHandler):
|
|||
raise
|
||||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
yield self.store.persist_events(
|
||||
yield self._persist_events(
|
||||
[
|
||||
(e, events_to_context[e.event_id])
|
||||
for e in itertools.chain(auth_events, state)
|
||||
|
@ -1616,12 +1560,10 @@ class FederationHandler(BaseHandler):
|
|||
event, old_state=state
|
||||
)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
event, new_event_context,
|
||||
yield self._persist_events(
|
||||
[(event, new_event_context)],
|
||||
)
|
||||
|
||||
defer.returnValue((event_stream_id, max_stream_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _prep_event(self, origin, event, state=None, auth_events=None):
|
||||
"""
|
||||
|
@ -1678,8 +1620,19 @@ class FederationHandler(BaseHandler):
|
|||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||
def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
|
||||
missing):
|
||||
in_room = yield self.auth.check_host_in_room(
|
||||
room_id,
|
||||
origin
|
||||
)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
event = yield self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
# Just go through and process each event in `remote_auth_chain`. We
|
||||
# don't want to fall into the trap of `missing` being wrong.
|
||||
for e in remote_auth_chain:
|
||||
|
@ -1689,7 +1642,6 @@ class FederationHandler(BaseHandler):
|
|||
pass
|
||||
|
||||
# Now get the current auth_chain for the event.
|
||||
event = yield self.store.get_event(event_id)
|
||||
local_auth_chain = yield self.store.get_auth_chain(
|
||||
[auth_id for auth_id, _ in event.auth_events],
|
||||
include_given=True
|
||||
|
@ -1777,7 +1729,7 @@ class FederationHandler(BaseHandler):
|
|||
logger.info("Missing auth: %s", missing_auth)
|
||||
# If we don't have all the auth events, we need to get them.
|
||||
try:
|
||||
remote_auth_chain = yield self.replication_layer.get_event_auth(
|
||||
remote_auth_chain = yield self.federation_client.get_event_auth(
|
||||
origin, event.room_id, event.event_id
|
||||
)
|
||||
|
||||
|
@ -1893,7 +1845,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
try:
|
||||
# 2. Get remote difference.
|
||||
result = yield self.replication_layer.query_auth(
|
||||
result = yield self.federation_client.query_auth(
|
||||
origin,
|
||||
event.room_id,
|
||||
event.event_id,
|
||||
|
@ -2192,7 +2144,7 @@ class FederationHandler(BaseHandler):
|
|||
yield member_handler.send_membership_event(None, event, context)
|
||||
else:
|
||||
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
||||
yield self.replication_layer.forward_third_party_invite(
|
||||
yield self.federation_client.forward_third_party_invite(
|
||||
destinations,
|
||||
room_id,
|
||||
event_dict,
|
||||
|
@ -2347,3 +2299,69 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
if "valid" not in response or not response["valid"]:
|
||||
raise AuthError(403, "Third party certificate was invalid")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _persist_events(self, event_and_contexts, backfilled=False):
|
||||
"""Persists events and tells the notifier/pushers about them, if
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
event_and_contexts(list[tuple[FrozenEvent, EventContext]])
|
||||
backfilled (bool): Whether these events are a result of
|
||||
backfilling or not
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
max_stream_id = yield self.store.persist_events(
|
||||
event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
if not backfilled: # Never notify for backfilled events
|
||||
for event, _ in event_and_contexts:
|
||||
self._notify_persisted_event(event, max_stream_id)
|
||||
|
||||
def _notify_persisted_event(self, event, max_stream_id):
|
||||
"""Checks to see if notifier/pushers should be notified about the
|
||||
event or not.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
max_stream_id (int): The max_stream_id returned by persist_events
|
||||
"""
|
||||
|
||||
extra_users = []
|
||||
if event.type == EventTypes.Member:
|
||||
target_user_id = event.state_key
|
||||
|
||||
# We notify for memberships if its an invite for one of our
|
||||
# users
|
||||
if event.internal_metadata.is_outlier():
|
||||
if event.membership != Membership.INVITE:
|
||||
if not self.is_mine_id(target_user_id):
|
||||
return
|
||||
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
elif event.internal_metadata.is_outlier():
|
||||
return
|
||||
|
||||
event_stream_id = event.internal_metadata.stream_ordering
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
|
||||
logcontext.run_in_background(
|
||||
self.pusher_pool.on_new_notifications,
|
||||
event_stream_id, max_stream_id,
|
||||
)
|
||||
|
||||
def _clean_room_for_join(self, room_id):
|
||||
return self.store.clean_room_for_join(room_id)
|
||||
|
||||
def user_joined_room(self, user, room_id):
|
||||
"""Called when a new user has joined the room
|
||||
"""
|
||||
return user_joined_room(self.distributor, user, room_id)
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
MatrixCodeMessageException,
|
||||
HttpResponseException,
|
||||
SynapseError,
|
||||
)
|
||||
|
||||
|
@ -85,7 +85,6 @@ class IdentityHandler(BaseHandler):
|
|||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
data = {}
|
||||
try:
|
||||
data = yield self.http_client.get_json(
|
||||
"https://%s%s" % (
|
||||
|
@ -94,11 +93,9 @@ class IdentityHandler(BaseHandler):
|
|||
),
|
||||
{'sid': creds['sid'], 'client_secret': client_secret}
|
||||
)
|
||||
except MatrixCodeMessageException as e:
|
||||
except HttpResponseException as e:
|
||||
logger.info("getValidated3pid failed with Matrix error: %r", e)
|
||||
raise SynapseError(e.code, e.msg, e.errcode)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
raise e.to_synapse_error()
|
||||
|
||||
if 'medium' in data:
|
||||
defer.returnValue(data)
|
||||
|
@ -136,7 +133,7 @@ class IdentityHandler(BaseHandler):
|
|||
)
|
||||
logger.debug("bound threepid %r to %s", creds, mxid)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
data = json.loads(e.msg) # XXX WAT?
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -209,12 +206,9 @@ class IdentityHandler(BaseHandler):
|
|||
params
|
||||
)
|
||||
defer.returnValue(data)
|
||||
except MatrixCodeMessageException as e:
|
||||
logger.info("Proxied requestToken failed with Matrix error: %r", e)
|
||||
raise SynapseError(e.code, e.msg, e.errcode)
|
||||
except CodeMessageException as e:
|
||||
except HttpResponseException as e:
|
||||
logger.info("Proxied requestToken failed: %r", e)
|
||||
raise e
|
||||
raise e.to_synapse_error()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def requestMsisdnToken(
|
||||
|
@ -244,9 +238,6 @@ class IdentityHandler(BaseHandler):
|
|||
params
|
||||
)
|
||||
defer.returnValue(data)
|
||||
except MatrixCodeMessageException as e:
|
||||
logger.info("Proxied requestToken failed with Matrix error: %r", e)
|
||||
raise SynapseError(e.code, e.msg, e.errcode)
|
||||
except CodeMessageException as e:
|
||||
except HttpResponseException as e:
|
||||
logger.info("Proxied requestToken failed: %r", e)
|
||||
raise e
|
||||
raise e.to_synapse_error()
|
||||
|
|
|
@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
|
|||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
|
|||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
one will be generated.
|
||||
password (str) : The password to assign to this user so they can
|
||||
password (unicode) : The password to assign to this user so they can
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
generate_token (bool): Whether a new access token should be
|
||||
|
@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
|
|||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield self._check_mau_limits()
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = yield self.auth_handler().hash(password)
|
||||
|
@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
|
|||
400,
|
||||
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
||||
)
|
||||
yield self._check_mau_limits()
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
|
@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
|
|||
"""
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
|
||||
yield self._check_mau_limits()
|
||||
need_register = True
|
||||
|
||||
try:
|
||||
|
@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
|
|||
remote_room_hosts=remote_room_hosts,
|
||||
action="join",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_mau_limits(self):
|
||||
"""
|
||||
Do not accept registrations if monthly active user limits exceeded
|
||||
and limiting is enabled
|
||||
"""
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
current_mau = yield self.store.count_monthly_users()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise RegistrationError(
|
||||
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
|
||||
)
|
||||
|
|
|
@ -39,12 +39,7 @@ from twisted.web.client import (
|
|||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
MatrixCodeMessageException,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
|
||||
from synapse.http.endpoint import SpiderEndpoint
|
||||
from synapse.util.async import add_timeout_to_deferred
|
||||
|
@ -132,6 +127,11 @@ class SimpleHttpClient(object):
|
|||
|
||||
Returns:
|
||||
Deferred[object]: parsed json
|
||||
|
||||
Raises:
|
||||
HttpResponseException: On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
"""
|
||||
|
||||
# TODO: Do we ever want to log message contents?
|
||||
|
@ -155,7 +155,10 @@ class SimpleHttpClient(object):
|
|||
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
else:
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json_get_json(self, uri, post_json, headers=None):
|
||||
|
@ -169,6 +172,11 @@ class SimpleHttpClient(object):
|
|||
|
||||
Returns:
|
||||
Deferred[object]: parsed json
|
||||
|
||||
Raises:
|
||||
HttpResponseException: On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
"""
|
||||
json_str = encode_canonical_json(post_json)
|
||||
|
||||
|
@ -193,9 +201,7 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
else:
|
||||
raise self._exceptionFromFailedRequest(response, body)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, uri, args={}, headers=None):
|
||||
|
@ -213,14 +219,12 @@ class SimpleHttpClient(object):
|
|||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
Raises:
|
||||
On a non-2xx HTTP response. The response body will be used as the
|
||||
error message.
|
||||
HttpResponseException On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
"""
|
||||
try:
|
||||
body = yield self.get_raw(uri, args, headers=headers)
|
||||
defer.returnValue(json.loads(body))
|
||||
except CodeMessageException as e:
|
||||
raise self._exceptionFromFailedRequest(e.code, e.msg)
|
||||
body = yield self.get_raw(uri, args, headers=headers)
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, uri, json_body, args={}, headers=None):
|
||||
|
@ -239,7 +243,9 @@ class SimpleHttpClient(object):
|
|||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body as JSON.
|
||||
Raises:
|
||||
On a non-2xx HTTP response.
|
||||
HttpResponseException On a non-2xx HTTP response.
|
||||
|
||||
ValueError: if the response was not JSON
|
||||
"""
|
||||
if len(args):
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
|
@ -266,10 +272,7 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
defer.returnValue(json.loads(body))
|
||||
else:
|
||||
# NB: This is explicitly not json.loads(body)'d because the contract
|
||||
# of CodeMessageException is a *string* message. Callers can always
|
||||
# load it into JSON if they want.
|
||||
raise CodeMessageException(response.code, body)
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_raw(self, uri, args={}, headers=None):
|
||||
|
@ -287,8 +290,7 @@ class SimpleHttpClient(object):
|
|||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||
HTTP body at text.
|
||||
Raises:
|
||||
On a non-2xx HTTP response. The response body will be used as the
|
||||
error message.
|
||||
HttpResponseException on a non-2xx HTTP response.
|
||||
"""
|
||||
if len(args):
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
|
@ -311,16 +313,7 @@ class SimpleHttpClient(object):
|
|||
if 200 <= response.code < 300:
|
||||
defer.returnValue(body)
|
||||
else:
|
||||
raise CodeMessageException(response.code, body)
|
||||
|
||||
def _exceptionFromFailedRequest(self, response, body):
|
||||
try:
|
||||
jsonBody = json.loads(body)
|
||||
errcode = jsonBody['errcode']
|
||||
error = jsonBody['error']
|
||||
return MatrixCodeMessageException(response.code, error, errcode)
|
||||
except (ValueError, KeyError):
|
||||
return CodeMessageException(response.code, body)
|
||||
raise HttpResponseException(response.code, response.phrase, body)
|
||||
|
||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||
# The two should be factored out.
|
||||
|
|
|
@ -13,12 +13,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cgi
|
||||
import collections
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
from six.moves import http_client
|
||||
from six import PY3
|
||||
from six.moves import http_client, urllib
|
||||
|
||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||
|
||||
|
@ -35,7 +36,6 @@ from synapse.api.errors import (
|
|||
Codes,
|
||||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
cs_exception,
|
||||
)
|
||||
from synapse.http.request_metrics import requests_counter
|
||||
from synapse.util.caches import intern_dict
|
||||
|
@ -76,16 +76,13 @@ def wrap_json_request_handler(h):
|
|||
def wrapped_request_handler(self, request):
|
||||
try:
|
||||
yield h(self, request)
|
||||
except CodeMessageException as e:
|
||||
except SynapseError as e:
|
||||
code = e.code
|
||||
if isinstance(e, SynapseError):
|
||||
logger.info(
|
||||
"%s SynapseError: %s - %s", request, code, e.msg
|
||||
)
|
||||
else:
|
||||
logger.exception(e)
|
||||
logger.info(
|
||||
"%s SynapseError: %s - %s", request, code, e.msg
|
||||
)
|
||||
respond_with_json(
|
||||
request, code, cs_exception(e), send_cors=True,
|
||||
request, code, e.error_dict(), send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
|
||||
|
@ -264,6 +261,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
self.hs = hs
|
||||
|
||||
def register_paths(self, method, path_patterns, callback):
|
||||
method = method.encode("utf-8") # method is bytes on py3
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
|
@ -296,8 +294,19 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
|
||||
def _unquote(s):
|
||||
if PY3:
|
||||
# On Python 3, unquote is unicode -> unicode
|
||||
return urllib.parse.unquote(s)
|
||||
else:
|
||||
# On Python 2, unquote is bytes -> bytes We need to encode the
|
||||
# URL again (as it was decoded by _get_handler_for request), as
|
||||
# ASCII because it's a URL, and then decode it to get the UTF-8
|
||||
# characters that were quoted.
|
||||
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
|
||||
|
||||
kwargs = intern_dict({
|
||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||
name: _unquote(value) if value else value
|
||||
for name, value in group_dict.items()
|
||||
})
|
||||
|
||||
|
@ -313,9 +322,9 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
request (twisted.web.http.Request):
|
||||
|
||||
Returns:
|
||||
Tuple[Callable, dict[str, str]]: callback method, and the dict
|
||||
mapping keys to path components as specified in the handler's
|
||||
path match regexp.
|
||||
Tuple[Callable, dict[unicode, unicode]]: callback method, and the
|
||||
dict mapping keys to path components as specified in the
|
||||
handler's path match regexp.
|
||||
|
||||
The callback will normally be a method registered via
|
||||
register_paths, so will return (possibly via Deferred) either
|
||||
|
@ -327,7 +336,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
for path_entry in self.path_regexs.get(request.method, []):
|
||||
m = path_entry.pattern.match(request.path)
|
||||
m = path_entry.pattern.match(request.path.decode('ascii'))
|
||||
if m:
|
||||
# We found a match!
|
||||
return path_entry.callback, m.groupdict()
|
||||
|
@ -383,7 +392,7 @@ class RootRedirect(resource.Resource):
|
|||
self.url = path
|
||||
|
||||
def render_GET(self, request):
|
||||
return redirectTo(self.url, request)
|
||||
return redirectTo(self.url.encode('ascii'), request)
|
||||
|
||||
def getChild(self, name, request):
|
||||
if len(name) == 0:
|
||||
|
@ -404,12 +413,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
|
|||
return
|
||||
|
||||
if pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
|
||||
).encode("utf-8")
|
||||
else:
|
||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||
# canonicaljson already encodes to bytes
|
||||
json_bytes = encode_canonical_json(json_object)
|
||||
else:
|
||||
json_bytes = json.dumps(json_object)
|
||||
json_bytes = json.dumps(json_object).encode("utf-8")
|
||||
|
||||
return respond_with_json_bytes(
|
||||
request, code, json_bytes,
|
||||
|
|
|
@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
|||
if not content_bytes and allow_empty_body:
|
||||
return None
|
||||
|
||||
# Decode to Unicode so that simplejson will return Unicode strings on
|
||||
# Python 2
|
||||
try:
|
||||
content = json.loads(content_bytes)
|
||||
content_unicode = content_bytes.decode('utf8')
|
||||
except UnicodeDecodeError:
|
||||
logger.warn("Unable to decode UTF-8")
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
|
||||
try:
|
||||
content = json.loads(content_unicode)
|
||||
except Exception as e:
|
||||
logger.warn("Unable to parse JSON: %s", e)
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
|
|
|
@ -23,8 +23,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
MatrixCodeMessageException,
|
||||
SynapseError,
|
||||
HttpResponseException,
|
||||
)
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
|
@ -160,11 +159,11 @@ class ReplicationEndpoint(object):
|
|||
# If we timed out we probably don't need to worry about backing
|
||||
# off too much, but lets just wait a little anyway.
|
||||
yield clock.sleep(1)
|
||||
except MatrixCodeMessageException as e:
|
||||
except HttpResponseException as e:
|
||||
# We convert to SynapseError as we know that it was a SynapseError
|
||||
# on the master process that we should send to the client. (And
|
||||
# importantly, not stack traces everywhere)
|
||||
raise SynapseError(e.code, e.msg, e.errcode)
|
||||
raise e.to_synapse_error()
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import hashlib
|
|||
import hmac
|
||||
import logging
|
||||
|
||||
from six import text_type
|
||||
from six.moves import http_client
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
|||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
else:
|
||||
if (not isinstance(body['username'], str) or len(body['username']) > 512):
|
||||
if (
|
||||
not isinstance(body['username'], text_type)
|
||||
or len(body['username']) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
|
||||
username = body["username"].encode("utf-8")
|
||||
|
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
|||
400, "password must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
else:
|
||||
if (not isinstance(body['password'], str) or len(body['password']) > 512):
|
||||
if (
|
||||
not isinstance(body['password'], text_type)
|
||||
or len(body['password']) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
password = body["password"].encode("utf-8")
|
||||
|
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
|
|||
want_mac.update(b"admin" if admin else b"notadmin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
if not hmac.compare_digest(want_mac, got_mac):
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
|
||||
raise SynapseError(403, "HMAC incorrect")
|
||||
|
||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
|
||||
register = RegisterRestServlet(self.hs)
|
||||
|
||||
(user_id, _) = yield register.registration_handler.register(
|
||||
localpart=username.lower(), password=password, admin=bool(admin),
|
||||
localpart=body['username'].lower(),
|
||||
password=body["password"],
|
||||
admin=bool(admin),
|
||||
generate_token=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.types import RoomAlias
|
||||
|
||||
|
@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||
def on_GET(self, request, room_id):
|
||||
room = yield self.store.get_room(room_id)
|
||||
if room is None:
|
||||
raise SynapseError(400, "Unknown room")
|
||||
raise NotFoundError("Unknown room")
|
||||
|
||||
defer.returnValue((200, {
|
||||
"visibility": "public" if room["is_public"] else "private"
|
||||
|
|
|
@ -88,7 +88,7 @@ class EventRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||
event = yield self.event_handler.get_event(requester.user, None, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
|
|
|
@ -506,7 +506,7 @@ class RoomEventServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||
event = yield self.event_handler.get_event(requester.user, room_id, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
|
|
|
@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
kind = "user"
|
||||
if "kind" in request.args:
|
||||
kind = request.args["kind"][0]
|
||||
kind = b"user"
|
||||
if b"kind" in request.args:
|
||||
kind = request.args[b"kind"][0]
|
||||
|
||||
if kind == "guest":
|
||||
if kind == b"guest":
|
||||
ret = yield self._do_guest_registration(body)
|
||||
defer.returnValue(ret)
|
||||
return
|
||||
elif kind != "user":
|
||||
elif kind != b"user":
|
||||
raise UnrecognizedRequestError(
|
||||
"Do not understand membership kind: %s" % (kind,)
|
||||
)
|
||||
|
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
|
|||
assert_params_in_dict(params, ["password"])
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
new_password = params.get("password", None)
|
||||
|
||||
if desired_username is not None:
|
||||
desired_username = desired_username.lower()
|
||||
|
|
|
@ -379,7 +379,7 @@ class MediaRepository(object):
|
|||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise SynapseError.from_http_response_exception(e)
|
||||
raise e.to_synapse_error()
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except SynapseError:
|
||||
|
|
|
@ -177,7 +177,7 @@ class MediaStorage(object):
|
|||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(
|
||||
open(local_path, "w"), self.hs.get_reactor())
|
||||
open(local_path, "wb"), self.hs.get_reactor())
|
||||
yield res.write_to_consumer(consumer)
|
||||
yield consumer.wait()
|
||||
defer.returnValue(local_path)
|
||||
|
|
|
@ -577,7 +577,7 @@ def _make_state_cache_entry(
|
|||
|
||||
def _ordered_events(events):
|
||||
def key_func(e):
|
||||
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
|
||||
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
|
||||
|
||||
return sorted(events, key=key_func)
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
EventsStore,
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
RejectionsStore,
|
||||
|
@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
EventsStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
SearchStore,
|
||||
|
@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
self._clock = hs.get_clock()
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self.db_conn = db_conn
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
extra_tables=[("local_invites", "stream_id")]
|
||||
|
@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
|
||||
return self.runInteraction("count_users", _count_users)
|
||||
|
||||
def count_monthly_users(self):
|
||||
"""Counts the number of users who used this homeserver in the last 30 days
|
||||
|
||||
This method should be refactored with count_daily_users - the only
|
||||
reason not to is waiting on definition of mau
|
||||
|
||||
Returns:
|
||||
Defered[int]
|
||||
"""
|
||||
def _count_monthly_users(txn):
|
||||
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
sql = """
|
||||
SELECT COALESCE(count(*), 0) FROM (
|
||||
SELECT user_id FROM user_ips
|
||||
WHERE last_seen > ?
|
||||
GROUP BY user_id
|
||||
) u
|
||||
"""
|
||||
|
||||
txn.execute(sql, (thirty_days_ago,))
|
||||
count, = txn.fetchone()
|
||||
return count
|
||||
|
||||
return self.runInteraction("count_monthly_users", _count_monthly_users)
|
||||
|
||||
def count_r30_users(self):
|
||||
"""
|
||||
Counts the number of 30 day retained users, defined as:-
|
||||
|
|
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.appservice import AppServiceTransaction
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.storage.events import EventsWorkerStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
|||
from synapse.api.errors import StoreError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.events import EventsWorkerStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.storage.signatures import SignatureWorkerStore
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
@ -343,6 +343,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
|||
table="events",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
retcol="depth",
|
||||
allow_none=True,
|
||||
|
|
|
@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
|
|||
|
||||
|
||||
def encode_json(json_object):
|
||||
return frozendict_json_encoder.encode(json_object)
|
||||
"""
|
||||
Encode a Python object as JSON and return it in a Unicode string.
|
||||
"""
|
||||
out = frozendict_json_encoder.encode(json_object)
|
||||
if isinstance(out, bytes):
|
||||
out = out.decode('utf8')
|
||||
return out
|
||||
|
||||
|
||||
class _EventPeristenceQueue(object):
|
||||
|
@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
|
|||
return f
|
||||
|
||||
|
||||
class EventsStore(EventsWorkerStore):
|
||||
# inherits from EventFederationStore so that we can call _update_backward_extremities
|
||||
# and _handle_mult_prev_events (though arguably those could both be moved in here)
|
||||
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
|
||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||
|
||||
|
@ -231,12 +241,18 @@ class EventsStore(EventsWorkerStore):
|
|||
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_events(self, events_and_contexts, backfilled=False):
|
||||
"""
|
||||
Write events to the database
|
||||
Args:
|
||||
events_and_contexts: list of tuples of (event, context)
|
||||
backfilled: ?
|
||||
backfilled (bool): Whether the results are retrieved from federation
|
||||
via backfill or not. Used to determine if they're "new" events
|
||||
which might update the current state etc.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: the stream ordering of the latest persisted event
|
||||
"""
|
||||
partitioned = {}
|
||||
for event, ctx in events_and_contexts:
|
||||
|
@ -253,10 +269,14 @@ class EventsStore(EventsWorkerStore):
|
|||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
return make_deferred_yieldable(
|
||||
yield make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||
|
||||
defer.returnValue(max_persisted_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False):
|
||||
|
@ -1054,7 +1074,7 @@ class EventsStore(EventsWorkerStore):
|
|||
|
||||
metadata_json = encode_json(
|
||||
event.internal_metadata.get_dict()
|
||||
).decode("UTF-8")
|
||||
)
|
||||
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
|
@ -1168,8 +1188,8 @@ class EventsStore(EventsWorkerStore):
|
|||
"room_id": event.room_id,
|
||||
"internal_metadata": encode_json(
|
||||
event.internal_metadata.get_dict()
|
||||
).decode("UTF-8"),
|
||||
"json": encode_json(event_dict(event)).decode("UTF-8"),
|
||||
),
|
||||
"json": encode_json(event_dict(event)),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
],
|
||||
|
|
|
@ -19,7 +19,7 @@ from canonicaljson import json
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import NotFoundError
|
||||
# these are only included to make the type annotations work
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events import FrozenEvent
|
||||
|
@ -77,7 +77,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
allow_none=False):
|
||||
allow_none=False, check_room_id=None):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
|
@ -88,7 +88,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
include the previous states content in the unsigned field.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
allow_none (bool): If True, return None if no event found, if
|
||||
False throw an exception.
|
||||
False throw a NotFoundError
|
||||
check_room_id (str|None): if not None, check the room of the found event.
|
||||
If there is a mismatch, behave as per allow_none.
|
||||
|
||||
Returns:
|
||||
Deferred : A FrozenEvent.
|
||||
|
@ -100,10 +102,16 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
if not events and not allow_none:
|
||||
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
||||
event = events[0] if events else None
|
||||
|
||||
defer.returnValue(events[0] if events else None)
|
||||
if event is not None and check_room_id is not None:
|
||||
if event.room_id != check_room_id:
|
||||
event = None
|
||||
|
||||
if event is None and not allow_none:
|
||||
raise NotFoundError("Could not find event %s" % (event_id,))
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(self, event_ids, check_redacted=True,
|
||||
|
|
|
@ -24,7 +24,7 @@ from canonicaljson import json
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.storage.events import EventsWorkerStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
|
|
|
@ -88,5 +88,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
|
|||
"UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'",
|
||||
(sql, ),
|
||||
)
|
||||
cur.execute("PRAGMA schema_version=%i" % (oldver+1,))
|
||||
cur.execute("PRAGMA schema_version=%i" % (oldver + 1,))
|
||||
cur.execute("PRAGMA writable_schema=OFF")
|
||||
|
|
|
@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
Returns:
|
||||
A dict of algorithm -> hash.
|
||||
A dict[unicode, bytes] of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
|
|
|
@ -43,7 +43,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.events import EventsWorkerStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
|
|
@ -137,7 +137,7 @@ class DomainSpecificString(
|
|||
@classmethod
|
||||
def from_string(cls, s):
|
||||
"""Parse the string given by 's' into a structure object."""
|
||||
if len(s) < 1 or s[0] != cls.SIGIL:
|
||||
if len(s) < 1 or s[0:1] != cls.SIGIL:
|
||||
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
||||
cls.__name__, cls.SIGIL,
|
||||
))
|
||||
|
|
|
@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
# If we're passed a cache_context then we'll want to call its
|
||||
# invalidate() whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||
list_args = arg_dict[self.list_name]
|
||||
|
||||
# cached is a dict arg -> deferred, where deferred results in a
|
||||
# 2-tuple (`arg`, `result`)
|
||||
results = {}
|
||||
cached_defers = {}
|
||||
missing = []
|
||||
|
||||
def update_results_dict(res, arg):
|
||||
results[arg] = res
|
||||
|
||||
# list of deferreds to wait for
|
||||
cached_defers = []
|
||||
|
||||
missing = set()
|
||||
|
||||
# If the cache takes a single arg then that is used as the key,
|
||||
# otherwise a tuple is used.
|
||||
if num_args == 1:
|
||||
def cache_get(arg):
|
||||
return cache.get(arg, callback=invalidate_callback)
|
||||
def arg_to_cache_key(arg):
|
||||
return arg
|
||||
else:
|
||||
key = list(keyargs)
|
||||
keylist = list(keyargs)
|
||||
|
||||
def cache_get(arg):
|
||||
key[self.list_pos] = arg
|
||||
return cache.get(tuple(key), callback=invalidate_callback)
|
||||
def arg_to_cache_key(arg):
|
||||
keylist[self.list_pos] = arg
|
||||
return tuple(keylist)
|
||||
|
||||
for arg in list_args:
|
||||
try:
|
||||
res = cache_get(arg)
|
||||
|
||||
res = cache.get(arg_to_cache_key(arg),
|
||||
callback=invalidate_callback)
|
||||
if not isinstance(res, ObservableDeferred):
|
||||
results[arg] = res
|
||||
elif not res.has_succeeded():
|
||||
res = res.observe()
|
||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||
cached_defers[arg] = res
|
||||
res.addCallback(update_results_dict, arg)
|
||||
cached_defers.append(res)
|
||||
else:
|
||||
results[arg] = res.get_result()
|
||||
except KeyError:
|
||||
missing.append(arg)
|
||||
missing.add(arg)
|
||||
|
||||
if missing:
|
||||
args_to_call = dict(arg_dict)
|
||||
args_to_call[self.list_name] = missing
|
||||
# we need an observable deferred for each entry in the list,
|
||||
# which we put in the cache. Each deferred resolves with the
|
||||
# relevant result for that key.
|
||||
deferreds_map = {}
|
||||
for arg in missing:
|
||||
deferred = defer.Deferred()
|
||||
deferreds_map[arg] = deferred
|
||||
key = arg_to_cache_key(arg)
|
||||
observable = ObservableDeferred(deferred)
|
||||
cache.set(key, observable, callback=invalidate_callback)
|
||||
|
||||
ret_d = defer.maybeDeferred(
|
||||
def complete_all(res):
|
||||
# the wrapped function has completed. It returns a
|
||||
# a dict. We can now resolve the observable deferreds in
|
||||
# the cache and update our own result map.
|
||||
for e in missing:
|
||||
val = res.get(e, None)
|
||||
deferreds_map[e].callback(val)
|
||||
results[e] = val
|
||||
|
||||
def errback(f):
|
||||
# the wrapped function has failed. Invalidate any cache
|
||||
# entries we're supposed to be populating, and fail
|
||||
# their deferreds.
|
||||
for e in missing:
|
||||
key = arg_to_cache_key(e)
|
||||
cache.invalidate(key)
|
||||
deferreds_map[e].errback(f)
|
||||
|
||||
# return the failure, to propagate to our caller.
|
||||
return f
|
||||
|
||||
args_to_call = dict(arg_dict)
|
||||
args_to_call[self.list_name] = list(missing)
|
||||
|
||||
cached_defers.append(defer.maybeDeferred(
|
||||
logcontext.preserve_fn(self.function_to_call),
|
||||
**args_to_call
|
||||
)
|
||||
|
||||
ret_d = ObservableDeferred(ret_d)
|
||||
|
||||
# We need to create deferreds for each arg in the list so that
|
||||
# we can insert the new deferred into the cache.
|
||||
for arg in missing:
|
||||
observer = ret_d.observe()
|
||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||
|
||||
observer = ObservableDeferred(observer)
|
||||
|
||||
if num_args == 1:
|
||||
cache.set(
|
||||
arg, observer,
|
||||
callback=invalidate_callback
|
||||
)
|
||||
|
||||
def invalidate(f, key):
|
||||
cache.invalidate(key)
|
||||
return f
|
||||
observer.addErrback(invalidate, arg)
|
||||
else:
|
||||
key = list(keyargs)
|
||||
key[self.list_pos] = arg
|
||||
cache.set(
|
||||
tuple(key), observer,
|
||||
callback=invalidate_callback
|
||||
)
|
||||
|
||||
def invalidate(f, key):
|
||||
cache.invalidate(key)
|
||||
return f
|
||||
observer.addErrback(invalidate, tuple(key))
|
||||
|
||||
res = observer.observe()
|
||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||
|
||||
cached_defers[arg] = res
|
||||
).addCallbacks(complete_all, errback))
|
||||
|
||||
if cached_defers:
|
||||
def update_results_dict(res):
|
||||
results.update(res)
|
||||
return results
|
||||
|
||||
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
list(cached_defers.values()),
|
||||
d = defer.gatherResults(
|
||||
cached_defers,
|
||||
consumeErrors=True,
|
||||
).addCallback(update_results_dict).addErrback(
|
||||
).addCallbacks(
|
||||
lambda _: results,
|
||||
unwrapFirstError
|
||||
))
|
||||
)
|
||||
return logcontext.make_deferred_yieldable(d)
|
||||
else:
|
||||
return results
|
||||
|
||||
|
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
|
|||
cache.
|
||||
|
||||
Args:
|
||||
cache (Cache): The underlying cache to use.
|
||||
cached_method_name (str): The name of the single-item lookup method.
|
||||
This is only used to find the cache to use.
|
||||
list_name (str): The name of the argument that is the list to use to
|
||||
do batch lookups in the cache.
|
||||
num_args (int): Number of arguments to use as the key in the cache
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from six import string_types
|
||||
from six import binary_type, text_type
|
||||
|
||||
from canonicaljson import json
|
||||
from frozendict import frozendict
|
||||
|
@ -26,7 +26,7 @@ def freeze(o):
|
|||
if isinstance(o, frozendict):
|
||||
return o
|
||||
|
||||
if isinstance(o, string_types):
|
||||
if isinstance(o, (binary_type, text_type)):
|
||||
return o
|
||||
|
||||
try:
|
||||
|
@ -41,7 +41,7 @@ def unfreeze(o):
|
|||
if isinstance(o, (dict, frozendict)):
|
||||
return dict({k: unfreeze(v) for k, v in o.items()})
|
||||
|
||||
if isinstance(o, string_types):
|
||||
if isinstance(o, (binary_type, text_type)):
|
||||
return o
|
||||
|
||||
try:
|
||||
|
|
|
@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.auth = Auth(self.hs)
|
||||
|
||||
self.test_user = "@foo:bar"
|
||||
self.test_token = "_test_token_"
|
||||
self.test_token = b"_test_token_"
|
||||
|
||||
# this is overridden for the appservice tests
|
||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||
|
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
d = self.auth.get_user_by_req(request)
|
||||
self.failureResultOf(d, AuthError)
|
||||
|
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "192.168.10.10"
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||
|
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "131.111.8.42"
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
d = self.auth.get_user_by_req(request)
|
||||
self.failureResultOf(d, AuthError)
|
||||
|
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
d = self.auth.get_user_by_req(request)
|
||||
self.failureResultOf(d, AuthError)
|
||||
|
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
||||
masquerading_user_id = "@doppelganger:matrix.org"
|
||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
ip_range_whitelist=None,
|
||||
|
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args["user_id"] = [masquerading_user_id]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(requester.user.to_string(), masquerading_user_id)
|
||||
self.assertEquals(
|
||||
requester.user.to_string(),
|
||||
masquerading_user_id.decode('utf8')
|
||||
)
|
||||
|
||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||
masquerading_user_id = "@doppelganger:matrix.org"
|
||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
ip_range_whitelist=None,
|
||||
|
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
request = Mock(args={})
|
||||
request.getClientIP.return_value = "127.0.0.1"
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.args["user_id"] = [masquerading_user_id]
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
d = self.auth.get_user_by_req(request)
|
||||
self.failureResultOf(d, AuthError)
|
||||
|
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
# check the token works
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [token]
|
||||
request.args[b"access_token"] = [token.encode('ascii')]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||
|
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
# the token should *not* work now
|
||||
request = Mock(args={})
|
||||
request.args["access_token"] = [guest_tok]
|
||||
request.args[b"access_token"] = [guest_tok.encode('ascii')]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from mock import Mock
|
||||
|
||||
import pymacaroons
|
||||
|
||||
|
@ -19,6 +20,7 @@ from twisted.internet import defer
|
|||
|
||||
import synapse
|
||||
import synapse.api.errors
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
|
||||
from tests import unittest
|
||||
|
@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.hs.handlers = AuthHandlers(self.hs)
|
||||
self.auth_handler = self.hs.handlers.auth_handler
|
||||
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||
# MAU tests
|
||||
self.hs.config.max_mau_value = 50
|
||||
self.small_number_of_users = 1
|
||||
self.large_number_of_users = 100
|
||||
|
||||
def test_token_is_a_macaroon(self):
|
||||
token = self.macaroon_generator.generate_access_token("some_user")
|
||||
|
@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
|
|||
v.satisfy_general(verify_nonce)
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
self.hs.clock.now = 1000
|
||||
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", 5000
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
"a_user",
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
self.assertEqual("a_user", user_id)
|
||||
|
||||
# when we advance the clock, the token should be rejected
|
||||
self.hs.clock.now = 6000
|
||||
with self.assertRaises(synapse.api.errors.AuthError):
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", 5000
|
||||
)
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
self.assertEqual(
|
||||
"a_user",
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
"a_user", user_id
|
||||
)
|
||||
|
||||
# add another "user_id" caveat, which might allow us to override the
|
||||
|
@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("user_id = b_user")
|
||||
|
||||
with self.assertRaises(synapse.api.errors.AuthError):
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_disabled(self):
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
# Ensure does not throw exception
|
||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||
|
||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self._get_macaroon().serialize()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_exceeded(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
return_value=defer.succeed(self.large_number_of_users)
|
||||
)
|
||||
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
return_value=defer.succeed(self.large_number_of_users)
|
||||
)
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self._get_macaroon().serialize()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_not_exceeded(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
return_value=defer.succeed(self.small_number_of_users)
|
||||
)
|
||||
# Ensure does not raise exception
|
||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
return_value=defer.succeed(self.small_number_of_users)
|
||||
)
|
||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self._get_macaroon().serialize()
|
||||
)
|
||||
|
||||
def _get_macaroon(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"user_a", 5000
|
||||
)
|
||||
return pymacaroons.Macaroon.deserialize(token)
|
||||
|
|
|
@ -17,6 +17,7 @@ from mock import Mock
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import RegistrationError
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
from synapse.types import UserID, create_requester
|
||||
|
||||
|
@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
requester, local_part, display_name)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cannot_register_when_mau_limits_exceeded(self):
|
||||
local_part = "someone"
|
||||
display_name = "someone"
|
||||
requester = create_requester("@as:test")
|
||||
store = self.hs.get_datastore()
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.hs.config.max_mau_value = 50
|
||||
lots_of_users = 100
|
||||
small_number_users = 1
|
||||
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||
|
||||
# Ensure does not throw exception
|
||||
yield self.handler.get_or_create_user(requester, 'a', display_name)
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.get_or_create_user(requester, 'b', display_name)
|
||||
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
|
||||
|
||||
self._macaroon_mock_generator("another_secret")
|
||||
|
||||
# Ensure does not throw exception
|
||||
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
|
||||
|
||||
self._macaroon_mock_generator("another another secret")
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.register(localpart=local_part)
|
||||
|
||||
self._macaroon_mock_generator("another another secret")
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.register_saml2(local_part)
|
||||
|
||||
def _macaroon_mock_generator(self, secret):
|
||||
"""
|
||||
Reset macaroon generator in the case where the test creates multiple users
|
||||
"""
|
||||
macaroon_generator = Mock(
|
||||
generate_access_token=Mock(return_value=secret))
|
||||
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
|
|
|
@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
|
|||
"content": content,
|
||||
}
|
||||
],
|
||||
"pdu_failures": [],
|
||||
}
|
||||
|
||||
|
||||
|
|
65
tests/storage/test__init__.py
Normal file
65
tests/storage/test__init__.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import tests.utils
|
||||
|
||||
|
||||
class InitTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(InitTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield tests.utils.setup_test_homeserver()
|
||||
|
||||
hs.config.max_mau_value = 50
|
||||
hs.config.limit_usage_by_mau = True
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_count_monthly_users(self):
|
||||
count = yield self.store.count_monthly_users()
|
||||
self.assertEqual(0, count)
|
||||
|
||||
yield self._insert_user_ips("@user:server1")
|
||||
yield self._insert_user_ips("@user:server2")
|
||||
|
||||
count = yield self.store.count_monthly_users()
|
||||
self.assertEqual(2, count)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _insert_user_ips(self, user):
|
||||
"""
|
||||
Helper function to populate user_ips without using batch insertion infra
|
||||
args:
|
||||
user (str): specify username i.e. @user:server.com
|
||||
"""
|
||||
yield self.store._simple_upsert(
|
||||
table="user_ips",
|
||||
keyvalues={
|
||||
"user_id": user,
|
||||
"access_token": "access_token",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"device_id": "device_id",
|
||||
},
|
||||
values={
|
||||
"last_seen": self.clock.time_msec(),
|
||||
}
|
||||
)
|
|
@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
r = yield obj.fn(2, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
|
||||
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_cache(self):
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1, arg2):
|
||||
pass
|
||||
|
||||
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||
def list_fn(self, args1, arg2):
|
||||
assert (
|
||||
logcontext.LoggingContext.current_context().request == "c1"
|
||||
)
|
||||
# we want this to behave like an asynchronous function
|
||||
yield run_on_reactor()
|
||||
assert (
|
||||
logcontext.LoggingContext.current_context().request == "c1"
|
||||
)
|
||||
defer.returnValue(self.mock(args1, arg2))
|
||||
|
||||
with logcontext.LoggingContext() as c1:
|
||||
c1.request = "c1"
|
||||
obj = Cls()
|
||||
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||
d1 = obj.list_fn([10, 20], 2)
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel,
|
||||
)
|
||||
r = yield d1
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(),
|
||||
c1
|
||||
)
|
||||
obj.mock.assert_called_once_with([10, 20], 2)
|
||||
self.assertEqual(r, {10: 'fish', 20: 'chips'})
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = {30: 'peas'}
|
||||
r = yield obj.list_fn([20, 30], 2)
|
||||
obj.mock.assert_called_once_with([30], 2)
|
||||
self.assertEqual(r, {20: 'chips', 30: 'peas'})
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# all the values should now be cached
|
||||
r = yield obj.fn(10, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
r = yield obj.fn(20, 2)
|
||||
self.assertEqual(r, 'chips')
|
||||
r = yield obj.fn(30, 2)
|
||||
self.assertEqual(r, 'peas')
|
||||
r = yield obj.list_fn([10, 20, 30], 2)
|
||||
obj.mock.assert_not_called()
|
||||
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invalidate(self):
|
||||
"""Make sure that invalidation callbacks are called."""
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1, arg2):
|
||||
pass
|
||||
|
||||
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||
def list_fn(self, args1, arg2):
|
||||
# we want this to behave like an asynchronous function
|
||||
yield run_on_reactor()
|
||||
defer.returnValue(self.mock(args1, arg2))
|
||||
|
||||
obj = Cls()
|
||||
invalidate0 = mock.Mock()
|
||||
invalidate1 = mock.Mock()
|
||||
|
||||
# cache miss
|
||||
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||
obj.mock.assert_called_once_with([10, 20], 2)
|
||||
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# cache hit
|
||||
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
|
||||
obj.mock.assert_not_called()
|
||||
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
|
||||
|
||||
invalidate0.assert_not_called()
|
||||
invalidate1.assert_not_called()
|
||||
|
||||
# now if we invalidate the keys, both invalidations should get called
|
||||
obj.fn.invalidate((10, 2))
|
||||
invalidate0.assert_called_once()
|
||||
invalidate1.assert_called_once()
|
||||
|
|
|
@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
|
|||
self.prefix = prefix
|
||||
|
||||
def trigger_get(self, path):
|
||||
return self.trigger("GET", path, None)
|
||||
return self.trigger(b"GET", path, None)
|
||||
|
||||
@patch('twisted.web.http.Request')
|
||||
@defer.inlineCallbacks
|
||||
|
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
|
|||
|
||||
headers = {}
|
||||
if federation_auth:
|
||||
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
|
||||
headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
|
||||
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
||||
|
||||
# return the right path if the event requires it
|
||||
|
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(path, bytes):
|
||||
path = path.decode('utf8')
|
||||
|
||||
for (method, pattern, func) in self.callbacks:
|
||||
if http_method != method:
|
||||
continue
|
||||
|
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
|
|||
if matcher:
|
||||
try:
|
||||
args = [
|
||||
urlparse.unquote(u).decode("UTF-8")
|
||||
urlparse.unquote(u)
|
||||
for u in matcher.groups()
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in a new issue