forked from MirrorHub/synapse
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/refactor_repl_servlet
This commit is contained in:
commit
cb298ff623
67 changed files with 1074 additions and 648 deletions
5
.github/ISSUE_TEMPLATE.md
vendored
5
.github/ISSUE_TEMPLATE.md
vendored
|
@ -27,8 +27,9 @@ Describe here the problem that you are experiencing, or the feature you are requ
|
||||||
|
|
||||||
Describe how what happens differs from what you expected.
|
Describe how what happens differs from what you expected.
|
||||||
|
|
||||||
If you can identify any relevant log snippets from _homeserver.log_, please include
|
<!-- If you can identify any relevant log snippets from _homeserver.log_, please include
|
||||||
those here (please be careful to remove any personal or private data):
|
those (please be careful to remove any personal or private data). Please surround them with
|
||||||
|
``` (three backticks, on a line on their own), so that they are formatted legibly. -->
|
||||||
|
|
||||||
### Version information
|
### Version information
|
||||||
|
|
||||||
|
|
|
@ -63,3 +63,6 @@ Christoph Witzany <christoph at web.crofting.com>
|
||||||
|
|
||||||
Pierre Jaury <pierre at jaury.eu>
|
Pierre Jaury <pierre at jaury.eu>
|
||||||
* Docker packaging
|
* Docker packaging
|
||||||
|
|
||||||
|
Serban Constantin <serban.constantin at gmail dot com>
|
||||||
|
* Small bug fix
|
|
@ -1,3 +1,12 @@
|
||||||
|
Synapse 0.33.1 (2018-08-02)
|
||||||
|
===========================
|
||||||
|
|
||||||
|
SECURITY FIXES
|
||||||
|
--------------
|
||||||
|
|
||||||
|
- Fix a potential issue where servers could request events for rooms they have not joined. ([\#3641](https://github.com/matrix-org/synapse/issues/3641))
|
||||||
|
- Fix a potential issue where users could see events in private rooms before they joined. ([\#3642](https://github.com/matrix-org/synapse/issues/3642))
|
||||||
|
|
||||||
Synapse 0.33.0 (2018-07-19)
|
Synapse 0.33.0 (2018-07-19)
|
||||||
===========================
|
===========================
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ makes it horribly hard to review otherwise.
|
||||||
Changelog
|
Changelog
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
||||||
All changes, even minor ones, need a corresponding changelog
|
All changes, even minor ones, need a corresponding changelog / newsfragment
|
||||||
entry. These are managed by Towncrier
|
entry. These are managed by Towncrier
|
||||||
(https://github.com/hawkowl/towncrier).
|
(https://github.com/hawkowl/towncrier).
|
||||||
|
|
||||||
|
|
22
Dockerfile
22
Dockerfile
|
@ -1,16 +1,32 @@
|
||||||
FROM docker.io/python:2-alpine3.7
|
FROM docker.io/python:2-alpine3.7
|
||||||
|
|
||||||
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev
|
RUN apk add --no-cache --virtual .nacl_deps \
|
||||||
|
build-base \
|
||||||
|
libffi-dev \
|
||||||
|
libjpeg-turbo-dev \
|
||||||
|
libressl-dev \
|
||||||
|
libxslt-dev \
|
||||||
|
linux-headers \
|
||||||
|
postgresql-dev \
|
||||||
|
su-exec \
|
||||||
|
zlib-dev
|
||||||
|
|
||||||
COPY . /synapse
|
COPY . /synapse
|
||||||
|
|
||||||
# A wheel cache may be provided in ./cache for faster build
|
# A wheel cache may be provided in ./cache for faster build
|
||||||
RUN cd /synapse \
|
RUN cd /synapse \
|
||||||
&& pip install --upgrade pip setuptools psycopg2 lxml \
|
&& pip install --upgrade \
|
||||||
|
lxml \
|
||||||
|
pip \
|
||||||
|
psycopg2 \
|
||||||
|
setuptools \
|
||||||
&& mkdir -p /synapse/cache \
|
&& mkdir -p /synapse/cache \
|
||||||
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
|
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
|
||||||
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
|
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
|
||||||
&& rm -rf setup.py setup.cfg synapse
|
&& rm -rf \
|
||||||
|
setup.cfg \
|
||||||
|
setup.py \
|
||||||
|
synapse
|
||||||
|
|
||||||
VOLUME ["/data"]
|
VOLUME ["/data"]
|
||||||
|
|
||||||
|
|
1
changelog.d/2952.bugfix
Normal file
1
changelog.d/2952.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Make /directory/list API return 404 for room not found instead of 400
|
1
changelog.d/3384.misc
Normal file
1
changelog.d/3384.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Rewrite cache list decorator
|
1
changelog.d/3543.misc
Normal file
1
changelog.d/3543.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve Dockerfile and docker-compose instructions
|
1
changelog.d/3569.bugfix
Normal file
1
changelog.d/3569.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.
|
1
changelog.d/3612.misc
Normal file
1
changelog.d/3612.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Make EventStore inherit from EventFederationStore
|
1
changelog.d/3621.misc
Normal file
1
changelog.d/3621.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor FederationHandler to move DB writes into separate functions
|
1
changelog.d/3628.misc
Normal file
1
changelog.d/3628.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Remove unused field "pdu_failures" from transactions.
|
1
changelog.d/3630.feature
Normal file
1
changelog.d/3630.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add ability to limit number of monthly active users on the server
|
1
changelog.d/3634.misc
Normal file
1
changelog.d/3634.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
rename replication_layer to federation_client
|
1
changelog.d/3638.misc
Normal file
1
changelog.d/3638.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Factor out exception handling in federation_client
|
1
changelog.d/3639.feature
Normal file
1
changelog.d/3639.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
When we fail to join a room over federation, pass the error code back to the client.
|
1
changelog.d/3645.misc
Normal file
1
changelog.d/3645.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update CONTRIBUTING to mention newsfragments.
|
|
@ -9,13 +9,7 @@ use that server.
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
Build the docker image with the `docker build` command from the root of the synapse repository.
|
Build the docker image with the `docker-compose build` command.
|
||||||
|
|
||||||
```
|
|
||||||
docker build -t docker.io/matrixdotorg/synapse .
|
|
||||||
```
|
|
||||||
|
|
||||||
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
|
|
||||||
|
|
||||||
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
|
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ version: '3'
|
||||||
services:
|
services:
|
||||||
|
|
||||||
synapse:
|
synapse:
|
||||||
|
build: ../..
|
||||||
image: docker.io/matrixdotorg/synapse:latest
|
image: docker.io/matrixdotorg/synapse:latest
|
||||||
# Since snyapse does not retry to connect to the database, restart upon
|
# Since snyapse does not retry to connect to the database, restart upon
|
||||||
# failure
|
# failure
|
||||||
|
|
6
contrib/grafana/README.md
Normal file
6
contrib/grafana/README.md
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# Using the Synapse Grafana dashboard
|
||||||
|
|
||||||
|
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
|
||||||
|
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst
|
||||||
|
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
|
||||||
|
3. Set up additional recording rules
|
|
@ -17,4 +17,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.33.0"
|
__version__ = "0.33.1"
|
||||||
|
|
|
@ -252,10 +252,10 @@ class Auth(object):
|
||||||
if ip_address not in app_service.ip_range_whitelist:
|
if ip_address not in app_service.ip_range_whitelist:
|
||||||
defer.returnValue((None, None))
|
defer.returnValue((None, None))
|
||||||
|
|
||||||
if "user_id" not in request.args:
|
if b"user_id" not in request.args:
|
||||||
defer.returnValue((app_service.sender, app_service))
|
defer.returnValue((app_service.sender, app_service))
|
||||||
|
|
||||||
user_id = request.args["user_id"][0]
|
user_id = request.args[b"user_id"][0].decode('utf8')
|
||||||
if app_service.sender == user_id:
|
if app_service.sender == user_id:
|
||||||
defer.returnValue((app_service.sender, app_service))
|
defer.returnValue((app_service.sender, app_service))
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ class Codes(object):
|
||||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
||||||
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
||||||
|
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
|
@ -69,20 +70,6 @@ class CodeMessageException(RuntimeError):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
def error_dict(self):
|
|
||||||
return cs_error(self.msg)
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixCodeMessageException(CodeMessageException):
|
|
||||||
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
|
||||||
"""
|
|
||||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
|
||||||
super(MatrixCodeMessageException, self).__init__(code, msg)
|
|
||||||
self.errcode = errcode
|
|
||||||
|
|
||||||
|
|
||||||
class SynapseError(CodeMessageException):
|
class SynapseError(CodeMessageException):
|
||||||
"""A base exception type for matrix errors which have an errcode and error
|
"""A base exception type for matrix errors which have an errcode and error
|
||||||
|
@ -108,38 +95,28 @@ class SynapseError(CodeMessageException):
|
||||||
self.errcode,
|
self.errcode,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_http_response_exception(cls, err):
|
|
||||||
"""Make a SynapseError based on an HTTPResponseException
|
|
||||||
|
|
||||||
This is useful when a proxied request has failed, and we need to
|
class ProxiedRequestError(SynapseError):
|
||||||
decide how to map the failure onto a matrix error to send back to the
|
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
|
||||||
client.
|
|
||||||
|
|
||||||
An attempt is made to parse the body of the http response as a matrix
|
Attributes:
|
||||||
error. If that succeeds, the errcode and error message from the body
|
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||||
are used as the errcode and error message in the new synapse error.
|
|
||||||
|
|
||||||
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
|
|
||||||
set to the reason code from the HTTP response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
err (HttpResponseException):
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SynapseError:
|
|
||||||
"""
|
"""
|
||||||
# try to parse the body as json, to get better errcode/msg, but
|
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
|
||||||
# default to M_UNKNOWN with the HTTP status as the error text
|
super(ProxiedRequestError, self).__init__(
|
||||||
try:
|
code, msg, errcode
|
||||||
j = json.loads(err.response)
|
)
|
||||||
except ValueError:
|
if additional_fields is None:
|
||||||
j = {}
|
self._additional_fields = {}
|
||||||
errcode = j.get('errcode', Codes.UNKNOWN)
|
else:
|
||||||
errmsg = j.get('error', err.msg)
|
self._additional_fields = dict(additional_fields)
|
||||||
|
|
||||||
res = SynapseError(err.code, errmsg, errcode)
|
def error_dict(self):
|
||||||
return res
|
return cs_error(
|
||||||
|
self.msg,
|
||||||
|
self.errcode,
|
||||||
|
**self._additional_fields
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConsentNotGivenError(SynapseError):
|
class ConsentNotGivenError(SynapseError):
|
||||||
|
@ -308,14 +285,6 @@ class LimitExceededError(SynapseError):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cs_exception(exception):
|
|
||||||
if isinstance(exception, CodeMessageException):
|
|
||||||
return exception.error_dict()
|
|
||||||
else:
|
|
||||||
logger.error("Unknown exception type: %s", type(exception))
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
|
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
|
||||||
""" Utility method for constructing an error response for client-server
|
""" Utility method for constructing an error response for client-server
|
||||||
interactions.
|
interactions.
|
||||||
|
@ -372,7 +341,7 @@ class HttpResponseException(CodeMessageException):
|
||||||
Represents an HTTP-level failure of an outbound request
|
Represents an HTTP-level failure of an outbound request
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
response (str): body of response
|
response (bytes): body of response
|
||||||
"""
|
"""
|
||||||
def __init__(self, code, msg, response):
|
def __init__(self, code, msg, response):
|
||||||
"""
|
"""
|
||||||
|
@ -380,7 +349,39 @@ class HttpResponseException(CodeMessageException):
|
||||||
Args:
|
Args:
|
||||||
code (int): HTTP status code
|
code (int): HTTP status code
|
||||||
msg (str): reason phrase from HTTP response status line
|
msg (str): reason phrase from HTTP response status line
|
||||||
response (str): body of response
|
response (bytes): body of response
|
||||||
"""
|
"""
|
||||||
super(HttpResponseException, self).__init__(code, msg)
|
super(HttpResponseException, self).__init__(code, msg)
|
||||||
self.response = response
|
self.response = response
|
||||||
|
|
||||||
|
def to_synapse_error(self):
|
||||||
|
"""Make a SynapseError based on an HTTPResponseException
|
||||||
|
|
||||||
|
This is useful when a proxied request has failed, and we need to
|
||||||
|
decide how to map the failure onto a matrix error to send back to the
|
||||||
|
client.
|
||||||
|
|
||||||
|
An attempt is made to parse the body of the http response as a matrix
|
||||||
|
error. If that succeeds, the errcode and error message from the body
|
||||||
|
are used as the errcode and error message in the new synapse error.
|
||||||
|
|
||||||
|
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
|
||||||
|
set to the reason code from the HTTP response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SynapseError:
|
||||||
|
"""
|
||||||
|
# try to parse the body as json, to get better errcode/msg, but
|
||||||
|
# default to M_UNKNOWN with the HTTP status as the error text
|
||||||
|
try:
|
||||||
|
j = json.loads(self.response)
|
||||||
|
except ValueError:
|
||||||
|
j = {}
|
||||||
|
|
||||||
|
if not isinstance(j, dict):
|
||||||
|
j = {}
|
||||||
|
|
||||||
|
errcode = j.pop('errcode', Codes.UNKNOWN)
|
||||||
|
errmsg = j.pop('error', self.msg)
|
||||||
|
|
||||||
|
return ProxiedRequestError(self.code, errmsg, errcode, j)
|
||||||
|
|
|
@ -20,6 +20,8 @@ import sys
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||||
|
@ -300,6 +302,11 @@ class SynapseHomeServer(HomeServer):
|
||||||
quit_with_error(e.message)
|
quit_with_error(e.message)
|
||||||
|
|
||||||
|
|
||||||
|
# Gauges to expose monthly active user control metrics
|
||||||
|
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
|
||||||
|
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
|
||||||
|
|
||||||
|
|
||||||
def setup(config_options):
|
def setup(config_options):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -512,6 +519,18 @@ def run(hs):
|
||||||
# table will decrease
|
# table will decrease
|
||||||
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
|
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def generate_monthly_active_users():
|
||||||
|
count = 0
|
||||||
|
if hs.config.limit_usage_by_mau:
|
||||||
|
count = yield hs.get_datastore().count_monthly_users()
|
||||||
|
current_mau_gauge.set(float(count))
|
||||||
|
max_mau_value_gauge.set(float(hs.config.max_mau_value))
|
||||||
|
|
||||||
|
generate_monthly_active_users()
|
||||||
|
if hs.config.limit_usage_by_mau:
|
||||||
|
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
|
||||||
|
|
||||||
if hs.config.report_stats:
|
if hs.config.report_stats:
|
||||||
logger.info("Scheduling stats reporting for 3 hour intervals")
|
logger.info("Scheduling stats reporting for 3 hour intervals")
|
||||||
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
|
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
|
||||||
|
|
|
@ -67,6 +67,14 @@ class ServerConfig(Config):
|
||||||
"block_non_admin_invites", False,
|
"block_non_admin_invites", False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Options to control access by tracking MAU
|
||||||
|
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
|
||||||
|
if self.limit_usage_by_mau:
|
||||||
|
self.max_mau_value = config.get(
|
||||||
|
"max_mau_value", 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.max_mau_value = 0
|
||||||
# FIXME: federation_domain_whitelist needs sytests
|
# FIXME: federation_domain_whitelist needs sytests
|
||||||
self.federation_domain_whitelist = None
|
self.federation_domain_whitelist = None
|
||||||
federation_domain_whitelist = config.get(
|
federation_domain_whitelist = config.get(
|
||||||
|
@ -209,6 +217,8 @@ class ServerConfig(Config):
|
||||||
# different cores. See
|
# different cores. See
|
||||||
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
|
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
|
||||||
#
|
#
|
||||||
|
# This setting requires the affinity package to be installed!
|
||||||
|
#
|
||||||
# cpu_affinity: 0xFFFFFFFF
|
# cpu_affinity: 0xFFFFFFFF
|
||||||
|
|
||||||
# Whether to serve a web client from the HTTP/HTTPS root resource.
|
# Whether to serve a web client from the HTTP/HTTPS root resource.
|
||||||
|
|
|
@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
|
||||||
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidResponseError(RuntimeError):
|
||||||
|
"""Helper for _try_destination_list: indicates that the server returned a response
|
||||||
|
we couldn't parse
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(FederationClient, self).__init__(hs)
|
super(FederationClient, self).__init__(hs)
|
||||||
|
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
|
||||||
defer.returnValue(signed_auth)
|
defer.returnValue(signed_auth)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
def _try_destination_list(self, description, destinations, callback):
|
||||||
|
"""Try an operation on a series of servers, until it succeeds
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description (unicode): description of the operation we're doing, for logging
|
||||||
|
|
||||||
|
destinations (Iterable[unicode]): list of server_names to try
|
||||||
|
|
||||||
|
callback (callable): Function to run for each server. Passed a single
|
||||||
|
argument: the server_name to try. May return a deferred.
|
||||||
|
|
||||||
|
If the callback raises a CodeMessageException with a 300/400 code,
|
||||||
|
attempts to perform the operation stop immediately and the exception is
|
||||||
|
reraised.
|
||||||
|
|
||||||
|
Otherwise, if the callback raises an Exception the error is logged and the
|
||||||
|
next server tried. Normally the stacktrace is logged but this is
|
||||||
|
suppressed if the exception is an InvalidResponseError.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [Deferred] result of callback, if it succeeds
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if the chosen remote server returns a 300/400 code.
|
||||||
|
|
||||||
|
RuntimeError if no servers were reachable.
|
||||||
|
"""
|
||||||
|
for destination in destinations:
|
||||||
|
if destination == self.server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = yield callback(destination)
|
||||||
|
defer.returnValue(res)
|
||||||
|
except InvalidResponseError as e:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to %s via %s: %s",
|
||||||
|
description, destination, e,
|
||||||
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
if not 500 <= e.code < 600:
|
||||||
|
raise e.to_synapse_error()
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to %s via %s: %i %s",
|
||||||
|
description, destination, e.code, e.message,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to %s via %s",
|
||||||
|
description, destination, exc_info=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError("Failed to %s via any server", description)
|
||||||
|
|
||||||
def make_membership_event(self, destinations, room_id, user_id, membership,
|
def make_membership_event(self, destinations, room_id, user_id, membership,
|
||||||
content={},):
|
content={},):
|
||||||
"""
|
"""
|
||||||
|
@ -481,7 +543,7 @@ class FederationClient(FederationBase):
|
||||||
Deferred: resolves to a tuple of (origin (str), event (object))
|
Deferred: resolves to a tuple of (origin (str), event (object))
|
||||||
where origin is the remote homeserver which generated the event.
|
where origin is the remote homeserver which generated the event.
|
||||||
|
|
||||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
Fails with a ``SynapseError`` if the chosen remote server
|
||||||
returns a 300/400 code.
|
returns a 300/400 code.
|
||||||
|
|
||||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||||
|
@ -492,11 +554,9 @@ class FederationClient(FederationBase):
|
||||||
"make_membership_event called with membership='%s', must be one of %s" %
|
"make_membership_event called with membership='%s', must be one of %s" %
|
||||||
(membership, ",".join(valid_memberships))
|
(membership, ",".join(valid_memberships))
|
||||||
)
|
)
|
||||||
for destination in destinations:
|
|
||||||
if destination == self.server_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
@defer.inlineCallbacks
|
||||||
|
def send_request(destination):
|
||||||
ret = yield self.transport_layer.make_membership_event(
|
ret = yield self.transport_layer.make_membership_event(
|
||||||
destination, room_id, user_id, membership
|
destination, room_id, user_id, membership
|
||||||
)
|
)
|
||||||
|
@ -518,24 +578,11 @@ class FederationClient(FederationBase):
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(destination, ev)
|
(destination, ev)
|
||||||
)
|
)
|
||||||
break
|
|
||||||
except CodeMessageException as e:
|
return self._try_destination_list(
|
||||||
if not 500 <= e.code < 600:
|
"make_" + membership, destinations, send_request,
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to make_%s via %s: %s",
|
|
||||||
membership, destination, e.message
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to make_%s via %s: %s",
|
|
||||||
membership, destination, e.message
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def send_join(self, destinations, pdu):
|
def send_join(self, destinations, pdu):
|
||||||
"""Sends a join event to one of a list of homeservers.
|
"""Sends a join event to one of a list of homeservers.
|
||||||
|
|
||||||
|
@ -552,17 +599,14 @@ class FederationClient(FederationBase):
|
||||||
giving the serer the event was sent to, ``state`` (?) and
|
giving the serer the event was sent to, ``state`` (?) and
|
||||||
``auth_chain``.
|
``auth_chain``.
|
||||||
|
|
||||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
Fails with a ``SynapseError`` if the chosen remote server
|
||||||
returns a 300/400 code.
|
returns a 300/400 code.
|
||||||
|
|
||||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for destination in destinations:
|
@defer.inlineCallbacks
|
||||||
if destination == self.server_name:
|
def send_request(destination):
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
_, content = yield self.transport_layer.send_join(
|
_, content = yield self.transport_layer.send_join(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
|
@ -624,31 +668,22 @@ class FederationClient(FederationBase):
|
||||||
"auth_chain": signed_auth,
|
"auth_chain": signed_auth,
|
||||||
"origin": destination,
|
"origin": destination,
|
||||||
})
|
})
|
||||||
except CodeMessageException as e:
|
return self._try_destination_list("send_join", destinations, send_request)
|
||||||
if not 500 <= e.code < 600:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to send_join via %s: %s",
|
|
||||||
destination, e.message
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to send_join via %s: %s",
|
|
||||||
destination, e.message
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_invite(self, destination, room_id, event_id, pdu):
|
def send_invite(self, destination, room_id, event_id, pdu):
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
try:
|
||||||
code, content = yield self.transport_layer.send_invite(
|
code, content = yield self.transport_layer.send_invite(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_id=event_id,
|
event_id=event_id,
|
||||||
content=pdu.get_pdu_json(time_now),
|
content=pdu.get_pdu_json(time_now),
|
||||||
)
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
if e.code == 403:
|
||||||
|
raise e.to_synapse_error()
|
||||||
|
raise
|
||||||
|
|
||||||
pdu_dict = content["event"]
|
pdu_dict = content["event"]
|
||||||
|
|
||||||
|
@ -663,7 +698,6 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
defer.returnValue(pdu)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def send_leave(self, destinations, pdu):
|
def send_leave(self, destinations, pdu):
|
||||||
"""Sends a leave event to one of a list of homeservers.
|
"""Sends a leave event to one of a list of homeservers.
|
||||||
|
|
||||||
|
@ -680,16 +714,13 @@ class FederationClient(FederationBase):
|
||||||
Return:
|
Return:
|
||||||
Deferred: resolves to None.
|
Deferred: resolves to None.
|
||||||
|
|
||||||
Fails with a ``CodeMessageException`` if the chosen remote server
|
Fails with a ``SynapseError`` if the chosen remote server
|
||||||
returns a non-200 code.
|
returns a 300/400 code.
|
||||||
|
|
||||||
Fails with a ``RuntimeError`` if no servers were reachable.
|
Fails with a ``RuntimeError`` if no servers were reachable.
|
||||||
"""
|
"""
|
||||||
for destination in destinations:
|
@defer.inlineCallbacks
|
||||||
if destination == self.server_name:
|
def send_request(destination):
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
_, content = yield self.transport_layer.send_leave(
|
_, content = yield self.transport_layer.send_leave(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
|
@ -700,15 +731,8 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
logger.debug("Got content: %s", content)
|
logger.debug("Got content: %s", content)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
except CodeMessageException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to send_leave via %s: %s",
|
|
||||||
destination, e.message
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError("Failed to send to any server.")
|
return self._try_destination_list("send_leave", destinations, send_request)
|
||||||
|
|
||||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||||
search_filter=None, include_all_networks=False,
|
search_filter=None, include_all_networks=False,
|
||||||
|
|
|
@ -207,10 +207,6 @@ class FederationServer(FederationBase):
|
||||||
edu.content
|
edu.content
|
||||||
)
|
)
|
||||||
|
|
||||||
pdu_failures = getattr(transaction, "pdu_failures", [])
|
|
||||||
for fail in pdu_failures:
|
|
||||||
logger.info("Got failure %r", fail)
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"pdus": pdu_results,
|
"pdus": pdu_results,
|
||||||
}
|
}
|
||||||
|
@ -430,6 +426,7 @@ class FederationServer(FederationBase):
|
||||||
ret = yield self.handler.on_query_auth(
|
ret = yield self.handler.on_query_auth(
|
||||||
origin,
|
origin,
|
||||||
event_id,
|
event_id,
|
||||||
|
room_id,
|
||||||
signed_auth,
|
signed_auth,
|
||||||
content.get("rejects", []),
|
content.get("rejects", []),
|
||||||
content.get("missing", []),
|
content.get("missing", []),
|
||||||
|
|
|
@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
|
||||||
|
|
||||||
self.edus = SortedDict() # stream position -> Edu
|
self.edus = SortedDict() # stream position -> Edu
|
||||||
|
|
||||||
self.failures = SortedDict() # stream position -> (destination, Failure)
|
|
||||||
|
|
||||||
self.device_messages = SortedDict() # stream position -> destination
|
self.device_messages = SortedDict() # stream position -> destination
|
||||||
|
|
||||||
self.pos = 1
|
self.pos = 1
|
||||||
|
@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
|
||||||
|
|
||||||
for queue_name in [
|
for queue_name in [
|
||||||
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||||
"edus", "failures", "device_messages", "pos_time",
|
"edus", "device_messages", "pos_time",
|
||||||
]:
|
]:
|
||||||
register(queue_name, getattr(self, queue_name))
|
register(queue_name, getattr(self, queue_name))
|
||||||
|
|
||||||
|
@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
|
||||||
for key in keys[:i]:
|
for key in keys[:i]:
|
||||||
del self.edus[key]
|
del self.edus[key]
|
||||||
|
|
||||||
# Delete things out of failure map
|
|
||||||
keys = self.failures.keys()
|
|
||||||
i = self.failures.bisect_left(position_to_delete)
|
|
||||||
for key in keys[:i]:
|
|
||||||
del self.failures[key]
|
|
||||||
|
|
||||||
# Delete things out of device map
|
# Delete things out of device map
|
||||||
keys = self.device_messages.keys()
|
keys = self.device_messages.keys()
|
||||||
i = self.device_messages.bisect_left(position_to_delete)
|
i = self.device_messages.bisect_left(position_to_delete)
|
||||||
|
@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
|
||||||
|
|
||||||
self.notifier.on_new_replication_data()
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
|
||||||
"""As per TransactionQueue"""
|
|
||||||
pos = self._next_pos()
|
|
||||||
|
|
||||||
self.failures[pos] = (destination, str(failure))
|
|
||||||
self.notifier.on_new_replication_data()
|
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
|
@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
|
||||||
for (pos, edu) in edus:
|
for (pos, edu) in edus:
|
||||||
rows.append((pos, EduRow(edu)))
|
rows.append((pos, EduRow(edu)))
|
||||||
|
|
||||||
# Fetch changed failures
|
|
||||||
i = self.failures.bisect_right(from_token)
|
|
||||||
j = self.failures.bisect_right(to_token) + 1
|
|
||||||
failures = self.failures.items()[i:j]
|
|
||||||
|
|
||||||
for (pos, (destination, failure)) in failures:
|
|
||||||
rows.append((pos, FailureRow(
|
|
||||||
destination=destination,
|
|
||||||
failure=failure,
|
|
||||||
)))
|
|
||||||
|
|
||||||
# Fetch changed device messages
|
# Fetch changed device messages
|
||||||
i = self.device_messages.bisect_right(from_token)
|
i = self.device_messages.bisect_right(from_token)
|
||||||
j = self.device_messages.bisect_right(to_token) + 1
|
j = self.device_messages.bisect_right(to_token) + 1
|
||||||
|
@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
|
||||||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||||
|
|
||||||
|
|
||||||
class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
|
|
||||||
"destination", # str
|
|
||||||
"failure",
|
|
||||||
))):
|
|
||||||
"""Streams failures to a remote server. Failures are issued when there was
|
|
||||||
something wrong with a transaction the remote sent us, e.g. it included
|
|
||||||
an event that was invalid.
|
|
||||||
"""
|
|
||||||
|
|
||||||
TypeId = "f"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_data(data):
|
|
||||||
return FailureRow(
|
|
||||||
destination=data["destination"],
|
|
||||||
failure=data["failure"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_data(self):
|
|
||||||
return {
|
|
||||||
"destination": self.destination,
|
|
||||||
"failure": self.failure,
|
|
||||||
}
|
|
||||||
|
|
||||||
def add_to_buffer(self, buff):
|
|
||||||
buff.failures.setdefault(self.destination, []).append(self.failure)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
||||||
"destination", # str
|
"destination", # str
|
||||||
))):
|
))):
|
||||||
|
@ -471,7 +417,6 @@ TypeToRow = {
|
||||||
PresenceRow,
|
PresenceRow,
|
||||||
KeyedEduRow,
|
KeyedEduRow,
|
||||||
EduRow,
|
EduRow,
|
||||||
FailureRow,
|
|
||||||
DeviceRow,
|
DeviceRow,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
|
||||||
"presence", # list(UserPresenceState)
|
"presence", # list(UserPresenceState)
|
||||||
"keyed_edus", # dict of destination -> { key -> Edu }
|
"keyed_edus", # dict of destination -> { key -> Edu }
|
||||||
"edus", # dict of destination -> [Edu]
|
"edus", # dict of destination -> [Edu]
|
||||||
"failures", # dict of destination -> [failures]
|
|
||||||
"device_destinations", # set of destinations
|
"device_destinations", # set of destinations
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
|
||||||
presence=[],
|
presence=[],
|
||||||
keyed_edus={},
|
keyed_edus={},
|
||||||
edus={},
|
edus={},
|
||||||
failures={},
|
|
||||||
device_destinations=set(),
|
device_destinations=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
|
||||||
edu.destination, edu.edu_type, edu.content, key=None,
|
edu.destination, edu.edu_type, edu.content, key=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
for destination, failure_list in iteritems(buff.failures):
|
|
||||||
for failure in failure_list:
|
|
||||||
transaction_queue.send_failure(destination, failure)
|
|
||||||
|
|
||||||
for destination in buff.device_destinations:
|
for destination in buff.device_destinations:
|
||||||
transaction_queue.send_device_messages(destination)
|
transaction_queue.send_device_messages(destination)
|
||||||
|
|
|
@ -116,9 +116,6 @@ class TransactionQueue(object):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# destination -> list of tuple(failure, deferred)
|
|
||||||
self.pending_failures_by_dest = {}
|
|
||||||
|
|
||||||
# destination -> stream_id of last successfully sent to-device message.
|
# destination -> stream_id of last successfully sent to-device message.
|
||||||
# NB: may be a long or an int.
|
# NB: may be a long or an int.
|
||||||
self.last_device_stream_id_by_dest = {}
|
self.last_device_stream_id_by_dest = {}
|
||||||
|
@ -382,19 +379,6 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
self._attempt_new_transaction(destination)
|
self._attempt_new_transaction(destination)
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
|
||||||
if destination == self.server_name or destination == "localhost":
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.can_send_to(destination):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.pending_failures_by_dest.setdefault(
|
|
||||||
destination, []
|
|
||||||
).append(failure)
|
|
||||||
|
|
||||||
self._attempt_new_transaction(destination)
|
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
return
|
return
|
||||||
|
@ -469,7 +453,6 @@ class TransactionQueue(object):
|
||||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||||
pending_presence = self.pending_presence_by_dest.pop(destination, {})
|
pending_presence = self.pending_presence_by_dest.pop(destination, {})
|
||||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
|
||||||
|
|
||||||
pending_edus.extend(
|
pending_edus.extend(
|
||||||
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
|
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
|
||||||
|
@ -497,7 +480,7 @@ class TransactionQueue(object):
|
||||||
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||||
destination, len(pending_pdus))
|
destination, len(pending_pdus))
|
||||||
|
|
||||||
if not pending_pdus and not pending_edus and not pending_failures:
|
if not pending_pdus and not pending_edus:
|
||||||
logger.debug("TX [%s] Nothing to send", destination)
|
logger.debug("TX [%s] Nothing to send", destination)
|
||||||
self.last_device_stream_id_by_dest[destination] = (
|
self.last_device_stream_id_by_dest[destination] = (
|
||||||
device_stream_id
|
device_stream_id
|
||||||
|
@ -507,7 +490,7 @@ class TransactionQueue(object):
|
||||||
# END CRITICAL SECTION
|
# END CRITICAL SECTION
|
||||||
|
|
||||||
success = yield self._send_new_transaction(
|
success = yield self._send_new_transaction(
|
||||||
destination, pending_pdus, pending_edus, pending_failures,
|
destination, pending_pdus, pending_edus,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
sent_transactions_counter.inc()
|
sent_transactions_counter.inc()
|
||||||
|
@ -584,14 +567,12 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
@measure_func("_send_new_transaction")
|
@measure_func("_send_new_transaction")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
def _send_new_transaction(self, destination, pending_pdus, pending_edus):
|
||||||
pending_failures):
|
|
||||||
|
|
||||||
# Sort based on the order field
|
# Sort based on the order field
|
||||||
pending_pdus.sort(key=lambda t: t[1])
|
pending_pdus.sort(key=lambda t: t[1])
|
||||||
pdus = [x[0] for x in pending_pdus]
|
pdus = [x[0] for x in pending_pdus]
|
||||||
edus = pending_edus
|
edus = pending_edus
|
||||||
failures = [x.get_dict() for x in pending_failures]
|
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
|
@ -601,11 +582,10 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"TX [%s] {%s} Attempting new transaction"
|
"TX [%s] {%s} Attempting new transaction"
|
||||||
" (pdus: %d, edus: %d, failures: %d)",
|
" (pdus: %d, edus: %d)",
|
||||||
destination, txn_id,
|
destination, txn_id,
|
||||||
len(pdus),
|
len(pdus),
|
||||||
len(edus),
|
len(edus),
|
||||||
len(failures)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||||
|
@ -617,7 +597,6 @@ class TransactionQueue(object):
|
||||||
destination=destination,
|
destination=destination,
|
||||||
pdus=pdus,
|
pdus=pdus,
|
||||||
edus=edus,
|
edus=edus,
|
||||||
pdu_failures=failures,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._next_txn_id += 1
|
self._next_txn_id += 1
|
||||||
|
@ -627,12 +606,11 @@ class TransactionQueue(object):
|
||||||
logger.debug("TX [%s] Persisted transaction", destination)
|
logger.debug("TX [%s] Persisted transaction", destination)
|
||||||
logger.info(
|
logger.info(
|
||||||
"TX [%s] {%s} Sending transaction [%s],"
|
"TX [%s] {%s} Sending transaction [%s],"
|
||||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
" (PDUs: %d, EDUs: %d)",
|
||||||
destination, txn_id,
|
destination, txn_id,
|
||||||
transaction.transaction_id,
|
transaction.transaction_id,
|
||||||
len(pdus),
|
len(pdus),
|
||||||
len(edus),
|
len(edus),
|
||||||
len(failures),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually send the transaction
|
# Actually send the transaction
|
||||||
|
|
|
@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
|
||||||
param_dict = dict(kv.split("=") for kv in params)
|
param_dict = dict(kv.split("=") for kv in params)
|
||||||
|
|
||||||
def strip_quotes(value):
|
def strip_quotes(value):
|
||||||
if value.startswith(b"\""):
|
if value.startswith("\""):
|
||||||
return value[1:-1]
|
return value[1:-1]
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
|
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
|
||||||
transaction_id, origin,
|
transaction_id, origin,
|
||||||
len(transaction_data.get("pdus", [])),
|
len(transaction_data.get("pdus", [])),
|
||||||
len(transaction_data.get("edus", [])),
|
len(transaction_data.get("edus", [])),
|
||||||
len(transaction_data.get("failures", [])),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should ideally be getting this from the security layer.
|
# We should ideally be getting this from the security layer.
|
||||||
|
|
|
@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
|
||||||
"previous_ids",
|
"previous_ids",
|
||||||
"pdus",
|
"pdus",
|
||||||
"edus",
|
"edus",
|
||||||
"pdu_failures",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
internal_keys = [
|
internal_keys = [
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
access_token = yield self.issue_access_token(user_id, device_id)
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
|
yield self._check_mau_limits()
|
||||||
|
|
||||||
# the device *should* have been registered before we got here; however,
|
# the device *should* have been registered before we got here; however,
|
||||||
# it's possible we raced against a DELETE operation. The thing we
|
# it's possible we raced against a DELETE operation. The thing we
|
||||||
|
@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
|
||||||
# special case to check for "password" for the check_password interface
|
# special case to check for "password" for the check_password interface
|
||||||
# for the auth providers
|
# for the auth providers
|
||||||
password = login_submission.get("password")
|
password = login_submission.get("password")
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD:
|
if login_type == LoginType.PASSWORD:
|
||||||
if not self._password_enabled:
|
if not self._password_enabled:
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
|
@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
|
||||||
multiple inexact matches.
|
multiple inexact matches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): complete @user:id
|
user_id (unicode): complete @user:id
|
||||||
|
password (unicode): the provided password
|
||||||
Returns:
|
Returns:
|
||||||
(str) the canonical_user_id, or None if unknown user / bad password
|
(unicode) the canonical_user_id, or None if unknown user / bad password
|
||||||
"""
|
"""
|
||||||
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
if not lookupres:
|
if not lookupres:
|
||||||
|
@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
|
||||||
device_id)
|
device_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||||
|
yield self._check_mau_limits()
|
||||||
auth_api = self.hs.get_auth()
|
auth_api = self.hs.get_auth()
|
||||||
|
user_id = None
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||||
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
||||||
return user_id
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token):
|
||||||
|
@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (str): Password to hash.
|
password (unicode): Password to hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(str): Hashed password.
|
Deferred(unicode): Hashed password.
|
||||||
"""
|
"""
|
||||||
def _do_hash():
|
def _do_hash():
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
# Normalise the Unicode in the password
|
||||||
bcrypt.gensalt(self.bcrypt_rounds))
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
|
return bcrypt.hashpw(
|
||||||
|
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||||
|
bcrypt.gensalt(self.bcrypt_rounds),
|
||||||
|
).decode('ascii')
|
||||||
|
|
||||||
return make_deferred_yieldable(
|
return make_deferred_yieldable(
|
||||||
threads.deferToThreadPool(
|
threads.deferToThreadPool(
|
||||||
|
@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (str): Password to hash.
|
password (unicode): Password to hash.
|
||||||
stored_hash (str): Expected hash value.
|
stored_hash (unicode): Expected hash value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _do_validate_hash():
|
def _do_validate_hash():
|
||||||
|
# Normalise the Unicode in the password
|
||||||
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
password.encode('utf8') + self.hs.config.password_pepper,
|
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||||
stored_hash.encode('utf8')
|
stored_hash.encode('utf8')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
return defer.succeed(False)
|
return defer.succeed(False)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_mau_limits(self):
|
||||||
|
"""
|
||||||
|
Ensure that if mau blocking is enabled that invalid users cannot
|
||||||
|
log in.
|
||||||
|
"""
|
||||||
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
|
current_mau = yield self.store.count_monthly_users()
|
||||||
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
|
raise AuthError(
|
||||||
|
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class MacaroonGenerator(object):
|
class MacaroonGenerator(object):
|
||||||
|
|
|
@ -19,10 +19,12 @@ import random
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
from synapse.api.errors import AuthError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -129,11 +131,13 @@ class EventStreamHandler(BaseHandler):
|
||||||
class EventHandler(BaseHandler):
|
class EventHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_event(self, user, event_id):
|
def get_event(self, user, room_id, event_id):
|
||||||
"""Retrieve a single specified event.
|
"""Retrieve a single specified event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (synapse.types.UserID): The user requesting the event
|
user (synapse.types.UserID): The user requesting the event
|
||||||
|
room_id (str|None): The expected room id. We'll return None if the
|
||||||
|
event's room does not match.
|
||||||
event_id (str): The event ID to obtain.
|
event_id (str): The event ID to obtain.
|
||||||
Returns:
|
Returns:
|
||||||
dict: An event, or None if there is no event matching this ID.
|
dict: An event, or None if there is no event matching this ID.
|
||||||
|
@ -142,13 +146,26 @@ class EventHandler(BaseHandler):
|
||||||
AuthError if the user does not have the rights to inspect this
|
AuthError if the user does not have the rights to inspect this
|
||||||
event.
|
event.
|
||||||
"""
|
"""
|
||||||
event = yield self.store.get_event(event_id)
|
event = yield self.store.get_event(event_id, check_room_id=room_id)
|
||||||
|
|
||||||
if not event:
|
if not event:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
return
|
return
|
||||||
|
|
||||||
if hasattr(event, "room_id"):
|
users = yield self.store.get_users_in_room(event.room_id)
|
||||||
yield self.auth.check_joined_room(event.room_id, user.to_string())
|
is_peeking = user.to_string() not in users
|
||||||
|
|
||||||
|
filtered = yield filter_events_for_client(
|
||||||
|
self.store,
|
||||||
|
user.to_string(),
|
||||||
|
[event],
|
||||||
|
is_peeking=is_peeking
|
||||||
|
)
|
||||||
|
|
||||||
|
if not filtered:
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"You don't have permission to access that event."
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
|
@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.replication_layer = hs.get_federation_client()
|
self.federation_client = hs.get_federation_client()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
|
@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
|
||||||
# know about
|
# know about
|
||||||
for p in prevs - seen:
|
for p in prevs - seen:
|
||||||
state, got_auth_chain = (
|
state, got_auth_chain = (
|
||||||
yield self.replication_layer.get_state_for_room(
|
yield self.federation_client.get_state_for_room(
|
||||||
origin, pdu.room_id, p
|
origin, pdu.room_id, p
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
|
||||||
#
|
#
|
||||||
# see https://github.com/matrix-org/synapse/pull/1744
|
# see https://github.com/matrix-org/synapse/pull/1744
|
||||||
|
|
||||||
missing_events = yield self.replication_layer.get_missing_events(
|
missing_events = yield self.federation_client.get_missing_events(
|
||||||
origin,
|
origin,
|
||||||
pdu.room_id,
|
pdu.room_id,
|
||||||
earliest_events_ids=list(latest),
|
earliest_events_ids=list(latest),
|
||||||
|
@ -400,7 +400,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
yield self._persist_auth_tree(
|
||||||
origin, auth_chain, state, event
|
origin, auth_chain, state, event
|
||||||
)
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
|
@ -444,7 +444,7 @@ class FederationHandler(BaseHandler):
|
||||||
yield self._handle_new_events(origin, event_infos)
|
yield self._handle_new_events(origin, event_infos)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
context = yield self._handle_new_event(
|
||||||
origin,
|
origin,
|
||||||
event,
|
event,
|
||||||
state=state,
|
state=state,
|
||||||
|
@ -469,17 +469,6 @@ class FederationHandler(BaseHandler):
|
||||||
except StoreError:
|
except StoreError:
|
||||||
logger.exception("Failed to store room.")
|
logger.exception("Failed to store room.")
|
||||||
|
|
||||||
extra_users = []
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
target_user_id = event.state_key
|
|
||||||
target_user = UserID.from_string(target_user_id)
|
|
||||||
extra_users.append(target_user)
|
|
||||||
|
|
||||||
self.notifier.on_new_room_event(
|
|
||||||
event, event_stream_id, max_stream_id,
|
|
||||||
extra_users=extra_users
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
# Only fire user_joined_room if the user has acutally
|
# Only fire user_joined_room if the user has acutally
|
||||||
|
@ -501,7 +490,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if newly_joined:
|
if newly_joined:
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield self.user_joined_room(user, event.room_id)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -522,7 +511,7 @@ class FederationHandler(BaseHandler):
|
||||||
if dest == self.server_name:
|
if dest == self.server_name:
|
||||||
raise SynapseError(400, "Can't backfill from self.")
|
raise SynapseError(400, "Can't backfill from self.")
|
||||||
|
|
||||||
events = yield self.replication_layer.backfill(
|
events = yield self.federation_client.backfill(
|
||||||
dest,
|
dest,
|
||||||
room_id,
|
room_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
@ -570,7 +559,7 @@ class FederationHandler(BaseHandler):
|
||||||
state_events = {}
|
state_events = {}
|
||||||
events_to_state = {}
|
events_to_state = {}
|
||||||
for e_id in edges:
|
for e_id in edges:
|
||||||
state, auth = yield self.replication_layer.get_state_for_room(
|
state, auth = yield self.federation_client.get_state_for_room(
|
||||||
destination=dest,
|
destination=dest,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_id=e_id
|
event_id=e_id
|
||||||
|
@ -612,7 +601,7 @@ class FederationHandler(BaseHandler):
|
||||||
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
[
|
[
|
||||||
logcontext.run_in_background(
|
logcontext.run_in_background(
|
||||||
self.replication_layer.get_pdu,
|
self.federation_client.get_pdu,
|
||||||
[dest],
|
[dest],
|
||||||
event_id,
|
event_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
|
@ -893,7 +882,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
Invites must be signed by the invitee's server before distribution.
|
Invites must be signed by the invitee's server before distribution.
|
||||||
"""
|
"""
|
||||||
pdu = yield self.replication_layer.send_invite(
|
pdu = yield self.federation_client.send_invite(
|
||||||
destination=target_host,
|
destination=target_host,
|
||||||
room_id=event.room_id,
|
room_id=event.room_id,
|
||||||
event_id=event.event_id,
|
event_id=event.event_id,
|
||||||
|
@ -942,7 +931,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
self.room_queues[room_id] = []
|
self.room_queues[room_id] = []
|
||||||
|
|
||||||
yield self.store.clean_room_for_join(room_id)
|
yield self._clean_room_for_join(room_id)
|
||||||
|
|
||||||
handled_events = set()
|
handled_events = set()
|
||||||
|
|
||||||
|
@ -955,7 +944,7 @@ class FederationHandler(BaseHandler):
|
||||||
target_hosts.insert(0, origin)
|
target_hosts.insert(0, origin)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
ret = yield self.replication_layer.send_join(target_hosts, event)
|
ret = yield self.federation_client.send_join(target_hosts, event)
|
||||||
|
|
||||||
origin = ret["origin"]
|
origin = ret["origin"]
|
||||||
state = ret["state"]
|
state = ret["state"]
|
||||||
|
@ -981,15 +970,10 @@ class FederationHandler(BaseHandler):
|
||||||
# FIXME
|
# FIXME
|
||||||
pass
|
pass
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
yield self._persist_auth_tree(
|
||||||
origin, auth_chain, state, event
|
origin, auth_chain, state, event
|
||||||
)
|
)
|
||||||
|
|
||||||
self.notifier.on_new_room_event(
|
|
||||||
event, event_stream_id, max_stream_id,
|
|
||||||
extra_users=[joinee]
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||||
finally:
|
finally:
|
||||||
room_queue = self.room_queues[room_id]
|
room_queue = self.room_queues[room_id]
|
||||||
|
@ -1084,7 +1068,7 @@ class FederationHandler(BaseHandler):
|
||||||
# would introduce the danger of backwards-compatibility problems.
|
# would introduce the danger of backwards-compatibility problems.
|
||||||
event.internal_metadata.send_on_behalf_of = origin
|
event.internal_metadata.send_on_behalf_of = origin
|
||||||
|
|
||||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
context = yield self._handle_new_event(
|
||||||
origin, event
|
origin, event
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1094,20 +1078,10 @@ class FederationHandler(BaseHandler):
|
||||||
event.signatures,
|
event.signatures,
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_users = []
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
target_user_id = event.state_key
|
|
||||||
target_user = UserID.from_string(target_user_id)
|
|
||||||
extra_users.append(target_user)
|
|
||||||
|
|
||||||
self.notifier.on_new_room_event(
|
|
||||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.content["membership"] == Membership.JOIN:
|
if event.content["membership"] == Membership.JOIN:
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield self.user_joined_room(user, event.room_id)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
|
||||||
|
@ -1176,17 +1150,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(event)
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
|
yield self._persist_events([(event, context)])
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
target_user = UserID.from_string(event.state_key)
|
|
||||||
self.notifier.on_new_room_event(
|
|
||||||
event, event_stream_id, max_stream_id,
|
|
||||||
extra_users=[target_user],
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
|
@ -1211,30 +1175,20 @@ class FederationHandler(BaseHandler):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
yield self.replication_layer.send_leave(
|
yield self.federation_client.send_leave(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
event
|
event
|
||||||
)
|
)
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(event)
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
|
yield self._persist_events([(event, context)])
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
target_user = UserID.from_string(event.state_key)
|
|
||||||
self.notifier.on_new_room_event(
|
|
||||||
event, event_stream_id, max_stream_id,
|
|
||||||
extra_users=[target_user],
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
||||||
content={},):
|
content={},):
|
||||||
origin, pdu = yield self.replication_layer.make_membership_event(
|
origin, pdu = yield self.federation_client.make_membership_event(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
room_id,
|
room_id,
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -1318,7 +1272,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
event.internal_metadata.outlier = False
|
event.internal_metadata.outlier = False
|
||||||
|
|
||||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
yield self._handle_new_event(
|
||||||
origin, event
|
origin, event
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1328,22 +1282,17 @@ class FederationHandler(BaseHandler):
|
||||||
event.signatures,
|
event.signatures,
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_users = []
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
target_user_id = event.state_key
|
|
||||||
target_user = UserID.from_string(target_user_id)
|
|
||||||
extra_users.append(target_user)
|
|
||||||
|
|
||||||
self.notifier.on_new_room_event(
|
|
||||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_for_pdu(self, room_id, event_id):
|
def get_state_for_pdu(self, room_id, event_id):
|
||||||
"""Returns the state at the event. i.e. not including said event.
|
"""Returns the state at the event. i.e. not including said event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
event = yield self.store.get_event(
|
||||||
|
event_id, allow_none=False, check_room_id=room_id,
|
||||||
|
)
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
room_id, [event_id]
|
room_id, [event_id]
|
||||||
)
|
)
|
||||||
|
@ -1354,8 +1303,7 @@ class FederationHandler(BaseHandler):
|
||||||
(e.type, e.state_key): e for e in state
|
(e.type, e.state_key): e for e in state
|
||||||
}
|
}
|
||||||
|
|
||||||
event = yield self.store.get_event(event_id)
|
if event.is_state():
|
||||||
if event and event.is_state():
|
|
||||||
# Get previous state
|
# Get previous state
|
||||||
if "replaces_state" in event.unsigned:
|
if "replaces_state" in event.unsigned:
|
||||||
prev_id = event.unsigned["replaces_state"]
|
prev_id = event.unsigned["replaces_state"]
|
||||||
|
@ -1374,6 +1322,10 @@ class FederationHandler(BaseHandler):
|
||||||
def get_state_ids_for_pdu(self, room_id, event_id):
|
def get_state_ids_for_pdu(self, room_id, event_id):
|
||||||
"""Returns the state at the event. i.e. not including said event.
|
"""Returns the state at the event. i.e. not including said event.
|
||||||
"""
|
"""
|
||||||
|
event = yield self.store.get_event(
|
||||||
|
event_id, allow_none=False, check_room_id=room_id,
|
||||||
|
)
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups_ids(
|
state_groups = yield self.store.get_state_groups_ids(
|
||||||
room_id, [event_id]
|
room_id, [event_id]
|
||||||
)
|
)
|
||||||
|
@ -1382,8 +1334,7 @@ class FederationHandler(BaseHandler):
|
||||||
_, state = state_groups.items().pop()
|
_, state = state_groups.items().pop()
|
||||||
results = state
|
results = state
|
||||||
|
|
||||||
event = yield self.store.get_event(event_id)
|
if event.is_state():
|
||||||
if event and event.is_state():
|
|
||||||
# Get previous state
|
# Get previous state
|
||||||
if "replaces_state" in event.unsigned:
|
if "replaces_state" in event.unsigned:
|
||||||
prev_id = event.unsigned["replaces_state"]
|
prev_id = event.unsigned["replaces_state"]
|
||||||
|
@ -1472,9 +1423,8 @@ class FederationHandler(BaseHandler):
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
yield self._persist_events(
|
||||||
event,
|
[(event, context)],
|
||||||
context=context,
|
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
except: # noqa: E722, as we reraise the exception this is fine.
|
except: # noqa: E722, as we reraise the exception this is fine.
|
||||||
|
@ -1487,15 +1437,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
six.reraise(tp, value, tb)
|
six.reraise(tp, value, tb)
|
||||||
|
|
||||||
if not backfilled:
|
defer.returnValue(context)
|
||||||
# this intentionally does not yield: we don't care about the result
|
|
||||||
# and don't need to wait for it.
|
|
||||||
logcontext.run_in_background(
|
|
||||||
self.pusher_pool.on_new_notifications,
|
|
||||||
event_stream_id, max_stream_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_new_events(self, origin, event_infos, backfilled=False):
|
def _handle_new_events(self, origin, event_infos, backfilled=False):
|
||||||
|
@ -1503,6 +1445,8 @@ class FederationHandler(BaseHandler):
|
||||||
should not depend on one another, e.g. this should be used to persist
|
should not depend on one another, e.g. this should be used to persist
|
||||||
a bunch of outliers, but not a chunk of individual events that depend
|
a bunch of outliers, but not a chunk of individual events that depend
|
||||||
on each other for state calculations.
|
on each other for state calculations.
|
||||||
|
|
||||||
|
Notifies about the events where appropriate.
|
||||||
"""
|
"""
|
||||||
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
[
|
[
|
||||||
|
@ -1517,7 +1461,7 @@ class FederationHandler(BaseHandler):
|
||||||
], consumeErrors=True,
|
], consumeErrors=True,
|
||||||
))
|
))
|
||||||
|
|
||||||
yield self.store.persist_events(
|
yield self._persist_events(
|
||||||
[
|
[
|
||||||
(ev_info["event"], context)
|
(ev_info["event"], context)
|
||||||
for ev_info, context in zip(event_infos, contexts)
|
for ev_info, context in zip(event_infos, contexts)
|
||||||
|
@ -1529,7 +1473,8 @@ class FederationHandler(BaseHandler):
|
||||||
def _persist_auth_tree(self, origin, auth_events, state, event):
|
def _persist_auth_tree(self, origin, auth_events, state, event):
|
||||||
"""Checks the auth chain is valid (and passes auth checks) for the
|
"""Checks the auth chain is valid (and passes auth checks) for the
|
||||||
state and event. Then persists the auth chain and state atomically.
|
state and event. Then persists the auth chain and state atomically.
|
||||||
Persists the event seperately.
|
Persists the event separately. Notifies about the persisted events
|
||||||
|
where appropriate.
|
||||||
|
|
||||||
Will attempt to fetch missing auth events.
|
Will attempt to fetch missing auth events.
|
||||||
|
|
||||||
|
@ -1540,8 +1485,7 @@ class FederationHandler(BaseHandler):
|
||||||
event (Event)
|
event (Event)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2-tuple of (event_stream_id, max_stream_id) from the persist_event
|
Deferred
|
||||||
call for `event`
|
|
||||||
"""
|
"""
|
||||||
events_to_context = {}
|
events_to_context = {}
|
||||||
for e in itertools.chain(auth_events, state):
|
for e in itertools.chain(auth_events, state):
|
||||||
|
@ -1567,7 +1511,7 @@ class FederationHandler(BaseHandler):
|
||||||
missing_auth_events.add(e_id)
|
missing_auth_events.add(e_id)
|
||||||
|
|
||||||
for e_id in missing_auth_events:
|
for e_id in missing_auth_events:
|
||||||
m_ev = yield self.replication_layer.get_pdu(
|
m_ev = yield self.federation_client.get_pdu(
|
||||||
[origin],
|
[origin],
|
||||||
e_id,
|
e_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
|
@ -1605,7 +1549,7 @@ class FederationHandler(BaseHandler):
|
||||||
raise
|
raise
|
||||||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
yield self.store.persist_events(
|
yield self._persist_events(
|
||||||
[
|
[
|
||||||
(e, events_to_context[e.event_id])
|
(e, events_to_context[e.event_id])
|
||||||
for e in itertools.chain(auth_events, state)
|
for e in itertools.chain(auth_events, state)
|
||||||
|
@ -1616,12 +1560,10 @@ class FederationHandler(BaseHandler):
|
||||||
event, old_state=state
|
event, old_state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
yield self._persist_events(
|
||||||
event, new_event_context,
|
[(event, new_event_context)],
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((event_stream_id, max_stream_id))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _prep_event(self, origin, event, state=None, auth_events=None):
|
def _prep_event(self, origin, event, state=None, auth_events=None):
|
||||||
"""
|
"""
|
||||||
|
@ -1678,8 +1620,19 @@ class FederationHandler(BaseHandler):
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
|
||||||
missing):
|
missing):
|
||||||
|
in_room = yield self.auth.check_host_in_room(
|
||||||
|
room_id,
|
||||||
|
origin
|
||||||
|
)
|
||||||
|
if not in_room:
|
||||||
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
event = yield self.store.get_event(
|
||||||
|
event_id, allow_none=False, check_room_id=room_id
|
||||||
|
)
|
||||||
|
|
||||||
# Just go through and process each event in `remote_auth_chain`. We
|
# Just go through and process each event in `remote_auth_chain`. We
|
||||||
# don't want to fall into the trap of `missing` being wrong.
|
# don't want to fall into the trap of `missing` being wrong.
|
||||||
for e in remote_auth_chain:
|
for e in remote_auth_chain:
|
||||||
|
@ -1689,7 +1642,6 @@ class FederationHandler(BaseHandler):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Now get the current auth_chain for the event.
|
# Now get the current auth_chain for the event.
|
||||||
event = yield self.store.get_event(event_id)
|
|
||||||
local_auth_chain = yield self.store.get_auth_chain(
|
local_auth_chain = yield self.store.get_auth_chain(
|
||||||
[auth_id for auth_id, _ in event.auth_events],
|
[auth_id for auth_id, _ in event.auth_events],
|
||||||
include_given=True
|
include_given=True
|
||||||
|
@ -1777,7 +1729,7 @@ class FederationHandler(BaseHandler):
|
||||||
logger.info("Missing auth: %s", missing_auth)
|
logger.info("Missing auth: %s", missing_auth)
|
||||||
# If we don't have all the auth events, we need to get them.
|
# If we don't have all the auth events, we need to get them.
|
||||||
try:
|
try:
|
||||||
remote_auth_chain = yield self.replication_layer.get_event_auth(
|
remote_auth_chain = yield self.federation_client.get_event_auth(
|
||||||
origin, event.room_id, event.event_id
|
origin, event.room_id, event.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1893,7 +1845,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 2. Get remote difference.
|
# 2. Get remote difference.
|
||||||
result = yield self.replication_layer.query_auth(
|
result = yield self.federation_client.query_auth(
|
||||||
origin,
|
origin,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
event.event_id,
|
event.event_id,
|
||||||
|
@ -2192,7 +2144,7 @@ class FederationHandler(BaseHandler):
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
||||||
yield self.replication_layer.forward_third_party_invite(
|
yield self.federation_client.forward_third_party_invite(
|
||||||
destinations,
|
destinations,
|
||||||
room_id,
|
room_id,
|
||||||
event_dict,
|
event_dict,
|
||||||
|
@ -2347,3 +2299,69 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
if "valid" not in response or not response["valid"]:
|
if "valid" not in response or not response["valid"]:
|
||||||
raise AuthError(403, "Third party certificate was invalid")
|
raise AuthError(403, "Third party certificate was invalid")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _persist_events(self, event_and_contexts, backfilled=False):
|
||||||
|
"""Persists events and tells the notifier/pushers about them, if
|
||||||
|
necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_and_contexts(list[tuple[FrozenEvent, EventContext]])
|
||||||
|
backfilled (bool): Whether these events are a result of
|
||||||
|
backfilling or not
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
max_stream_id = yield self.store.persist_events(
|
||||||
|
event_and_contexts,
|
||||||
|
backfilled=backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not backfilled: # Never notify for backfilled events
|
||||||
|
for event, _ in event_and_contexts:
|
||||||
|
self._notify_persisted_event(event, max_stream_id)
|
||||||
|
|
||||||
|
def _notify_persisted_event(self, event, max_stream_id):
|
||||||
|
"""Checks to see if notifier/pushers should be notified about the
|
||||||
|
event or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (FrozenEvent)
|
||||||
|
max_stream_id (int): The max_stream_id returned by persist_events
|
||||||
|
"""
|
||||||
|
|
||||||
|
extra_users = []
|
||||||
|
if event.type == EventTypes.Member:
|
||||||
|
target_user_id = event.state_key
|
||||||
|
|
||||||
|
# We notify for memberships if its an invite for one of our
|
||||||
|
# users
|
||||||
|
if event.internal_metadata.is_outlier():
|
||||||
|
if event.membership != Membership.INVITE:
|
||||||
|
if not self.is_mine_id(target_user_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
target_user = UserID.from_string(target_user_id)
|
||||||
|
extra_users.append(target_user)
|
||||||
|
elif event.internal_metadata.is_outlier():
|
||||||
|
return
|
||||||
|
|
||||||
|
event_stream_id = event.internal_metadata.stream_ordering
|
||||||
|
self.notifier.on_new_room_event(
|
||||||
|
event, event_stream_id, max_stream_id,
|
||||||
|
extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
|
logcontext.run_in_background(
|
||||||
|
self.pusher_pool.on_new_notifications,
|
||||||
|
event_stream_id, max_stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_room_for_join(self, room_id):
|
||||||
|
return self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
|
def user_joined_room(self, user, room_id):
|
||||||
|
"""Called when a new user has joined the room
|
||||||
|
"""
|
||||||
|
return user_joined_room(self.distributor, user, room_id)
|
||||||
|
|
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
Codes,
|
Codes,
|
||||||
MatrixCodeMessageException,
|
HttpResponseException,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,7 +85,6 @@ class IdentityHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
data = {}
|
|
||||||
try:
|
try:
|
||||||
data = yield self.http_client.get_json(
|
data = yield self.http_client.get_json(
|
||||||
"https://%s%s" % (
|
"https://%s%s" % (
|
||||||
|
@ -94,11 +93,9 @@ class IdentityHandler(BaseHandler):
|
||||||
),
|
),
|
||||||
{'sid': creds['sid'], 'client_secret': client_secret}
|
{'sid': creds['sid'], 'client_secret': client_secret}
|
||||||
)
|
)
|
||||||
except MatrixCodeMessageException as e:
|
except HttpResponseException as e:
|
||||||
logger.info("getValidated3pid failed with Matrix error: %r", e)
|
logger.info("getValidated3pid failed with Matrix error: %r", e)
|
||||||
raise SynapseError(e.code, e.msg, e.errcode)
|
raise e.to_synapse_error()
|
||||||
except CodeMessageException as e:
|
|
||||||
data = json.loads(e.msg)
|
|
||||||
|
|
||||||
if 'medium' in data:
|
if 'medium' in data:
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
@ -136,7 +133,7 @@ class IdentityHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
logger.debug("bound threepid %r to %s", creds, mxid)
|
logger.debug("bound threepid %r to %s", creds, mxid)
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
data = json.loads(e.msg)
|
data = json.loads(e.msg) # XXX WAT?
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -209,12 +206,9 @@ class IdentityHandler(BaseHandler):
|
||||||
params
|
params
|
||||||
)
|
)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
except MatrixCodeMessageException as e:
|
except HttpResponseException as e:
|
||||||
logger.info("Proxied requestToken failed with Matrix error: %r", e)
|
|
||||||
raise SynapseError(e.code, e.msg, e.errcode)
|
|
||||||
except CodeMessageException as e:
|
|
||||||
logger.info("Proxied requestToken failed: %r", e)
|
logger.info("Proxied requestToken failed: %r", e)
|
||||||
raise e
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def requestMsisdnToken(
|
def requestMsisdnToken(
|
||||||
|
@ -244,9 +238,6 @@ class IdentityHandler(BaseHandler):
|
||||||
params
|
params
|
||||||
)
|
)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
except MatrixCodeMessageException as e:
|
except HttpResponseException as e:
|
||||||
logger.info("Proxied requestToken failed with Matrix error: %r", e)
|
|
||||||
raise SynapseError(e.code, e.msg, e.errcode)
|
|
||||||
except CodeMessageException as e:
|
|
||||||
logger.info("Proxied requestToken failed: %r", e)
|
logger.info("Proxied requestToken failed: %r", e)
|
||||||
raise e
|
raise e.to_synapse_error()
|
||||||
|
|
|
@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
hs (synapse.server.HomeServer):
|
hs (synapse.server.HomeServer):
|
||||||
"""
|
"""
|
||||||
super(RegistrationHandler, self).__init__(hs)
|
super(RegistrationHandler, self).__init__(hs)
|
||||||
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
Args:
|
Args:
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be generated.
|
one will be generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (unicode) : The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
generate_token (bool): Whether a new access token should be
|
generate_token (bool): Whether a new access token should be
|
||||||
|
@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
RegistrationError if there was a problem registering.
|
RegistrationError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
|
yield self._check_mau_limits()
|
||||||
password_hash = None
|
password_hash = None
|
||||||
if password:
|
if password:
|
||||||
password_hash = yield self.auth_handler().hash(password)
|
password_hash = yield self.auth_handler().hash(password)
|
||||||
|
@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
400,
|
400,
|
||||||
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
||||||
)
|
)
|
||||||
|
yield self._check_mau_limits()
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
if localpart is None:
|
if localpart is None:
|
||||||
raise SynapseError(400, "Request must include user id")
|
raise SynapseError(400, "Request must include user id")
|
||||||
|
yield self._check_mau_limits()
|
||||||
need_register = True
|
need_register = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
|
||||||
remote_room_hosts=remote_room_hosts,
|
remote_room_hosts=remote_room_hosts,
|
||||||
action="join",
|
action="join",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_mau_limits(self):
|
||||||
|
"""
|
||||||
|
Do not accept registrations if monthly active user limits exceeded
|
||||||
|
and limiting is enabled
|
||||||
|
"""
|
||||||
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
|
current_mau = yield self.store.count_monthly_users()
|
||||||
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
|
raise RegistrationError(
|
||||||
|
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
|
||||||
|
)
|
||||||
|
|
|
@ -39,12 +39,7 @@ from twisted.web.client import (
|
||||||
from twisted.web.http import PotentialDataLoss
|
from twisted.web.http import PotentialDataLoss
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
CodeMessageException,
|
|
||||||
Codes,
|
|
||||||
MatrixCodeMessageException,
|
|
||||||
SynapseError,
|
|
||||||
)
|
|
||||||
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
|
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
|
||||||
from synapse.http.endpoint import SpiderEndpoint
|
from synapse.http.endpoint import SpiderEndpoint
|
||||||
from synapse.util.async import add_timeout_to_deferred
|
from synapse.util.async import add_timeout_to_deferred
|
||||||
|
@ -132,6 +127,11 @@ class SimpleHttpClient(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[object]: parsed json
|
Deferred[object]: parsed json
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HttpResponseException: On a non-2xx HTTP response.
|
||||||
|
|
||||||
|
ValueError: if the response was not JSON
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Do we ever want to log message contents?
|
# TODO: Do we ever want to log message contents?
|
||||||
|
@ -155,7 +155,10 @@ class SimpleHttpClient(object):
|
||||||
|
|
||||||
body = yield make_deferred_yieldable(readBody(response))
|
body = yield make_deferred_yieldable(readBody(response))
|
||||||
|
|
||||||
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
else:
|
||||||
|
raise HttpResponseException(response.code, response.phrase, body)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_json_get_json(self, uri, post_json, headers=None):
|
def post_json_get_json(self, uri, post_json, headers=None):
|
||||||
|
@ -169,6 +172,11 @@ class SimpleHttpClient(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[object]: parsed json
|
Deferred[object]: parsed json
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HttpResponseException: On a non-2xx HTTP response.
|
||||||
|
|
||||||
|
ValueError: if the response was not JSON
|
||||||
"""
|
"""
|
||||||
json_str = encode_canonical_json(post_json)
|
json_str = encode_canonical_json(post_json)
|
||||||
|
|
||||||
|
@ -193,9 +201,7 @@ class SimpleHttpClient(object):
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
else:
|
else:
|
||||||
raise self._exceptionFromFailedRequest(response, body)
|
raise HttpResponseException(response.code, response.phrase, body)
|
||||||
|
|
||||||
defer.returnValue(json.loads(body))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_json(self, uri, args={}, headers=None):
|
def get_json(self, uri, args={}, headers=None):
|
||||||
|
@ -213,14 +219,12 @@ class SimpleHttpClient(object):
|
||||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||||
HTTP body as JSON.
|
HTTP body as JSON.
|
||||||
Raises:
|
Raises:
|
||||||
On a non-2xx HTTP response. The response body will be used as the
|
HttpResponseException On a non-2xx HTTP response.
|
||||||
error message.
|
|
||||||
|
ValueError: if the response was not JSON
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
body = yield self.get_raw(uri, args, headers=headers)
|
body = yield self.get_raw(uri, args, headers=headers)
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
except CodeMessageException as e:
|
|
||||||
raise self._exceptionFromFailedRequest(e.code, e.msg)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def put_json(self, uri, json_body, args={}, headers=None):
|
def put_json(self, uri, json_body, args={}, headers=None):
|
||||||
|
@ -239,7 +243,9 @@ class SimpleHttpClient(object):
|
||||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||||
HTTP body as JSON.
|
HTTP body as JSON.
|
||||||
Raises:
|
Raises:
|
||||||
On a non-2xx HTTP response.
|
HttpResponseException On a non-2xx HTTP response.
|
||||||
|
|
||||||
|
ValueError: if the response was not JSON
|
||||||
"""
|
"""
|
||||||
if len(args):
|
if len(args):
|
||||||
query_bytes = urllib.urlencode(args, True)
|
query_bytes = urllib.urlencode(args, True)
|
||||||
|
@ -266,10 +272,7 @@ class SimpleHttpClient(object):
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
else:
|
else:
|
||||||
# NB: This is explicitly not json.loads(body)'d because the contract
|
raise HttpResponseException(response.code, response.phrase, body)
|
||||||
# of CodeMessageException is a *string* message. Callers can always
|
|
||||||
# load it into JSON if they want.
|
|
||||||
raise CodeMessageException(response.code, body)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_raw(self, uri, args={}, headers=None):
|
def get_raw(self, uri, args={}, headers=None):
|
||||||
|
@ -287,8 +290,7 @@ class SimpleHttpClient(object):
|
||||||
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
|
||||||
HTTP body at text.
|
HTTP body at text.
|
||||||
Raises:
|
Raises:
|
||||||
On a non-2xx HTTP response. The response body will be used as the
|
HttpResponseException on a non-2xx HTTP response.
|
||||||
error message.
|
|
||||||
"""
|
"""
|
||||||
if len(args):
|
if len(args):
|
||||||
query_bytes = urllib.urlencode(args, True)
|
query_bytes = urllib.urlencode(args, True)
|
||||||
|
@ -311,16 +313,7 @@ class SimpleHttpClient(object):
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
defer.returnValue(body)
|
defer.returnValue(body)
|
||||||
else:
|
else:
|
||||||
raise CodeMessageException(response.code, body)
|
raise HttpResponseException(response.code, response.phrase, body)
|
||||||
|
|
||||||
def _exceptionFromFailedRequest(self, response, body):
|
|
||||||
try:
|
|
||||||
jsonBody = json.loads(body)
|
|
||||||
errcode = jsonBody['errcode']
|
|
||||||
error = jsonBody['error']
|
|
||||||
return MatrixCodeMessageException(response.code, error, errcode)
|
|
||||||
except (ValueError, KeyError):
|
|
||||||
return CodeMessageException(response.code, body)
|
|
||||||
|
|
||||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||||
# The two should be factored out.
|
# The two should be factored out.
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import cgi
|
import cgi
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
|
||||||
|
|
||||||
from six.moves import http_client
|
from six import PY3
|
||||||
|
from six.moves import http_client, urllib
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||||
|
|
||||||
|
@ -35,7 +36,6 @@ from synapse.api.errors import (
|
||||||
Codes,
|
Codes,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
UnrecognizedRequestError,
|
UnrecognizedRequestError,
|
||||||
cs_exception,
|
|
||||||
)
|
)
|
||||||
from synapse.http.request_metrics import requests_counter
|
from synapse.http.request_metrics import requests_counter
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
|
@ -76,16 +76,13 @@ def wrap_json_request_handler(h):
|
||||||
def wrapped_request_handler(self, request):
|
def wrapped_request_handler(self, request):
|
||||||
try:
|
try:
|
||||||
yield h(self, request)
|
yield h(self, request)
|
||||||
except CodeMessageException as e:
|
except SynapseError as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
if isinstance(e, SynapseError):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"%s SynapseError: %s - %s", request, code, e.msg
|
"%s SynapseError: %s - %s", request, code, e.msg
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
logger.exception(e)
|
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request, code, cs_exception(e), send_cors=True,
|
request, code, e.error_dict(), send_cors=True,
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -264,6 +261,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback):
|
def register_paths(self, method, path_patterns, callback):
|
||||||
|
method = method.encode("utf-8") # method is bytes on py3
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
|
@ -296,8 +294,19 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
# here. If it throws an exception, that is handled by the wrapper
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
# installed by @request_handler.
|
# installed by @request_handler.
|
||||||
|
|
||||||
|
def _unquote(s):
|
||||||
|
if PY3:
|
||||||
|
# On Python 3, unquote is unicode -> unicode
|
||||||
|
return urllib.parse.unquote(s)
|
||||||
|
else:
|
||||||
|
# On Python 2, unquote is bytes -> bytes We need to encode the
|
||||||
|
# URL again (as it was decoded by _get_handler_for request), as
|
||||||
|
# ASCII because it's a URL, and then decode it to get the UTF-8
|
||||||
|
# characters that were quoted.
|
||||||
|
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
|
||||||
|
|
||||||
kwargs = intern_dict({
|
kwargs = intern_dict({
|
||||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
name: _unquote(value) if value else value
|
||||||
for name, value in group_dict.items()
|
for name, value in group_dict.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -313,9 +322,9 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
request (twisted.web.http.Request):
|
request (twisted.web.http.Request):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Callable, dict[str, str]]: callback method, and the dict
|
Tuple[Callable, dict[unicode, unicode]]: callback method, and the
|
||||||
mapping keys to path components as specified in the handler's
|
dict mapping keys to path components as specified in the
|
||||||
path match regexp.
|
handler's path match regexp.
|
||||||
|
|
||||||
The callback will normally be a method registered via
|
The callback will normally be a method registered via
|
||||||
register_paths, so will return (possibly via Deferred) either
|
register_paths, so will return (possibly via Deferred) either
|
||||||
|
@ -327,7 +336,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
for path_entry in self.path_regexs.get(request.method, []):
|
for path_entry in self.path_regexs.get(request.method, []):
|
||||||
m = path_entry.pattern.match(request.path)
|
m = path_entry.pattern.match(request.path.decode('ascii'))
|
||||||
if m:
|
if m:
|
||||||
# We found a match!
|
# We found a match!
|
||||||
return path_entry.callback, m.groupdict()
|
return path_entry.callback, m.groupdict()
|
||||||
|
@ -383,7 +392,7 @@ class RootRedirect(resource.Resource):
|
||||||
self.url = path
|
self.url = path
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
return redirectTo(self.url, request)
|
return redirectTo(self.url.encode('ascii'), request)
|
||||||
|
|
||||||
def getChild(self, name, request):
|
def getChild(self, name, request):
|
||||||
if len(name) == 0:
|
if len(name) == 0:
|
||||||
|
@ -404,12 +413,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
|
||||||
return
|
return
|
||||||
|
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
|
||||||
|
).encode("utf-8")
|
||||||
else:
|
else:
|
||||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||||
|
# canonicaljson already encodes to bytes
|
||||||
json_bytes = encode_canonical_json(json_object)
|
json_bytes = encode_canonical_json(json_object)
|
||||||
else:
|
else:
|
||||||
json_bytes = json.dumps(json_object)
|
json_bytes = json.dumps(json_object).encode("utf-8")
|
||||||
|
|
||||||
return respond_with_json_bytes(
|
return respond_with_json_bytes(
|
||||||
request, code, json_bytes,
|
request, code, json_bytes,
|
||||||
|
|
|
@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
||||||
if not content_bytes and allow_empty_body:
|
if not content_bytes and allow_empty_body:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Decode to Unicode so that simplejson will return Unicode strings on
|
||||||
|
# Python 2
|
||||||
try:
|
try:
|
||||||
content = json.loads(content_bytes)
|
content_unicode = content_bytes.decode('utf8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.warn("Unable to decode UTF-8")
|
||||||
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = json.loads(content_unicode)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("Unable to parse JSON: %s", e)
|
logger.warn("Unable to parse JSON: %s", e)
|
||||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
|
@ -23,8 +23,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
MatrixCodeMessageException,
|
HttpResponseException,
|
||||||
SynapseError,
|
|
||||||
)
|
)
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
@ -160,11 +159,11 @@ class ReplicationEndpoint(object):
|
||||||
# If we timed out we probably don't need to worry about backing
|
# If we timed out we probably don't need to worry about backing
|
||||||
# off too much, but lets just wait a little anyway.
|
# off too much, but lets just wait a little anyway.
|
||||||
yield clock.sleep(1)
|
yield clock.sleep(1)
|
||||||
except MatrixCodeMessageException as e:
|
except HttpResponseException as e:
|
||||||
# We convert to SynapseError as we know that it was a SynapseError
|
# We convert to SynapseError as we know that it was a SynapseError
|
||||||
# on the master process that we should send to the client. (And
|
# on the master process that we should send to the client. (And
|
||||||
# importantly, not stack traces everywhere)
|
# importantly, not stack traces everywhere)
|
||||||
raise SynapseError(e.code, e.msg, e.errcode)
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from six import text_type
|
||||||
from six.moves import http_client
|
from six.moves import http_client
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (not isinstance(body['username'], str) or len(body['username']) > 512):
|
if (
|
||||||
|
not isinstance(body['username'], text_type)
|
||||||
|
or len(body['username']) > 512
|
||||||
|
):
|
||||||
raise SynapseError(400, "Invalid username")
|
raise SynapseError(400, "Invalid username")
|
||||||
|
|
||||||
username = body["username"].encode("utf-8")
|
username = body["username"].encode("utf-8")
|
||||||
|
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
400, "password must be specified", errcode=Codes.BAD_JSON,
|
400, "password must be specified", errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (not isinstance(body['password'], str) or len(body['password']) > 512):
|
if (
|
||||||
|
not isinstance(body['password'], text_type)
|
||||||
|
or len(body['password']) > 512
|
||||||
|
):
|
||||||
raise SynapseError(400, "Invalid password")
|
raise SynapseError(400, "Invalid password")
|
||||||
|
|
||||||
password = body["password"].encode("utf-8")
|
password = body["password"].encode("utf-8")
|
||||||
|
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
want_mac.update(b"admin" if admin else b"notadmin")
|
want_mac.update(b"admin" if admin else b"notadmin")
|
||||||
want_mac = want_mac.hexdigest()
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(want_mac, got_mac):
|
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
|
||||||
raise SynapseError(
|
raise SynapseError(403, "HMAC incorrect")
|
||||||
403, "HMAC incorrect",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
|
|
||||||
register = RegisterRestServlet(self.hs)
|
register = RegisterRestServlet(self.hs)
|
||||||
|
|
||||||
(user_id, _) = yield register.registration_handler.register(
|
(user_id, _) = yield register.registration_handler.register(
|
||||||
localpart=username.lower(), password=password, admin=bool(admin),
|
localpart=body['username'].lower(),
|
||||||
|
password=body["password"],
|
||||||
|
admin=bool(admin),
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.types import RoomAlias
|
from synapse.types import RoomAlias
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
room = yield self.store.get_room(room_id)
|
room = yield self.store.get_room(room_id)
|
||||||
if room is None:
|
if room is None:
|
||||||
raise SynapseError(400, "Unknown room")
|
raise NotFoundError("Unknown room")
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"visibility": "public" if room["is_public"] else "private"
|
"visibility": "public" if room["is_public"] else "private"
|
||||||
|
|
|
@ -88,7 +88,7 @@ class EventRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, event_id):
|
def on_GET(self, request, event_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
event = yield self.event_handler.get_event(requester.user, None, event_id)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
if event:
|
if event:
|
||||||
|
|
|
@ -506,7 +506,7 @@ class RoomEventServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id, event_id):
|
def on_GET(self, request, room_id, event_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
event = yield self.event_handler.get_event(requester.user, room_id, event_id)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
if event:
|
if event:
|
||||||
|
|
|
@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
kind = "user"
|
kind = b"user"
|
||||||
if "kind" in request.args:
|
if b"kind" in request.args:
|
||||||
kind = request.args["kind"][0]
|
kind = request.args[b"kind"][0]
|
||||||
|
|
||||||
if kind == "guest":
|
if kind == b"guest":
|
||||||
ret = yield self._do_guest_registration(body)
|
ret = yield self._do_guest_registration(body)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
return
|
return
|
||||||
elif kind != "user":
|
elif kind != b"user":
|
||||||
raise UnrecognizedRequestError(
|
raise UnrecognizedRequestError(
|
||||||
"Do not understand membership kind: %s" % (kind,)
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
)
|
)
|
||||||
|
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
|
||||||
assert_params_in_dict(params, ["password"])
|
assert_params_in_dict(params, ["password"])
|
||||||
|
|
||||||
desired_username = params.get("username", None)
|
desired_username = params.get("username", None)
|
||||||
new_password = params.get("password", None)
|
|
||||||
guest_access_token = params.get("guest_access_token", None)
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
new_password = params.get("password", None)
|
||||||
|
|
||||||
if desired_username is not None:
|
if desired_username is not None:
|
||||||
desired_username = desired_username.lower()
|
desired_username = desired_username.lower()
|
||||||
|
|
|
@ -379,7 +379,7 @@ class MediaRepository(object):
|
||||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||||
server_name, media_id, e.response)
|
server_name, media_id, e.response)
|
||||||
if e.code == twisted.web.http.NOT_FOUND:
|
if e.code == twisted.web.http.NOT_FOUND:
|
||||||
raise SynapseError.from_http_response_exception(e)
|
raise e.to_synapse_error()
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
|
|
|
@ -177,7 +177,7 @@ class MediaStorage(object):
|
||||||
if res:
|
if res:
|
||||||
with res:
|
with res:
|
||||||
consumer = BackgroundFileConsumer(
|
consumer = BackgroundFileConsumer(
|
||||||
open(local_path, "w"), self.hs.get_reactor())
|
open(local_path, "wb"), self.hs.get_reactor())
|
||||||
yield res.write_to_consumer(consumer)
|
yield res.write_to_consumer(consumer)
|
||||||
yield consumer.wait()
|
yield consumer.wait()
|
||||||
defer.returnValue(local_path)
|
defer.returnValue(local_path)
|
||||||
|
|
|
@ -577,7 +577,7 @@ def _make_state_cache_entry(
|
||||||
|
|
||||||
def _ordered_events(events):
|
def _ordered_events(events):
|
||||||
def key_func(e):
|
def key_func(e):
|
||||||
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
|
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
|
||||||
|
|
||||||
return sorted(events, key=key_func)
|
return sorted(events, key=key_func)
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
PresenceStore, TransactionStore,
|
PresenceStore, TransactionStore,
|
||||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||||
ApplicationServiceStore,
|
ApplicationServiceStore,
|
||||||
|
EventsStore,
|
||||||
EventFederationStore,
|
EventFederationStore,
|
||||||
MediaRepositoryStore,
|
MediaRepositoryStore,
|
||||||
RejectionsStore,
|
RejectionsStore,
|
||||||
|
@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
PusherStore,
|
PusherStore,
|
||||||
PushRuleStore,
|
PushRuleStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
EventsStore,
|
|
||||||
ReceiptsStore,
|
ReceiptsStore,
|
||||||
EndToEndKeyStore,
|
EndToEndKeyStore,
|
||||||
SearchStore,
|
SearchStore,
|
||||||
|
@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
|
self.db_conn = db_conn
|
||||||
self._stream_id_gen = StreamIdGenerator(
|
self._stream_id_gen = StreamIdGenerator(
|
||||||
db_conn, "events", "stream_ordering",
|
db_conn, "events", "stream_ordering",
|
||||||
extra_tables=[("local_invites", "stream_id")]
|
extra_tables=[("local_invites", "stream_id")]
|
||||||
|
@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
|
|
||||||
return self.runInteraction("count_users", _count_users)
|
return self.runInteraction("count_users", _count_users)
|
||||||
|
|
||||||
|
def count_monthly_users(self):
|
||||||
|
"""Counts the number of users who used this homeserver in the last 30 days
|
||||||
|
|
||||||
|
This method should be refactored with count_daily_users - the only
|
||||||
|
reason not to is waiting on definition of mau
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Defered[int]
|
||||||
|
"""
|
||||||
|
def _count_monthly_users(txn):
|
||||||
|
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||||
|
sql = """
|
||||||
|
SELECT COALESCE(count(*), 0) FROM (
|
||||||
|
SELECT user_id FROM user_ips
|
||||||
|
WHERE last_seen > ?
|
||||||
|
GROUP BY user_id
|
||||||
|
) u
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (thirty_days_ago,))
|
||||||
|
count, = txn.fetchone()
|
||||||
|
return count
|
||||||
|
|
||||||
|
return self.runInteraction("count_monthly_users", _count_monthly_users)
|
||||||
|
|
||||||
def count_r30_users(self):
|
def count_r30_users(self):
|
||||||
"""
|
"""
|
||||||
Counts the number of 30 day retained users, defined as:-
|
Counts the number of 30 day retained users, defined as:-
|
||||||
|
|
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.appservice import AppServiceTransaction
|
from synapse.appservice import AppServiceTransaction
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.signatures import SignatureWorkerStore
|
from synapse.storage.signatures import SignatureWorkerStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
@ -343,6 +343,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||||
table="events",
|
table="events",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
|
"room_id": room_id,
|
||||||
},
|
},
|
||||||
retcol="depth",
|
retcol="depth",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
|
|
@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase # noqa: F401
|
from synapse.events import EventBase # noqa: F401
|
||||||
from synapse.events.snapshot import EventContext # noqa: F401
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
from synapse.storage.events_worker import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
|
||||||
|
|
||||||
|
|
||||||
def encode_json(json_object):
|
def encode_json(json_object):
|
||||||
return frozendict_json_encoder.encode(json_object)
|
"""
|
||||||
|
Encode a Python object as JSON and return it in a Unicode string.
|
||||||
|
"""
|
||||||
|
out = frozendict_json_encoder.encode(json_object)
|
||||||
|
if isinstance(out, bytes):
|
||||||
|
out = out.decode('utf8')
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class _EventPeristenceQueue(object):
|
class _EventPeristenceQueue(object):
|
||||||
|
@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
class EventsStore(EventsWorkerStore):
|
# inherits from EventFederationStore so that we can call _update_backward_extremities
|
||||||
|
# and _handle_mult_prev_events (though arguably those could both be moved in here)
|
||||||
|
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
|
||||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||||
|
|
||||||
|
@ -231,12 +241,18 @@ class EventsStore(EventsWorkerStore):
|
||||||
|
|
||||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def persist_events(self, events_and_contexts, backfilled=False):
|
def persist_events(self, events_and_contexts, backfilled=False):
|
||||||
"""
|
"""
|
||||||
Write events to the database
|
Write events to the database
|
||||||
Args:
|
Args:
|
||||||
events_and_contexts: list of tuples of (event, context)
|
events_and_contexts: list of tuples of (event, context)
|
||||||
backfilled: ?
|
backfilled (bool): Whether the results are retrieved from federation
|
||||||
|
via backfill or not. Used to determine if they're "new" events
|
||||||
|
which might update the current state etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[int]: the stream ordering of the latest persisted event
|
||||||
"""
|
"""
|
||||||
partitioned = {}
|
partitioned = {}
|
||||||
for event, ctx in events_and_contexts:
|
for event, ctx in events_and_contexts:
|
||||||
|
@ -253,10 +269,14 @@ class EventsStore(EventsWorkerStore):
|
||||||
for room_id in partitioned:
|
for room_id in partitioned:
|
||||||
self._maybe_start_persisting(room_id)
|
self._maybe_start_persisting(room_id)
|
||||||
|
|
||||||
return make_deferred_yieldable(
|
yield make_deferred_yieldable(
|
||||||
defer.gatherResults(deferreds, consumeErrors=True)
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
|
defer.returnValue(max_persisted_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def persist_event(self, event, context, backfilled=False):
|
def persist_event(self, event, context, backfilled=False):
|
||||||
|
@ -1054,7 +1074,7 @@ class EventsStore(EventsWorkerStore):
|
||||||
|
|
||||||
metadata_json = encode_json(
|
metadata_json = encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
).decode("UTF-8")
|
)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
|
@ -1168,8 +1188,8 @@ class EventsStore(EventsWorkerStore):
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"internal_metadata": encode_json(
|
"internal_metadata": encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
).decode("UTF-8"),
|
),
|
||||||
"json": encode_json(event_dict(event)).decode("UTF-8"),
|
"json": encode_json(event_dict(event)),
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
],
|
],
|
||||||
|
|
|
@ -19,7 +19,7 @@ from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import NotFoundError
|
||||||
# these are only included to make the type annotations work
|
# these are only included to make the type annotations work
|
||||||
from synapse.events import EventBase # noqa: F401
|
from synapse.events import EventBase # noqa: F401
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
@ -77,7 +77,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_event(self, event_id, check_redacted=True,
|
def get_event(self, event_id, check_redacted=True,
|
||||||
get_prev_content=False, allow_rejected=False,
|
get_prev_content=False, allow_rejected=False,
|
||||||
allow_none=False):
|
allow_none=False, check_room_id=None):
|
||||||
"""Get an event from the database by event_id.
|
"""Get an event from the database by event_id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -88,7 +88,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
include the previous states content in the unsigned field.
|
include the previous states content in the unsigned field.
|
||||||
allow_rejected (bool): If True return rejected events.
|
allow_rejected (bool): If True return rejected events.
|
||||||
allow_none (bool): If True, return None if no event found, if
|
allow_none (bool): If True, return None if no event found, if
|
||||||
False throw an exception.
|
False throw a NotFoundError
|
||||||
|
check_room_id (str|None): if not None, check the room of the found event.
|
||||||
|
If there is a mismatch, behave as per allow_none.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A FrozenEvent.
|
Deferred : A FrozenEvent.
|
||||||
|
@ -100,10 +102,16 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
allow_rejected=allow_rejected,
|
allow_rejected=allow_rejected,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not events and not allow_none:
|
event = events[0] if events else None
|
||||||
raise SynapseError(404, "Could not find event %s" % (event_id,))
|
|
||||||
|
|
||||||
defer.returnValue(events[0] if events else None)
|
if event is not None and check_room_id is not None:
|
||||||
|
if event.room_id != check_room_id:
|
||||||
|
event = None
|
||||||
|
|
||||||
|
if event is None and not allow_none:
|
||||||
|
raise NotFoundError("Could not find event %s" % (event_id,))
|
||||||
|
|
||||||
|
defer.returnValue(event)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_events(self, event_ids, check_redacted=True,
|
def get_events(self, event_ids, check_redacted=True,
|
||||||
|
|
|
@ -24,7 +24,7 @@ from canonicaljson import json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
|
|
|
@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||||
txn (cursor):
|
txn (cursor):
|
||||||
event_id (str): Id for the Event.
|
event_id (str): Id for the Event.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of algorithm -> hash.
|
A dict[unicode, bytes] of algorithm -> hash.
|
||||||
"""
|
"""
|
||||||
query = (
|
query = (
|
||||||
"SELECT algorithm, hash"
|
"SELECT algorithm, hash"
|
||||||
|
|
|
@ -43,7 +43,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
|
@ -137,7 +137,7 @@ class DomainSpecificString(
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, s):
|
def from_string(cls, s):
|
||||||
"""Parse the string given by 's' into a structure object."""
|
"""Parse the string given by 's' into a structure object."""
|
||||||
if len(s) < 1 or s[0] != cls.SIGIL:
|
if len(s) < 1 or s[0:1] != cls.SIGIL:
|
||||||
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
||||||
cls.__name__, cls.SIGIL,
|
cls.__name__, cls.SIGIL,
|
||||||
))
|
))
|
||||||
|
|
|
@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its
|
||||||
# whenever we are invalidated
|
# invalidate() whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
list_args = arg_dict[self.list_name]
|
list_args = arg_dict[self.list_name]
|
||||||
|
|
||||||
# cached is a dict arg -> deferred, where deferred results in a
|
|
||||||
# 2-tuple (`arg`, `result`)
|
|
||||||
results = {}
|
results = {}
|
||||||
cached_defers = {}
|
|
||||||
missing = []
|
def update_results_dict(res, arg):
|
||||||
|
results[arg] = res
|
||||||
|
|
||||||
|
# list of deferreds to wait for
|
||||||
|
cached_defers = []
|
||||||
|
|
||||||
|
missing = set()
|
||||||
|
|
||||||
# If the cache takes a single arg then that is used as the key,
|
# If the cache takes a single arg then that is used as the key,
|
||||||
# otherwise a tuple is used.
|
# otherwise a tuple is used.
|
||||||
if num_args == 1:
|
if num_args == 1:
|
||||||
def cache_get(arg):
|
def arg_to_cache_key(arg):
|
||||||
return cache.get(arg, callback=invalidate_callback)
|
return arg
|
||||||
else:
|
else:
|
||||||
key = list(keyargs)
|
keylist = list(keyargs)
|
||||||
|
|
||||||
def cache_get(arg):
|
def arg_to_cache_key(arg):
|
||||||
key[self.list_pos] = arg
|
keylist[self.list_pos] = arg
|
||||||
return cache.get(tuple(key), callback=invalidate_callback)
|
return tuple(keylist)
|
||||||
|
|
||||||
for arg in list_args:
|
for arg in list_args:
|
||||||
try:
|
try:
|
||||||
res = cache_get(arg)
|
res = cache.get(arg_to_cache_key(arg),
|
||||||
|
callback=invalidate_callback)
|
||||||
if not isinstance(res, ObservableDeferred):
|
if not isinstance(res, ObservableDeferred):
|
||||||
results[arg] = res
|
results[arg] = res
|
||||||
elif not res.has_succeeded():
|
elif not res.has_succeeded():
|
||||||
res = res.observe()
|
res = res.observe()
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
res.addCallback(update_results_dict, arg)
|
||||||
cached_defers[arg] = res
|
cached_defers.append(res)
|
||||||
else:
|
else:
|
||||||
results[arg] = res.get_result()
|
results[arg] = res.get_result()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing.append(arg)
|
missing.add(arg)
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
args_to_call = dict(arg_dict)
|
# we need an observable deferred for each entry in the list,
|
||||||
args_to_call[self.list_name] = missing
|
# which we put in the cache. Each deferred resolves with the
|
||||||
|
# relevant result for that key.
|
||||||
|
deferreds_map = {}
|
||||||
|
for arg in missing:
|
||||||
|
deferred = defer.Deferred()
|
||||||
|
deferreds_map[arg] = deferred
|
||||||
|
key = arg_to_cache_key(arg)
|
||||||
|
observable = ObservableDeferred(deferred)
|
||||||
|
cache.set(key, observable, callback=invalidate_callback)
|
||||||
|
|
||||||
ret_d = defer.maybeDeferred(
|
def complete_all(res):
|
||||||
|
# the wrapped function has completed. It returns a
|
||||||
|
# a dict. We can now resolve the observable deferreds in
|
||||||
|
# the cache and update our own result map.
|
||||||
|
for e in missing:
|
||||||
|
val = res.get(e, None)
|
||||||
|
deferreds_map[e].callback(val)
|
||||||
|
results[e] = val
|
||||||
|
|
||||||
|
def errback(f):
|
||||||
|
# the wrapped function has failed. Invalidate any cache
|
||||||
|
# entries we're supposed to be populating, and fail
|
||||||
|
# their deferreds.
|
||||||
|
for e in missing:
|
||||||
|
key = arg_to_cache_key(e)
|
||||||
|
cache.invalidate(key)
|
||||||
|
deferreds_map[e].errback(f)
|
||||||
|
|
||||||
|
# return the failure, to propagate to our caller.
|
||||||
|
return f
|
||||||
|
|
||||||
|
args_to_call = dict(arg_dict)
|
||||||
|
args_to_call[self.list_name] = list(missing)
|
||||||
|
|
||||||
|
cached_defers.append(defer.maybeDeferred(
|
||||||
logcontext.preserve_fn(self.function_to_call),
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
**args_to_call
|
**args_to_call
|
||||||
)
|
).addCallbacks(complete_all, errback))
|
||||||
|
|
||||||
ret_d = ObservableDeferred(ret_d)
|
|
||||||
|
|
||||||
# We need to create deferreds for each arg in the list so that
|
|
||||||
# we can insert the new deferred into the cache.
|
|
||||||
for arg in missing:
|
|
||||||
observer = ret_d.observe()
|
|
||||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
|
||||||
|
|
||||||
observer = ObservableDeferred(observer)
|
|
||||||
|
|
||||||
if num_args == 1:
|
|
||||||
cache.set(
|
|
||||||
arg, observer,
|
|
||||||
callback=invalidate_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def invalidate(f, key):
|
|
||||||
cache.invalidate(key)
|
|
||||||
return f
|
|
||||||
observer.addErrback(invalidate, arg)
|
|
||||||
else:
|
|
||||||
key = list(keyargs)
|
|
||||||
key[self.list_pos] = arg
|
|
||||||
cache.set(
|
|
||||||
tuple(key), observer,
|
|
||||||
callback=invalidate_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def invalidate(f, key):
|
|
||||||
cache.invalidate(key)
|
|
||||||
return f
|
|
||||||
observer.addErrback(invalidate, tuple(key))
|
|
||||||
|
|
||||||
res = observer.observe()
|
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
|
||||||
|
|
||||||
cached_defers[arg] = res
|
|
||||||
|
|
||||||
if cached_defers:
|
if cached_defers:
|
||||||
def update_results_dict(res):
|
d = defer.gatherResults(
|
||||||
results.update(res)
|
cached_defers,
|
||||||
return results
|
|
||||||
|
|
||||||
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
|
||||||
list(cached_defers.values()),
|
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addCallback(update_results_dict).addErrback(
|
).addCallbacks(
|
||||||
|
lambda _: results,
|
||||||
unwrapFirstError
|
unwrapFirstError
|
||||||
))
|
)
|
||||||
|
return logcontext.make_deferred_yieldable(d)
|
||||||
else:
|
else:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
|
||||||
cache.
|
cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache (Cache): The underlying cache to use.
|
cached_method_name (str): The name of the single-item lookup method.
|
||||||
|
This is only used to find the cache to use.
|
||||||
list_name (str): The name of the argument that is the list to use to
|
list_name (str): The name of the argument that is the list to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args (int): Number of arguments to use as the key in the cache
|
num_args (int): Number of arguments to use as the key in the cache
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from six import string_types
|
from six import binary_type, text_type
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
@ -26,7 +26,7 @@ def freeze(o):
|
||||||
if isinstance(o, frozendict):
|
if isinstance(o, frozendict):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
if isinstance(o, string_types):
|
if isinstance(o, (binary_type, text_type)):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -41,7 +41,7 @@ def unfreeze(o):
|
||||||
if isinstance(o, (dict, frozendict)):
|
if isinstance(o, (dict, frozendict)):
|
||||||
return dict({k: unfreeze(v) for k, v in o.items()})
|
return dict({k: unfreeze(v) for k, v in o.items()})
|
||||||
|
|
||||||
if isinstance(o, string_types):
|
if isinstance(o, (binary_type, text_type)):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.auth = Auth(self.hs)
|
self.auth = Auth(self.hs)
|
||||||
|
|
||||||
self.test_user = "@foo:bar"
|
self.test_user = "@foo:bar"
|
||||||
self.test_token = "_test_token_"
|
self.test_token = b"_test_token_"
|
||||||
|
|
||||||
# this is overridden for the appservice tests
|
# this is overridden for the appservice tests
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "192.168.10.10"
|
request.getClientIP.return_value = "192.168.10.10"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "131.111.8.42"
|
request.getClientIP.return_value = "131.111.8.42"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
|
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), masquerading_user_id)
|
self.assertEquals(
|
||||||
|
requester.user.to_string(),
|
||||||
|
masquerading_user_id.decode('utf8')
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
|
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# check the token works
|
# check the token works
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [token]
|
request.args[b"access_token"] = [token.encode('ascii')]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||||
|
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# the token should *not* work now
|
# the token should *not* work now
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [guest_tok]
|
request.args[b"access_token"] = [guest_tok.encode('ascii')]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
@ -19,6 +20,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
|
from synapse.api.errors import AuthError
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.hs.handlers = AuthHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
self.auth_handler = self.hs.handlers.auth_handler
|
self.auth_handler = self.hs.handlers.auth_handler
|
||||||
self.macaroon_generator = self.hs.get_macaroon_generator()
|
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||||
|
# MAU tests
|
||||||
|
self.hs.config.max_mau_value = 50
|
||||||
|
self.small_number_of_users = 1
|
||||||
|
self.large_number_of_users = 100
|
||||||
|
|
||||||
def test_token_is_a_macaroon(self):
|
def test_token_is_a_macaroon(self):
|
||||||
token = self.macaroon_generator.generate_access_token("some_user")
|
token = self.macaroon_generator.generate_access_token("some_user")
|
||||||
|
@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
|
||||||
v.satisfy_general(verify_nonce)
|
v.satisfy_general(verify_nonce)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
self.hs.clock.now = 1000
|
self.hs.clock.now = 1000
|
||||||
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
"a_user", 5000
|
"a_user", 5000
|
||||||
)
|
)
|
||||||
|
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
self.assertEqual(
|
|
||||||
"a_user",
|
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
)
|
self.assertEqual("a_user", user_id)
|
||||||
|
|
||||||
# when we advance the clock, the token should be rejected
|
# when we advance the clock, the token should be rejected
|
||||||
self.hs.clock.now = 6000
|
self.hs.clock.now = 6000
|
||||||
with self.assertRaises(synapse.api.errors.AuthError):
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
"a_user", 5000
|
"a_user", 5000
|
||||||
)
|
)
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
self.assertEqual(
|
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
"a_user",
|
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
|
||||||
macaroon.serialize()
|
macaroon.serialize()
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
"a_user", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# add another "user_id" caveat, which might allow us to override the
|
# add another "user_id" caveat, which might allow us to override the
|
||||||
|
@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
|
||||||
macaroon.add_first_party_caveat("user_id = b_user")
|
macaroon.add_first_party_caveat("user_id = b_user")
|
||||||
|
|
||||||
with self.assertRaises(synapse.api.errors.AuthError):
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
macaroon.serialize()
|
macaroon.serialize()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_mau_limits_disabled(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = False
|
||||||
|
# Ensure does not throw exception
|
||||||
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
self._get_macaroon().serialize()
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_mau_limits_exceeded(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
|
)
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
self._get_macaroon().serialize()
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_mau_limits_not_exceeded(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
|
)
|
||||||
|
# Ensure does not raise exception
|
||||||
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
|
)
|
||||||
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
self._get_macaroon().serialize()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_macaroon(self):
|
||||||
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"user_a", 5000
|
||||||
|
)
|
||||||
|
return pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import RegistrationError
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
|
@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
requester, local_part, display_name)
|
requester, local_part, display_name)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cannot_register_when_mau_limits_exceeded(self):
|
||||||
|
local_part = "someone"
|
||||||
|
display_name = "someone"
|
||||||
|
requester = create_requester("@as:test")
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.hs.config.limit_usage_by_mau = False
|
||||||
|
self.hs.config.max_mau_value = 50
|
||||||
|
lots_of_users = 100
|
||||||
|
small_number_users = 1
|
||||||
|
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||||
|
|
||||||
|
# Ensure does not throw exception
|
||||||
|
yield self.handler.get_or_create_user(requester, 'a', display_name)
|
||||||
|
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
|
||||||
|
with self.assertRaises(RegistrationError):
|
||||||
|
yield self.handler.get_or_create_user(requester, 'b', display_name)
|
||||||
|
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
|
||||||
|
|
||||||
|
self._macaroon_mock_generator("another_secret")
|
||||||
|
|
||||||
|
# Ensure does not throw exception
|
||||||
|
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
|
||||||
|
|
||||||
|
self._macaroon_mock_generator("another another secret")
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||||
|
|
||||||
|
with self.assertRaises(RegistrationError):
|
||||||
|
yield self.handler.register(localpart=local_part)
|
||||||
|
|
||||||
|
self._macaroon_mock_generator("another another secret")
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||||
|
|
||||||
|
with self.assertRaises(RegistrationError):
|
||||||
|
yield self.handler.register_saml2(local_part)
|
||||||
|
|
||||||
|
def _macaroon_mock_generator(self, secret):
|
||||||
|
"""
|
||||||
|
Reset macaroon generator in the case where the test creates multiple users
|
||||||
|
"""
|
||||||
|
macaroon_generator = Mock(
|
||||||
|
generate_access_token=Mock(return_value=secret))
|
||||||
|
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
|
||||||
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
|
|
|
@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pdu_failures": [],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
65
tests/storage/test__init__.py
Normal file
65
tests/storage/test__init__.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import tests.utils
|
||||||
|
|
||||||
|
|
||||||
|
class InitTestCase(tests.unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(InitTestCase, self).__init__(*args, **kwargs)
|
||||||
|
self.store = None # type: synapse.storage.DataStore
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
hs = yield tests.utils.setup_test_homeserver()
|
||||||
|
|
||||||
|
hs.config.max_mau_value = 50
|
||||||
|
hs.config.limit_usage_by_mau = True
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_count_monthly_users(self):
|
||||||
|
count = yield self.store.count_monthly_users()
|
||||||
|
self.assertEqual(0, count)
|
||||||
|
|
||||||
|
yield self._insert_user_ips("@user:server1")
|
||||||
|
yield self._insert_user_ips("@user:server2")
|
||||||
|
|
||||||
|
count = yield self.store.count_monthly_users()
|
||||||
|
self.assertEqual(2, count)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _insert_user_ips(self, user):
|
||||||
|
"""
|
||||||
|
Helper function to populate user_ips without using batch insertion infra
|
||||||
|
args:
|
||||||
|
user (str): specify username i.e. @user:server.com
|
||||||
|
"""
|
||||||
|
yield self.store._simple_upsert(
|
||||||
|
table="user_ips",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user,
|
||||||
|
"access_token": "access_token",
|
||||||
|
"ip": "ip",
|
||||||
|
"user_agent": "user_agent",
|
||||||
|
"device_id": "device_id",
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"last_seen": self.clock.time_msec(),
|
||||||
|
}
|
||||||
|
)
|
|
@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
r = yield obj.fn(2, 3)
|
r = yield obj.fn(2, 3)
|
||||||
self.assertEqual(r, 'chips')
|
self.assertEqual(r, 'chips')
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||||
|
def list_fn(self, args1, arg2):
|
||||||
|
assert (
|
||||||
|
logcontext.LoggingContext.current_context().request == "c1"
|
||||||
|
)
|
||||||
|
# we want this to behave like an asynchronous function
|
||||||
|
yield run_on_reactor()
|
||||||
|
assert (
|
||||||
|
logcontext.LoggingContext.current_context().request == "c1"
|
||||||
|
)
|
||||||
|
defer.returnValue(self.mock(args1, arg2))
|
||||||
|
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.request = "c1"
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||||
|
d1 = obj.list_fn([10, 20], 2)
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel,
|
||||||
|
)
|
||||||
|
r = yield d1
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(),
|
||||||
|
c1
|
||||||
|
)
|
||||||
|
obj.mock.assert_called_once_with([10, 20], 2)
|
||||||
|
self.assertEqual(r, {10: 'fish', 20: 'chips'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = {30: 'peas'}
|
||||||
|
r = yield obj.list_fn([20, 30], 2)
|
||||||
|
obj.mock.assert_called_once_with([30], 2)
|
||||||
|
self.assertEqual(r, {20: 'chips', 30: 'peas'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# all the values should now be cached
|
||||||
|
r = yield obj.fn(10, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(20, 2)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
r = yield obj.fn(30, 2)
|
||||||
|
self.assertEqual(r, 'peas')
|
||||||
|
r = yield obj.list_fn([10, 20, 30], 2)
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate(self):
|
||||||
|
"""Make sure that invalidation callbacks are called."""
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||||
|
def list_fn(self, args1, arg2):
|
||||||
|
# we want this to behave like an asynchronous function
|
||||||
|
yield run_on_reactor()
|
||||||
|
defer.returnValue(self.mock(args1, arg2))
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
invalidate0 = mock.Mock()
|
||||||
|
invalidate1 = mock.Mock()
|
||||||
|
|
||||||
|
# cache miss
|
||||||
|
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||||
|
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||||
|
obj.mock.assert_called_once_with([10, 20], 2)
|
||||||
|
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# cache hit
|
||||||
|
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
|
||||||
|
|
||||||
|
invalidate0.assert_not_called()
|
||||||
|
invalidate1.assert_not_called()
|
||||||
|
|
||||||
|
# now if we invalidate the keys, both invalidations should get called
|
||||||
|
obj.fn.invalidate((10, 2))
|
||||||
|
invalidate0.assert_called_once()
|
||||||
|
invalidate1.assert_called_once()
|
||||||
|
|
|
@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def trigger_get(self, path):
|
def trigger_get(self, path):
|
||||||
return self.trigger("GET", path, None)
|
return self.trigger(b"GET", path, None)
|
||||||
|
|
||||||
@patch('twisted.web.http.Request')
|
@patch('twisted.web.http.Request')
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if federation_auth:
|
if federation_auth:
|
||||||
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
|
headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
|
||||||
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
||||||
|
|
||||||
# return the right path if the event requires it
|
# return the right path if the event requires it
|
||||||
|
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
path = path.decode('utf8')
|
||||||
|
|
||||||
for (method, pattern, func) in self.callbacks:
|
for (method, pattern, func) in self.callbacks:
|
||||||
if http_method != method:
|
if http_method != method:
|
||||||
continue
|
continue
|
||||||
|
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
|
||||||
if matcher:
|
if matcher:
|
||||||
try:
|
try:
|
||||||
args = [
|
args = [
|
||||||
urlparse.unquote(u).decode("UTF-8")
|
urlparse.unquote(u)
|
||||||
for u in matcher.groups()
|
for u in matcher.groups()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue