Merge branch 'develop' into matrix-org-hotfixes

This commit is contained in:
Richard van der Hoff 2019-05-20 15:55:39 +01:00
commit ca03f90ee7
108 changed files with 2747 additions and 653 deletions

View file

@ -1,3 +1,9 @@
Synapse 0.99.4 (2019-05-15)
===========================
No significant changes.
Synapse 0.99.4rc1 (2019-05-13)
==============================
@ -17,8 +23,8 @@ Features
instead of the executable name, `python`.
Contributed by Christoph Müller. ([\#5023](https://github.com/matrix-org/synapse/issues/5023))
- Add time-based account expiration. ([\#5027](https://github.com/matrix-org/synapse/issues/5027), [\#5047](https://github.com/matrix-org/synapse/issues/5047), [\#5073](https://github.com/matrix-org/synapse/issues/5073), [\#5116](https://github.com/matrix-org/synapse/issues/5116))
- Add support for handling /verions, /voip and /push_rules client endpoints to client_reader worker. ([\#5063](https://github.com/matrix-org/synapse/issues/5063), [\#5065](https://github.com/matrix-org/synapse/issues/5065), [\#5070](https://github.com/matrix-org/synapse/issues/5070))
- Add an configuration option to require authentication on /publicRooms and /profile endpoints. ([\#5083](https://github.com/matrix-org/synapse/issues/5083))
- Add support for handling `/versions`, `/voip` and `/push_rules` client endpoints to client_reader worker. ([\#5063](https://github.com/matrix-org/synapse/issues/5063), [\#5065](https://github.com/matrix-org/synapse/issues/5065), [\#5070](https://github.com/matrix-org/synapse/issues/5070))
- Add a configuration option to require authentication on /publicRooms and /profile endpoints. ([\#5083](https://github.com/matrix-org/synapse/issues/5083))
- Move admin APIs to `/_synapse/admin/v1`. (The old paths are retained for backwards-compatibility, for now). ([\#5119](https://github.com/matrix-org/synapse/issues/5119))
- Implement an admin API for sending server notices. Many thanks to @krombel who provided a foundation for this work. ([\#5121](https://github.com/matrix-org/synapse/issues/5121), [\#5142](https://github.com/matrix-org/synapse/issues/5142))
@ -39,11 +45,9 @@ Bugfixes
- Workaround bug in twisted where attempting too many concurrent DNS requests could cause it to hang due to running out of file descriptors. ([\#5037](https://github.com/matrix-org/synapse/issues/5037))
- Make sure we're not registering the same 3pid twice on registration. ([\#5071](https://github.com/matrix-org/synapse/issues/5071))
- Don't crash on lack of expiry templates. ([\#5077](https://github.com/matrix-org/synapse/issues/5077))
- Fix the ratelimting on third party invites. ([\#5104](https://github.com/matrix-org/synapse/issues/5104))
- Fix the ratelimiting on third party invites. ([\#5104](https://github.com/matrix-org/synapse/issues/5104))
- Add some missing limitations to room alias creation. ([\#5124](https://github.com/matrix-org/synapse/issues/5124), [\#5128](https://github.com/matrix-org/synapse/issues/5128))
- Limit the number of EDUs in transactions to 100 as expected by synapse. Thanks to @superboum for this work! ([\#5138](https://github.com/matrix-org/synapse/issues/5138))
- Fix bogus imports in unit tests. ([\#5154](https://github.com/matrix-org/synapse/issues/5154))
Internal Changes
----------------
@ -78,6 +82,7 @@ Internal Changes
- Prevent an exception from being raised in a IResolutionReceiver and use a more generic error message for blacklisted URL previews. ([\#5155](https://github.com/matrix-org/synapse/issues/5155))
- Run `black` on the tests directory. ([\#5170](https://github.com/matrix-org/synapse/issues/5170))
- Fix CI after new release of isort. ([\#5179](https://github.com/matrix-org/synapse/issues/5179))
- Fix bogus imports in unit tests. ([\#5154](https://github.com/matrix-org/synapse/issues/5154))
Synapse 0.99.3.2 (2019-05-03)

View file

@ -35,7 +35,7 @@ virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools
pip install matrix-synapse[all]
pip install matrix-synapse
```
This will download Synapse from [PyPI](https://pypi.org/project/matrix-synapse)
@ -48,7 +48,7 @@ update flag:
```
source ~/synapse/env/bin/activate
pip install -U matrix-synapse[all]
pip install -U matrix-synapse
```
Before you can start Synapse, you will need to generate a configuration

1
changelog.d/3484.misc Normal file
View file

@ -0,0 +1 @@
Make /sync attempt to return device updates for both joined and invited users. Note that this doesn't currently work correctly due to other bugs.

1
changelog.d/5039.bugfix Normal file
View file

@ -0,0 +1 @@
Fix image orientation when generating thumbnails (needs pillow>=4.3.0). Contributed by Pau Rodriguez-Estivill.

1
changelog.d/5043.feature Normal file
View file

@ -0,0 +1 @@
Add ability to blacklist IP ranges for the federation client.

1
changelog.d/5171.misc Normal file
View file

@ -0,0 +1 @@
Update tests to consistently be configured via the same code that is used when loading from configuration files.

1
changelog.d/5174.bugfix Normal file
View file

@ -0,0 +1 @@
Re-order stages in registration flows such that msisdn and email verification are done last.

1
changelog.d/5177.bugfix Normal file
View file

@ -0,0 +1 @@
Fix 3pid guest invites.

1
changelog.d/5181.feature Normal file
View file

@ -0,0 +1 @@
Ratelimiting configuration for clients sending messages and the federation server has been altered to match login ratelimiting. The old configuration names will continue working. Check the sample config for details of the new names.

1
changelog.d/5183.misc Normal file
View file

@ -0,0 +1 @@
Allow client event serialization to be async.

1
changelog.d/5184.misc Normal file
View file

@ -0,0 +1 @@
Expose DataStore._get_events as get_events_as_list.

1
changelog.d/5185.misc Normal file
View file

@ -0,0 +1 @@
Update tests to consistently be configured via the same code that is used when loading from configuration files.

1
changelog.d/5187.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where the register endpoint would fail with M_THREEPID_IN_USE instead of returning an account previously registered in the same session.

1
changelog.d/5190.feature Normal file
View file

@ -0,0 +1 @@
Drop support for the undocumented /_matrix/client/v2_alpha API prefix.

1
changelog.d/5191.misc Normal file
View file

@ -0,0 +1 @@
Make generating SQL bounds for pagination generic.

1
changelog.d/5196.feature Normal file
View file

@ -0,0 +1 @@
Add an option to disable per-room profiles.

1
changelog.d/5197.misc Normal file
View file

@ -0,0 +1 @@
Stop telling people to install the optional dependencies by default.

1
changelog.d/5198.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent registration for user ids that are to long to fit into a state key. Contributed by Reid Anderson.

1
changelog.d/5200.bugfix Normal file
View file

@ -0,0 +1 @@
Fix worker registration bug caused by ClientReaderSlavedStore being unable to see get_profileinfo.

1
changelog.d/5209.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for relations (aka reactions and edits).

1
changelog.d/5210.feature Normal file
View file

@ -0,0 +1 @@
Add a new room version which uses a new event ID format.

1
changelog.d/5211.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for relations (aka reactions and edits).

7
debian/changelog vendored
View file

@ -1,9 +1,12 @@
matrix-synapse-py3 (0.99.3.2+nmu1) UNRELEASED; urgency=medium
matrix-synapse-py3 (0.99.4) stable; urgency=medium
[ Christoph Müller ]
* Configure the systemd units to have a log identifier of `matrix-synapse`
-- Christoph Müller <iblzm@hotmail.de> Wed, 17 Apr 2019 16:17:32 +0200
[ Synapse Packaging team ]
* New synapse release 0.99.4.
-- Synapse Packaging team <packages@matrix.org> Wed, 15 May 2019 13:58:08 +0100
matrix-synapse-py3 (0.99.3.2) stable; urgency=medium

2
debian/test/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
.vagrant
*.log

23
debian/test/provision.sh vendored Normal file
View file

@ -0,0 +1,23 @@
#!/bin/bash
#
# provisioning script for vagrant boxes for testing the matrix-synapse debs.
#
# Will install the most recent matrix-synapse-py3 deb for this platform from
# the /debs directory.
set -e
apt-get update
apt-get install -y lsb-release
deb=`ls /debs/matrix-synapse-py3_*+$(lsb_release -cs)*.deb | sort | tail -n1`
debconf-set-selections <<EOF
matrix-synapse matrix-synapse/report-stats boolean false
matrix-synapse matrix-synapse/server-name string localhost:18448
EOF
dpkg -i "$deb"
sed -i -e '/port: 8...$/{s/8448/18448/; s/8008/18008/}' -e '$aregistration_shared_secret: secret' /etc/matrix-synapse/homeserver.yaml
systemctl restart matrix-synapse

13
debian/test/stretch/Vagrantfile vendored Normal file
View file

@ -0,0 +1,13 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
ver = `cd ../../..; dpkg-parsechangelog -S Version`.strip()
Vagrant.configure("2") do |config|
config.vm.box = "debian/stretch64"
config.vm.synced_folder ".", "/vagrant", disabled: true
config.vm.synced_folder "../../../../debs", "/debs", type: "nfs"
config.vm.provision "shell", path: "../provision.sh"
end

10
debian/test/xenial/Vagrantfile vendored Normal file
View file

@ -0,0 +1,10 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
Vagrant.configure("2") do |config|
config.vm.box = "ubuntu/xenial64"
config.vm.synced_folder ".", "/vagrant", disabled: true
config.vm.synced_folder "../../../../debs", "/debs"
config.vm.provision "shell", path: "../provision.sh"
end

View file

@ -3,6 +3,28 @@ Using Postgres
Postgres version 9.4 or later is known to work.
Install postgres client libraries
=================================
Synapse will require the python postgres client library in order to connect to
a postgres database.
* If you are using the `matrix.org debian/ubuntu
packages <../INSTALL.md#matrixorg-packages>`_,
the necessary libraries will already be installed.
* For other pre-built packages, please consult the documentation from the
relevant package.
* If you installed synapse `in a virtualenv
<../INSTALL.md#installing-from-source>`_, you can install the library with::
~/synapse/env/bin/pip install matrix-synapse[postgres]
(substituting the path to your virtualenv for ``~/synapse/env``, if you used a
different path). You will require the postgres development files. These are in
the ``libpq-dev`` package on Debian-derived distributions.
Set up database
===============
@ -26,29 +48,6 @@ encoding use, e.g.::
This would create an appropriate database named ``synapse`` owned by the
``synapse_user`` user (which must already exist).
Set up client in Debian/Ubuntu
===========================
Postgres support depends on the postgres python connector ``psycopg2``. In the
virtual env::
sudo apt-get install libpq-dev
pip install psycopg2
Set up client in RHEL/CentOs 7
==============================
Make sure you have the appropriate version of postgres-devel installed. For a
postgres 9.4, use the postgres 9.4 packages from
[here](https://wiki.postgresql.org/wiki/YUM_Installation).
As with Debian/Ubuntu, postgres support depends on the postgres python connector
``psycopg2``. In the virtual env::
sudo yum install postgresql-devel libpqxx-devel.x86_64
export PATH=/usr/pgsql-9.4/bin/:$PATH
pip install psycopg2
Tuning Postgres
===============

View file

@ -115,6 +115,24 @@ pid_file: DATADIR/homeserver.pid
# - nyc.example.com
# - syd.example.com
# Prevent federation requests from being sent to the following
# blacklist IP address CIDR ranges. If this option is not specified, or
# specified with an empty list, no ip range blacklist will be enforced.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
federation_ip_range_blacklist:
- '127.0.0.0/8'
- '10.0.0.0/8'
- '172.16.0.0/12'
- '192.168.0.0/16'
- '100.64.0.0/10'
- '169.254.0.0/16'
- '::1/128'
- 'fe80::/64'
- 'fc00::/7'
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@ -258,6 +276,12 @@ listeners:
#
#require_membership_for_aliases: false
# Whether to allow per-room membership profiles through the send of membership
# events with profile information that differ from the target's global profile.
# Defaults to 'true'.
#
#allow_per_room_profiles: false
## TLS ##
@ -428,21 +452,15 @@ log_config: "CONFDIR/SERVERNAME.log.config"
## Ratelimiting ##
# Number of messages a client can send per second
#
#rc_messages_per_second: 0.2
# Number of message a client can send before being throttled
#
#rc_message_burst_count: 10.0
# Ratelimiting settings for registration and login.
# Ratelimiting settings for client actions (registration, login, messaging).
#
# Each ratelimiting configuration is made of two parameters:
# - per_second: number of requests a client can send per second.
# - burst_count: number of requests a client can send before being throttled.
#
# Synapse currently uses the following configurations:
# - one for messages that ratelimits sending based on the account the client
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
# - one for login that ratelimits login requests based on the client's IP
@ -455,6 +473,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
#
# The defaults are as shown below.
#
#rc_message:
# per_second: 0.2
# burst_count: 10
#
#rc_registration:
# per_second: 0.17
# burst_count: 3
@ -470,29 +492,28 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
# The federation window size in milliseconds
#
#federation_rc_window_size: 1000
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
# Ratelimiting settings for incoming federation
#
#federation_rc_sleep_limit: 10
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
# The rc_federation configuration is made up of the following settings:
# - window_size: window size in milliseconds
# - sleep_limit: number of federation requests from a single server in
# a window before the server will delay processing the request.
# - sleep_delay: duration in milliseconds to delay processing events
# from remote servers by if they go over the sleep limit.
# - reject_limit: maximum number of concurrent federation requests
# allowed from a single server
# - concurrent: number of federation requests to concurrently process
# from a single server
#
#federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
# The defaults are as shown below.
#
#federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
#
#federation_rc_concurrent: 3
#rc_federation:
# window_size: 1000
# sleep_limit: 10
# sleep_delay: 500
# reject_limit: 50
# concurrent: 3
# Target outgoing federation transaction frequency for sending read-receipts,
# per-room.

View file

@ -27,4 +27,4 @@ try:
except ImportError:
pass
__version__ = "0.99.4rc1"
__version__ = "0.99.4"

View file

@ -23,6 +23,9 @@ MAX_DEPTH = 2**63 - 1
# the maximum length for a room alias is 255 characters
MAX_ALIAS_LENGTH = 255
# the maximum length for a user id is 255 characters
MAX_USERID_LENGTH = 255
class Membership(object):
@ -116,3 +119,11 @@ class UserTypes(object):
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
class RelationTypes(object):
"""The types of relations known to this server.
"""
ANNOTATION = "m.annotation"
REPLACE = "m.replace"
REFERENCE = "m.reference"

View file

@ -19,13 +19,15 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
V1 = 1 # $id:server format
V2 = 2 # MSC1659-style $hash format: introduced for room v3
V1 = 1 # $id:server event id format
V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1,
EventFormatVersions.V2,
EventFormatVersions.V3,
}
@ -75,6 +77,12 @@ class RoomVersions(object):
EventFormatVersions.V2,
StateResolutionVersions.V2,
)
EVENTID_NOSLASH_TEST = RoomVersion(
"eventid-noslash-test",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
)
# the version we will give rooms which are created on this server
@ -87,5 +95,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
RoomVersions.EVENTID_NOSLASH_TEST,
)
} # type: dict[str, RoomVersion]

View file

@ -22,8 +22,7 @@ from six.moves.urllib.parse import urlencode
from synapse.config import ConfigError
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
CLIENT_API_PREFIX = "/_matrix/client"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"

View file

@ -29,6 +29,7 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage import SlavedProfileStore
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@ -83,6 +84,7 @@ class ClientReaderSlavedStore(
SlavedTransactionStore,
SlavedClientIpStore,
BaseSlavedStore,
SlavedProfileStore,
):
pass

View file

@ -16,16 +16,56 @@ from ._base import Config
class RateLimitConfig(object):
def __init__(self, config):
self.per_second = config.get("per_second", 0.17)
self.burst_count = config.get("burst_count", 3.0)
def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"])
class FederationRateLimitConfig(object):
_items_and_default = {
"window_size": 10000,
"sleep_limit": 10,
"sleep_delay": 500,
"reject_limit": 50,
"concurrent": 3,
}
def __init__(self, **kwargs):
for i in self._items_and_default.keys():
setattr(self, i, kwargs.get(i) or self._items_and_default[i])
class RatelimitConfig(Config):
def read_config(self, config):
self.rc_messages_per_second = config.get("rc_messages_per_second", 0.2)
self.rc_message_burst_count = config.get("rc_message_burst_count", 10.0)
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
if "rc_message" in config:
self.rc_message = RateLimitConfig(
config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
)
else:
self.rc_message = RateLimitConfig(
{
"per_second": config.get("rc_messages_per_second", 0.2),
"burst_count": config.get("rc_message_burst_count", 10.0),
}
)
# Load the new-style federation config, if it exists. Otherwise, fall
# back to the old method.
if "federation_rc" in config:
self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
else:
self.rc_federation = FederationRateLimitConfig(
**{
"window_size": config.get("federation_rc_window_size"),
"sleep_limit": config.get("federation_rc_sleep_limit"),
"sleep_delay": config.get("federation_rc_sleep_delay"),
"reject_limit": config.get("federation_rc_reject_limit"),
"concurrent": config.get("federation_rc_concurrent"),
}
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
@ -33,38 +73,26 @@ class RatelimitConfig(Config):
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
self.rc_login_failed_attempts = RateLimitConfig(
rc_login_config.get("failed_attempts", {}),
rc_login_config.get("failed_attempts", {})
)
self.federation_rc_window_size = config.get("federation_rc_window_size", 1000)
self.federation_rc_sleep_limit = config.get("federation_rc_sleep_limit", 10)
self.federation_rc_sleep_delay = config.get("federation_rc_sleep_delay", 500)
self.federation_rc_reject_limit = config.get("federation_rc_reject_limit", 50)
self.federation_rc_concurrent = config.get("federation_rc_concurrent", 3)
self.federation_rr_transactions_per_room_per_second = config.get(
"federation_rr_transactions_per_room_per_second", 50,
"federation_rr_transactions_per_room_per_second", 50
)
def default_config(self, **kwargs):
return """\
## Ratelimiting ##
# Number of messages a client can send per second
#
#rc_messages_per_second: 0.2
# Number of message a client can send before being throttled
#
#rc_message_burst_count: 10.0
# Ratelimiting settings for registration and login.
# Ratelimiting settings for client actions (registration, login, messaging).
#
# Each ratelimiting configuration is made of two parameters:
# - per_second: number of requests a client can send per second.
# - burst_count: number of requests a client can send before being throttled.
#
# Synapse currently uses the following configurations:
# - one for messages that ratelimits sending based on the account the client
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
# - one for login that ratelimits login requests based on the client's IP
@ -77,6 +105,10 @@ class RatelimitConfig(Config):
#
# The defaults are as shown below.
#
#rc_message:
# per_second: 0.2
# burst_count: 10
#
#rc_registration:
# per_second: 0.17
# burst_count: 3
@ -92,29 +124,28 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
# The federation window size in milliseconds
#
#federation_rc_window_size: 1000
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
# Ratelimiting settings for incoming federation
#
#federation_rc_sleep_limit: 10
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
# The rc_federation configuration is made up of the following settings:
# - window_size: window size in milliseconds
# - sleep_limit: number of federation requests from a single server in
# a window before the server will delay processing the request.
# - sleep_delay: duration in milliseconds to delay processing events
# from remote servers by if they go over the sleep limit.
# - reject_limit: maximum number of concurrent federation requests
# allowed from a single server
# - concurrent: number of federation requests to concurrently process
# from a single server
#
#federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
# The defaults are as shown below.
#
#federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
#
#federation_rc_concurrent: 3
#rc_federation:
# window_size: 1000
# sleep_limit: 10
# sleep_delay: 500
# reject_limit: 50
# concurrent: 3
# Target outgoing federation transaction frequency for sending read-receipts,
# per-room.

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,6 +18,8 @@
import logging
import os.path
from netaddr import IPSet
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@ -98,6 +101,11 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
# Whether to enable experimental MSC1849 (aka relations) support
self.experimental_msc1849_support_enabled = config.get(
"experimental_msc1849_support_enabled", False,
)
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
@ -137,6 +145,24 @@ class ServerConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
self.federation_ip_range_blacklist = config.get(
"federation_ip_range_blacklist", [],
)
# Attempt to create an IPSet from the given ranges
try:
self.federation_ip_range_blacklist = IPSet(
self.federation_ip_range_blacklist
)
# Always blacklist 0.0.0.0, ::
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
"Invalid range(s) provided in "
"federation_ip_range_blacklist: %s" % e
)
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@ -153,6 +179,10 @@ class ServerConfig(Config):
"require_membership_for_aliases", True,
)
# Whether to allow per-room membership profiles through the send of membership
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@ -386,6 +416,24 @@ class ServerConfig(Config):
# - nyc.example.com
# - syd.example.com
# Prevent federation requests from being sent to the following
# blacklist IP address CIDR ranges. If this option is not specified, or
# specified with an empty list, no ip range blacklist will be enforced.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
federation_ip_range_blacklist:
- '127.0.0.0/8'
- '10.0.0.0/8'
- '172.16.0.0/12'
- '192.168.0.0/16'
- '100.64.0.0/10'
- '169.254.0.0/16'
- '::1/128'
- 'fe80::/64'
- 'fc00::/7'
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@ -528,6 +576,12 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#require_membership_for_aliases: false
# Whether to allow per-room membership profiles through the send of membership
# events with profile information that differ from the target's global profile.
# Defaults to 'true'.
#
#allow_per_room_profiles: false
""" % locals()
def read_arguments(self, args):

View file

@ -335,13 +335,32 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
return "<%s event_id='%s', type='%s', state_key='%s'>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
def event_id(self):
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
if self._event_id:
return self._event_id
self._event_id = "$" + encode_base64(
compute_event_reference_hash(self)[1], urlsafe=True
)
return self._event_id
def room_version_to_event_format(room_version):
"""Converts a room version string to the event format
@ -376,6 +395,8 @@ def event_type_from_format_version(format_version):
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
elif format_version == EventFormatVersions.V3:
return FrozenEventV3
else:
raise Exception(
"No event format %r" % (format_version,)

View file

@ -19,7 +19,10 @@ from six import string_types
from frozendict import frozendict
from synapse.api.constants import EventTypes
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes
from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase
@ -311,3 +314,92 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = only_fields(d, only_event_fields)
return d
class EventClientSerializer(object):
"""Serializes events that are to be sent to clients.
This is used for bundling extra information with any events to be sent to
clients.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = (
hs.config.experimental_msc1849_support_enabled
)
@defer.inlineCallbacks
def serialize_event(self, event, time_now, **kwargs):
"""Serializes a single event.
Args:
event (EventBase)
time_now (int): The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event`
Returns:
Deferred[dict]: The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
defer.returnValue(event)
event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
if self.experimental_msc1849_support_enabled:
annotations = yield self.store.get_aggregation_groups_for_event(
event_id,
)
references = yield self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f",
)
if annotations.chunk:
r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.ANNOTATION] = annotations.to_dict()
if references.chunk:
r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.REFERENCE] = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit(event_id)
if edit:
# If there is an edit replace the content, preserving existing
# relations.
relations = event.content.get("m.relates_to")
serialized_event["content"] = edit.content.get("m.new_content", {})
if relations:
serialized_event["content"]["m.relates_to"] = relations
else:
serialized_event["content"].pop("m.relates_to", None)
r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
}
defer.returnValue(serialized_event)
def serialize_events(self, events, time_now, **kwargs):
"""Serializes multiple events.
Args:
event (iter[EventBase])
time_now (int): The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event`
Returns:
Deferred[list[dict]]: The list of serialized events
"""
return yieldable_gather_results(
self.serialize_event, events,
time_now=time_now, **kwargs
)

View file

@ -63,11 +63,7 @@ class TransportLayerServer(JsonResource):
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
config=hs.config.rc_federation,
)
self.register_servlets()

View file

@ -90,8 +90,8 @@ class BaseHandler(object):
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
messages_per_second = self.hs.config.rc_messages_per_second
burst_count = self.hs.config.rc_message_burst_count
messages_per_second = self.hs.config.rc_message.per_second
burst_count = self.hs.config.rc_message.burst_count
allowed, time_allowed = self.ratelimiter.can_do_action(
user_id, time_now,

View file

@ -21,7 +21,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.types import UserID
from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
@ -50,6 +49,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@log_function
@ -120,9 +120,9 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
chunks = [
serialize_event(e, time_now, as_client_event) for e in events
]
chunks = yield self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event,
)
chunk = {
"chunk": chunks,

View file

@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
@ -43,6 +42,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
@ -138,7 +138,9 @@ class InitialSyncHandler(BaseHandler):
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
d["invite"] = yield self._event_serializer.serialize_event(
invite_event, time_now, as_client_event,
)
rooms_ret.append(d)
@ -185,18 +187,21 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"chunk": (
yield self._event_serializer.serialize_events(
messages, time_now=time_now,
as_client_event=as_client_event,
)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
d["state"] = yield self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event
)
account_data_events = []
tags = tags_by_room.get(event.room_id)
@ -337,11 +342,15 @@ class InitialSyncHandler(BaseHandler):
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"chunk": (yield self._event_serializer.serialize_events(
messages, time_now,
)),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"state": (yield self._event_serializer.serialize_events(
room_state.values(), time_now,
)),
"presence": [],
"receipts": [],
})
@ -355,10 +364,9 @@ class InitialSyncHandler(BaseHandler):
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
state = yield self._event_serializer.serialize_events(
current_state.values(), time_now,
)
now_token = yield self.hs.get_event_sources().get_current_token()
@ -425,7 +433,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"chunk": (yield self._event_serializer.serialize_events(
messages, time_now,
)),
"start": start_token.to_string(),
"end": end_token.to_string(),
},

View file

@ -32,7 +32,6 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
@ -57,6 +56,7 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
@ -164,9 +164,10 @@ class MessageHandler(object):
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
defer.returnValue(
[serialize_event(c, now) for c in room_state.values()]
events = yield self._event_serializer.serialize_events(
room_state.values(), now,
)
defer.returnValue(events)
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):

View file

@ -20,7 +20,6 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@ -78,6 +77,7 @@ class PaginationHandler(object):
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
def start_purge_history(self, room_id, token,
delete_local_events=False):
@ -278,18 +278,22 @@ class PaginationHandler(object):
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now, as_client_event)
for e in events
],
"chunk": (
yield self._event_serializer.serialize_events(
events, time_now,
as_client_event=as_client_event,
)
),
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
if state:
chunk["state"] = [
serialize_event(e, time_now, as_client_event)
for e in state
]
chunk["state"] = (
yield self._event_serializer.serialize_events(
state, time_now,
as_client_event=as_client_event,
)
)
defer.returnValue(chunk)

View file

@ -19,7 +19,7 @@ import logging
from twisted.internet import defer
from synapse import types
from synapse.api.constants import LoginType
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import (
AuthError,
Codes,
@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id)
if len(user_id) > MAX_USERID_LENGTH:
raise SynapseError(
400,
"User ID may not be longer than %s characters" % (
MAX_USERID_LENGTH,
),
Codes.INVALID_USERNAME
)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -74,6 +75,7 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self._server_notices_mxid = self.config.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
@ -377,6 +379,13 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
if not self.allow_per_room_profiles:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
content.pop("displayname", None)
content.pop("avatar_url", None)
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@ -955,7 +964,7 @@ class RoomMemberHandler(object):
}
if self.config.invite_3pid_guest:
guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
requester=requester,
medium=medium,
address=address,

View file

@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@ -36,6 +35,7 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@ -401,14 +401,16 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
context["events_before"] = [
serialize_event(e, time_now)
for e in context["events_before"]
]
context["events_after"] = [
serialize_event(e, time_now)
for e in context["events_after"]
]
context["events_before"] = (
yield self._event_serializer.serialize_events(
context["events_before"], time_now,
)
)
context["events_after"] = (
yield self._event_serializer.serialize_events(
context["events_after"], time_now,
)
)
state_results = {}
if include_state:
@ -422,14 +424,13 @@ class SearchHandler(BaseHandler):
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
results = [
{
results = []
for e in allowed_events:
results.append({
"rank": rank_map[e.event_id],
"result": serialize_event(e, time_now),
"result": (yield self._event_serializer.serialize_event(e, time_now)),
"context": contexts.get(e.event_id, {}),
}
for e in allowed_events
]
})
rooms_cat_res = {
"results": results,
@ -438,10 +439,13 @@ class SearchHandler(BaseHandler):
}
if state_results:
rooms_cat_res["state"] = {
room_id: [serialize_event(e, time_now) for e in state]
for room_id, state in state_results.items()
}
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
state, time_now,
)
rooms_cat_res["state"] = s
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups

View file

@ -937,7 +937,7 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
newly_joined_rooms, newly_joined_users, _, _ = res
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@ -946,7 +946,7 @@ class SyncHandler(object):
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder)
@ -954,7 +954,7 @@ class SyncHandler(object):
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
newly_joined_users=newly_joined_users,
newly_joined_or_invited_users=newly_joined_or_invited_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@ -1039,7 +1039,8 @@ class SyncHandler(object):
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder,
newly_joined_rooms, newly_joined_users,
newly_joined_rooms,
newly_joined_or_invited_users,
newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@ -1053,7 +1054,7 @@ class SyncHandler(object):
# share a room with?
for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_users_in_room(room_id)
newly_joined_users.update(joined_users)
newly_joined_or_invited_users.update(joined_users)
for room_id in newly_left_rooms:
left_users = yield self.state.get_current_users_in_room(room_id)
@ -1061,7 +1062,7 @@ class SyncHandler(object):
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
changed.update(newly_joined_users)
changed.update(newly_joined_or_invited_users)
if not changed and not newly_left_users:
defer.returnValue(DeviceLists(
@ -1179,7 +1180,7 @@ class SyncHandler(object):
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
newly_joined_users):
newly_joined_or_invited_users):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
@ -1187,8 +1188,9 @@ class SyncHandler(object):
sync_result_builder(SyncResultBuilder)
newly_joined_rooms(list): List of rooms that the user has joined
since the last sync (or empty if an initial sync)
newly_joined_users(list): List of users that have joined rooms
since the last sync (or empty if an initial sync)
newly_joined_or_invited_users(list): List of users that have joined
or been invited to rooms since the last sync (or empty if an initial
sync)
"""
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@ -1214,7 +1216,7 @@ class SyncHandler(object):
"presence_key", presence_key
)
extra_users_ids = set(newly_joined_users)
extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
@ -1246,7 +1248,8 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a 4-tuple of
`(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
`(newly_joined_rooms, newly_joined_or_invited_users,
newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@ -1317,8 +1320,8 @@ class SyncHandler(object):
sync_result_builder.invited.extend(invited)
# Now we want to get any newly joined users
newly_joined_users = set()
# Now we want to get any newly joined or invited users
newly_joined_or_invited_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@ -1327,19 +1330,22 @@ class SyncHandler(object):
)
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
newly_joined_users.add(event.state_key)
if (
event.membership == Membership.JOIN or
event.membership == Membership.INVITE
):
newly_joined_or_invited_users.add(event.state_key)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
newly_left_users -= newly_joined_users
newly_left_users -= newly_joined_or_invited_users
defer.returnValue((
newly_joined_rooms,
newly_joined_users,
newly_joined_or_invited_users,
newly_left_rooms,
newly_left_users,
))
@ -1384,7 +1390,7 @@ class SyncHandler(object):
where:
room_entries is a list [RoomSyncResultBuilder]
invited_rooms is a list [InvitedSyncResult]
newly_joined rooms is a list[str] of room ids
newly_joined_rooms is a list[str] of room ids
newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
@ -1425,7 +1431,7 @@ class SyncHandler(object):
if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly.
# If there are non join member events, but we are still in the room,
# If there are non-join member events, but we are still in the room,
# then the user must have left and joined
newly_joined_rooms.append(room_id)

View file

@ -165,7 +165,8 @@ class BlacklistingAgentWrapper(Agent):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
"Blocking access to %s because of blacklist" % (ip_address,)
"Blocking access to %s due to blacklist" %
(ip_address,)
)
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
@ -263,9 +264,6 @@ class SimpleHttpClient(object):
uri (str): URI to query.
data (bytes): Data to send in the request body, if applicable.
headers (t.w.http_headers.Headers): Request headers.
Raises:
SynapseError: If the IP is blacklisted.
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it

View file

@ -27,9 +27,11 @@ import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
@ -44,6 +46,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
@ -172,19 +175,44 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
reactor = hs.get_reactor()
real_reactor = hs.get_reactor()
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver(
real_reactor, None, hs.config.federation_ip_range_blacklist,
)
@implementer(IReactorPluggableNameResolver)
class Reactor(object):
def __getattr__(_self, attr):
if attr == "nameResolver":
return nameResolver
else:
return getattr(real_reactor, attr)
self.reactor = Reactor()
self.agent = MatrixFederationAgent(
hs.get_reactor(),
self.reactor,
tls_client_options_factory,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent, self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string_bytes = hs.version_string.encode('ascii')
self.default_timeout = 60
def schedule(x):
reactor.callLater(_EPSILON, x)
self.reactor.callLater(_EPSILON, x)
self._cooperator = Cooperator(scheduler=schedule)
@ -370,7 +398,7 @@ class MatrixFederationHttpClient(object):
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
reactor=self.reactor,
)
response = yield request_deferred
@ -397,7 +425,7 @@ class MatrixFederationHttpClient(object):
d = timeout_deferred(
d,
timeout=_sec_timeout,
reactor=self.hs.get_reactor(),
reactor=self.reactor,
)
try:
@ -586,7 +614,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
self.hs.get_reactor(), self.default_timeout, request, response,
self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@ -645,7 +673,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
self.hs.get_reactor(), _sec_timeout, request, response,
self.reactor, _sec_timeout, request, response,
)
defer.returnValue(body)
@ -704,7 +732,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
self.hs.get_reactor(), self.default_timeout, request, response,
self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@ -753,7 +781,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
self.hs.get_reactor(), self.default_timeout, request, response,
self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@ -801,7 +829,7 @@ class MatrixFederationHttpClient(object):
try:
d = _readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.hs.get_reactor())
d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
except Exception as e:
logger.warn(

View file

@ -53,7 +53,7 @@ REQUIREMENTS = [
"pyasn1-modules>=0.0.7",
"daemonize>=2.3.1",
"bcrypt>=3.1.0",
"pillow>=3.1.2",
"pillow>=4.3.0",
"sortedcontainers>=1.4.4",
"psutil>=2.0.0",
"pymacaroons>=0.13.0",

View file

@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.relations import RelationsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage.state import StateGroupWorkerStore
@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore,
EventsWorkerStore,
SignatureWorkerStore,
UserErasureWorkerStore,
RelationsWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs):
@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore,
for row in rows:
self.invalidate_caches_for_event(
-token, row.event_id, row.room_id, row.type, row.state_key,
row.redacts,
row.redacts, row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore,
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
token, data.event_id, data.room_id, data.type, data.state_key,
data.redacts,
data.redacts, data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore,
raise Exception("Unknown events stream row type %s" % (row.type, ))
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
etype, state_key, redacts, backfilled):
etype, state_key, redacts, relates_to,
backfilled):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
@ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore,
state_key, stream_ordering
)
self.get_invited_rooms_for_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))

View file

@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", (
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
))
PresenceStreamRow = namedtuple("PresenceStreamRow", (
"user_id", # str

View file

@ -80,11 +80,12 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
event_id = attr.ib() # str
room_id = attr.ib() # str
type = attr.ib() # str
state_key = attr.ib() # str, optional
redacts = attr.ib() # str, optional
event_id = attr.ib() # str
room_id = attr.ib() # str
type = attr.ib() # str
state_key = attr.ib() # str, optional
redacts = attr.ib() # str, optional
relates_to = attr.ib() # str, optional
@attr.s(slots=True, frozen=True)

View file

@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import (
read_marker,
receipts,
register,
relations,
report_event,
room_keys,
room_upgrade_rest_servlet,
@ -115,6 +116,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(

View file

@ -19,7 +19,7 @@
import logging
import re
from synapse.api.urls import CLIENT_PREFIX
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.servlet import RestServlet
from synapse.rest.client.transactions import HttpTransactionCache
@ -36,12 +36,12 @@ def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release)
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns

View file

@ -19,7 +19,6 @@ import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
@ -84,6 +83,7 @@ class EventRestServlet(ClientV1RestServlet):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
@ -92,7 +92,8 @@ class EventRestServlet(ClientV1RestServlet):
time_now = self.clock.time_msec()
if event:
defer.returnValue((200, serialize_event(event, time_now)))
event = yield self._event_serializer.serialize_event(event, time_now)
defer.returnValue((200, event))
else:
defer.returnValue((404, "Event not found."))

View file

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
@ -537,6 +537,7 @@ class RoomEventServlet(ClientV1RestServlet):
super(RoomEventServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@ -545,7 +546,8 @@ class RoomEventServlet(ClientV1RestServlet):
time_now = self.clock.time_msec()
if event:
defer.returnValue((200, serialize_event(event, time_now)))
event = yield self._event_serializer.serialize_event(event, time_now)
defer.returnValue((200, event))
else:
defer.returnValue((404, "Event not found."))
@ -559,6 +561,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@ -588,16 +591,18 @@ class RoomEventContextServlet(ClientV1RestServlet):
)
time_now = self.clock.time_msec()
results["events_before"] = [
serialize_event(event, time_now) for event in results["events_before"]
]
results["event"] = serialize_event(results["event"], time_now)
results["events_after"] = [
serialize_event(event, time_now) for event in results["events_after"]
]
results["state"] = [
serialize_event(event, time_now) for event in results["state"]
]
results["events_before"] = yield self._event_serializer.serialize_events(
results["events_before"], time_now,
)
results["event"] = yield self._event_serializer.serialize_event(
results["event"], time_now,
)
results["events_after"] = yield self._event_serializer.serialize_events(
results["events_after"], time_now,
)
results["state"] = yield self._event_serializer.serialize_events(
results["state"], time_now,
)
defer.returnValue((200, results))

View file

@ -21,13 +21,12 @@ import re
from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,),
v2_alpha=True,
unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
@ -39,13 +38,11 @@ def client_v2_patterns(path_regex, releases=(0,),
SRE_Pattern
"""
patterns = []
if v2_alpha:
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string
@ -139,8 +139,8 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA:
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
'myurl': "%s/r0/auth/%s/fallback/web" % (
CLIENT_API_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
@ -159,8 +159,8 @@ class AuthRestServlet(RestServlet):
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS
'myurl': "%s/r0/auth/%s/fallback/web" % (
CLIENT_API_PREFIX, LoginType.TERMS
),
}
html_bytes = html.encode("utf8")
@ -203,8 +203,8 @@ class AuthRestServlet(RestServlet):
else:
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
'myurl': "%s/r0/auth/%s/fallback/web" % (
CLIENT_API_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
@ -240,8 +240,8 @@ class AuthRestServlet(RestServlet):
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS
'myurl': "%s/r0/auth/%s/fallback/web" % (
CLIENT_API_PREFIX, LoginType.TERMS
),
}
html_bytes = html.encode("utf8")

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
PATTERNS = client_v2_patterns("/devices$")
def __init__(self, hs):
"""
@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
PATTERNS = client_v2_patterns("/delete_devices")
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs):
"""

View file

@ -17,10 +17,7 @@ import logging
from twisted.internet import defer
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
serialize_event,
)
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns
@ -36,6 +33,7 @@ class NotificationsServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request):
@ -69,11 +67,11 @@ class NotificationsServlet(RestServlet):
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
"event": serialize_event(
"event": (yield self._event_serializer.serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
),
)),
}
if pa["room_id"] not in receipts_by_room:

View file

@ -31,6 +31,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
@ -153,16 +154,18 @@ class UsernameAvailabilityRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
hs.get_clock(),
# Time window of 2s
window_size=2000,
# Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1,
# Amount of artificial delay to apply
sleep_msec=1000,
# Error with 429 if more than reject_limit requests are queued
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
FederationRateLimitConfig(
# Time window of 2s
window_size=2000,
# Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1,
# Amount of artificial delay to apply
sleep_msec=1000,
# Error with 429 if more than reject_limit requests are queued
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
)
)
@defer.inlineCallbacks
@ -345,18 +348,22 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.enable_registration_captcha:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([[LoginType.RECAPTCHA]])
# Also add a dummy flow here, otherwise if a client completes
# recaptcha first we'll assume they were going for this flow
# and complete the request, when they could have been trying to
# complete one of the flows with email/msisdn auth.
flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
else:
# only support 3PIDless registration if no 3PIDs are required
@ -379,7 +386,15 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.user_consent_at_registration:
new_flows = []
for flow in flows:
flow.append(LoginType.TERMS)
inserted = False
# m.login.terms should go near the end but before msisdn or email auth
for i, stage in enumerate(flow):
if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
flow.insert(i, LoginType.TERMS)
inserted = True
break
if not inserted:
flow.append(LoginType.TERMS)
flows.extend(new_flows)
auth_result, params, session_id = yield self.auth_handler.check_auth(
@ -391,13 +406,6 @@ class RegisterRestServlet(RestServlet):
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
#
# Also check that we're not trying to register a 3pid that's already
# been registered.
#
# This has probably happened in /register/email/requestToken as well,
# but if a user hits this endpoint twice then clicks on each link from
# the two activation emails, they would register the same 3pid twice.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
@ -413,17 +421,6 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
existingUid = yield self.store.get_user_id_by_threepid(
medium, address,
)
if existingUid is not None:
raise SynapseError(
400,
"%s is already in use" % medium,
Codes.THREEPID_IN_USE,
)
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
@ -446,6 +443,28 @@ class RegisterRestServlet(RestServlet):
if auth_result:
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
# Also check that we're not trying to register a 3pid that's already
# been registered.
#
# This has probably happened in /register/email/requestToken as well,
# but if a user hits this endpoint twice then clicks on each link from
# the two activation emails, they would register the same 3pid twice.
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
medium = auth_result[login_type]['medium']
address = auth_result[login_type]['address']
existingUid = yield self.store.get_user_id_by_threepid(
medium, address,
)
if existingUid is not None:
raise SynapseError(
400,
"%s is already in use" % medium,
Codes.THREEPID_IN_USE,
)
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,

View file

@ -0,0 +1,338 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This class implements the proposed relation APIs from MSC 1849.
Since the MSC has not been approved all APIs here are unstable and may change at
any time to reflect changes in the MSC.
"""
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class RelationSendServlet(RestServlet):
"""Helper API for sending events that have relation data.
Example API shape to send a 👍 reaction to a room:
POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D
{}
{
"event_id": "$foobar"
}
"""
PATTERN = (
"/rooms/(?P<room_id>[^/]*)/send_relation"
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
)
def __init__(self, hs):
super(RelationSendServlet, self).__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
def register(self, http_server):
http_server.register_paths(
"POST",
client_v2_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST,
)
http_server.register_paths(
"PUT",
client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT,
)
def on_PUT(self, request, *args, **kwargs):
return self.txns.fetch_or_execute_request(
request, self.on_PUT_or_POST, request, *args, **kwargs
)
@defer.inlineCallbacks
def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None
):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
# Add relations to a membership is meaningless, so we just deny it
# at the CS API rather than trying to handle it correctly.
raise SynapseError(400, "Cannot send member events with relations")
content = parse_json_object_from_request(request)
aggregation_key = parse_string(request, "key", encoding="utf-8")
content["m.relates_to"] = {
"event_id": parent_id,
"key": aggregation_key,
"rel_type": relation_type,
}
event_dict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
defer.returnValue((200, {"event_id": event.event_id}))
class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
"""
PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(),
)
def __init__(self, hs):
super(RelationPaginationServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string()
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
if from_token:
from_token = RelationPaginationToken.from_string(from_token)
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
result = yield self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
from_token=from_token,
to_token=to_token,
)
events = yield self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = events
defer.returnValue((200, return_value))
class RelationAggregationPaginationServlet(RestServlet):
"""API to paginate aggregation groups of relations, e.g. paginate the
types and counts of the reactions on the events.
Example request and response:
GET /rooms/{room_id}/aggregations/{parent_id}
{
chunk: [
{
"type": "m.reaction",
"key": "👍",
"count": 3
}
]
}
"""
PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(),
)
def __init__(self, hs):
super(RelationAggregationPaginationServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string()
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
if from_token:
from_token = AggregationPaginationToken.from_string(from_token)
if to_token:
to_token = AggregationPaginationToken.from_string(to_token)
res = yield self.store.get_aggregation_groups_for_event(
event_id=parent_id,
event_type=event_type,
limit=limit,
from_token=from_token,
to_token=to_token,
)
defer.returnValue((200, res.to_dict()))
class RelationAggregationGroupPaginationServlet(RestServlet):
"""API to paginate within an aggregation group of relations, e.g. paginate
all the 👍 reactions on an event.
Example request and response:
GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍
{
chunk: [
{
"type": "m.reaction",
"content": {
"m.relates_to": {
"rel_type": "m.annotation",
"key": "👍"
}
}
},
...
]
}
"""
PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
releases=(),
)
def __init__(self, hs):
super(RelationAggregationGroupPaginationServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string()
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
if from_token:
from_token = RelationPaginationToken.from_string(from_token)
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
result = yield self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,
limit=limit,
from_token=from_token,
to_token=to_token,
)
events = yield self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = events
defer.returnValue((200, return_value))
def register_servlets(hs, http_server):
RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
RelationAggregationGroupPaginationServlet(hs).register(http_server)

View file

@ -50,7 +50,6 @@ class RoomUpgradeRestServlet(RestServlet):
PATTERNS = client_v2_patterns(
# /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$",
v2_alpha=False,
)
def __init__(self, hs):

View file

@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
v2_alpha=False
)
def __init__(self, hs):

View file

@ -26,7 +26,6 @@ from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
@ -86,6 +85,7 @@ class SyncRestServlet(RestServlet):
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request):
@ -168,14 +168,14 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
response_content = self.encode_response(
response_content = yield self.encode_response(
time_now, sync_result, requester.access_token_id, filter
)
defer.returnValue((200, response_content))
@staticmethod
def encode_response(time_now, sync_result, access_token_id, filter):
@defer.inlineCallbacks
def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
@ -183,24 +183,24 @@ class SyncRestServlet(RestServlet):
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
joined = SyncRestServlet.encode_joined(
joined = yield self.encode_joined(
sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
invited = SyncRestServlet.encode_invited(
invited = yield self.encode_invited(
sync_result.invited, time_now, access_token_id,
event_formatter,
)
archived = SyncRestServlet.encode_archived(
archived = yield self.encode_archived(
sync_result.archived, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
return {
defer.returnValue({
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
@ -222,7 +222,7 @@ class SyncRestServlet(RestServlet):
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}
})
@staticmethod
def encode_presence(events, time_now):
@ -239,8 +239,8 @@ class SyncRestServlet(RestServlet):
]
}
@staticmethod
def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
@defer.inlineCallbacks
def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the joined rooms in a sync result
@ -261,15 +261,15 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
joined[room.room_id] = yield self.encode_room(
room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
)
return joined
defer.returnValue(joined)
@staticmethod
def encode_invited(rooms, time_now, token_id, event_formatter):
@defer.inlineCallbacks
def encode_invited(self, rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@ -289,7 +289,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
invite = serialize_event(
invite = yield self._event_serializer.serialize_event(
room.invite, time_now, token_id=token_id,
event_format=event_formatter,
is_invite=True,
@ -302,10 +302,10 @@ class SyncRestServlet(RestServlet):
"invite_state": {"events": invited_state}
}
return invited
defer.returnValue(invited)
@staticmethod
def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
@defer.inlineCallbacks
def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the archived rooms in a sync result
@ -326,17 +326,17 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
joined[room.room_id] = yield self.encode_room(
room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
return joined
defer.returnValue(joined)
@staticmethod
@defer.inlineCallbacks
def encode_room(
room, time_now, token_id, joined,
self, room, time_now, token_id, joined,
only_fields, event_formatter,
):
"""
@ -355,9 +355,10 @@ class SyncRestServlet(RestServlet):
Returns:
dict[str, object]: the room, encoded in our response format
"""
def serialize(event):
return serialize_event(
event, time_now, token_id=token_id,
def serialize(events):
return self._event_serializer.serialize_events(
events, time_now=time_now,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
)
@ -376,8 +377,8 @@ class SyncRestServlet(RestServlet):
event.event_id, room.room_id, event.room_id,
)
serialized_state = [serialize(e) for e in state_events]
serialized_timeline = [serialize(e) for e in timeline_events]
serialized_state = yield serialize(state_events)
serialized_timeline = yield serialize(timeline_events)
account_data = room.account_data
@ -397,7 +398,7 @@ class SyncRestServlet(RestServlet):
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
return result
defer.returnValue(result)
def register_servlets(hs, http_server):

View file

@ -444,6 +444,9 @@ class MediaRepository(object):
)
return
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
@ -578,6 +581,12 @@ class MediaRepository(object):
)
return
if thumbnailer.transpose_method is not None:
m_width, m_height = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
thumbnailer.transpose
)
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}

View file

@ -108,6 +108,7 @@ class FileStorageProviderBackend(StorageProvider):
"""
def __init__(self, hs, config):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config

View file

@ -20,6 +20,17 @@ import PIL.Image as Image
logger = logging.getLogger(__name__)
EXIF_ORIENTATION_TAG = 0x0112
EXIF_TRANSPOSE_MAPPINGS = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90
}
class Thumbnailer(object):
@ -31,6 +42,30 @@ class Thumbnailer(object):
def __init__(self, input_path):
self.image = Image.open(input_path)
self.width, self.height = self.image.size
self.transpose_method = None
try:
# We don't use ImageOps.exif_transpose since it crashes with big EXIF
image_exif = self.image._getexif()
if image_exif is not None:
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
except Exception as e:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
def transpose(self):
"""Transpose the image using its EXIF Orientation tag
Returns:
Tuple[int, int]: (width, height) containing the new image size in pixels.
"""
if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method)
self.width, self.height = self.image.size
self.transpose_method = None
# We don't need EXIF any more
self.image.info["exif"] = None
return self.image.size
def aspect(self, max_width, max_height):
"""Calculate the largest size that preserves aspect ratio which

View file

@ -35,6 +35,7 @@ from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
from synapse.events.utils import EventClientSerializer
from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
@ -185,6 +186,7 @@ class HomeServer(object):
'sendmail',
'registration_handler',
'account_validity_handler',
'event_client_serializer',
]
REQUIRED_ON_MASTER_STARTUP = [
@ -511,6 +513,9 @@ class HomeServer(object):
def build_account_validity_handler(self):
return AccountValidityHandler(self)
def build_event_client_serializer(self):
return EventClientSerializer(self)
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -49,6 +49,7 @@ from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
@ -99,6 +100,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
RelationsStore,
):
def __init__(self, db_conn, hs):
self.hs = hs

View file

@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore(
event_ids = json.loads(entry["event_ids"])
events = yield self._get_events(event_ids)
events = yield self.get_events_as_list(event_ids)
defer.returnValue(
AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
events = yield self._get_events(event_ids)
events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events))

View file

@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given
).addCallback(self._get_events)
).addCallback(self.get_events_as_list)
def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list,
limit,
)
.addCallback(self._get_events)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
)
@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
events = yield self._get_events(ids)
events = yield self.get_events_as_list(ids)
defer.returnValue(events)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

View file

@ -1325,6 +1325,9 @@ class EventsStore(
txn, event.room_id, event.redacts
)
# Remove from relations table.
self._handle_redaction(txn, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@ -1351,6 +1354,8 @@ class EventsStore(
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
self._handle_event_relations(txn, event)
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@ -1655,10 +1660,11 @@ class EventsStore(
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@ -1673,11 +1679,12 @@ class EventsStore(
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
@ -1698,10 +1705,11 @@ class EventsStore(
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@ -1716,11 +1724,12 @@ class EventsStore(
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"

View file

@ -103,7 +103,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
events = yield self._get_events(
events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@ -142,7 +142,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self._get_events(
events = yield self.get_events_as_list(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@ -152,13 +152,32 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
def _get_events(
def get_events_as_list(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
events are in the same order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
"""
if not event_ids:
defer.returnValue([])
@ -202,21 +221,22 @@ class EventsWorkerStore(SQLBaseStore):
#
# The problem is that we end up at this point when an event
# which has been redacted is pulled out of the database by
# _enqueue_events, because _enqueue_events needs to check the
# redaction before it can cache the redacted event. So obviously,
# calling get_event to get the redacted event out of the database
# gives us an infinite loop.
# _enqueue_events, because _enqueue_events needs to check
# the redaction before it can cache the redacted event. So
# obviously, calling get_event to get the redacted event out
# of the database gives us an infinite loop.
#
# For now (quick hack to fix during 0.99 release cycle), we just
# go and fetch the relevant row from the db, but it would be nice
# to think about how we can cache this rather than hit the db
# every time we access a redaction event.
# For now (quick hack to fix during 0.99 release cycle), we
# just go and fetch the relevant row from the db, but it
# would be nice to think about how we can cache this rather
# than hit the db every time we access a redaction event.
#
# One thought on how to do this:
# 1. split _get_events up so that it is divided into (a) get the
# rawish event from the db/cache, (b) do the redaction/rejection
# filtering
# 2. have _get_event_from_row just call the first half of that
# 1. split get_events_as_list up so that it is divided into
# (a) get the rawish event from the db/cache, (b) do the
# redaction/rejection filtering
# 2. have _get_event_from_row just call the first half of
# that
orig_sender = yield self._simple_select_one_onecol(
table="events",

View file

@ -0,0 +1,434 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import attr
from twisted.internet import defer
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.stream import generate_pagination_where_clause
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
@attr.s
class PaginationChunk(object):
"""Returned by relation pagination APIs.
Attributes:
chunk (list): The rows returned by pagination
next_batch (Any|None): Token to fetch next set of results with, if
None then there are no more results.
prev_batch (Any|None): Token to fetch previous set of results with, if
None then there are no previous results.
"""
chunk = attr.ib()
next_batch = attr.ib(default=None)
prev_batch = attr.ib(default=None)
def to_dict(self):
d = {"chunk": self.chunk}
if self.next_batch:
d["next_batch"] = self.next_batch.to_string()
if self.prev_batch:
d["prev_batch"] = self.prev_batch.to_string()
return d
@attr.s(frozen=True, slots=True)
class RelationPaginationToken(object):
"""Pagination token for relation pagination API.
As the results are order by topological ordering, we can use the
`topological_ordering` and `stream_ordering` fields of the events at the
boundaries of the chunk as pagination tokens.
Attributes:
topological (int): The topological ordering of the boundary event
stream (int): The stream ordering of the boundary event.
"""
topological = attr.ib()
stream = attr.ib()
@staticmethod
def from_string(string):
try:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
def to_string(self):
return "%d-%d" % (self.topological, self.stream)
def as_tuple(self):
return attr.astuple(self)
@attr.s(frozen=True, slots=True)
class AggregationPaginationToken(object):
"""Pagination token for relation aggregation pagination API.
As the results are order by count and then MAX(stream_ordering) of the
aggregation groups, we can just use them as our pagination token.
Attributes:
count (int): The count of relations in the boundar group.
stream (int): The MAX stream ordering in the boundary group.
"""
count = attr.ib()
stream = attr.ib()
@staticmethod
def from_string(string):
try:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
def to_string(self):
return "%d-%d" % (self.count, self.stream)
def as_tuple(self):
return attr.astuple(self)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
def get_relations_for_event(
self,
event_id,
relation_type=None,
event_type=None,
aggregation_key=None,
limit=5,
direction="b",
from_token=None,
to_token=None,
):
"""Get a list of relations for an event, ordered by topological ordering.
Args:
event_id (str): Fetch events that relate to this event ID.
relation_type (str|None): Only fetch events with this relation
type, if given.
event_type (str|None): Only fetch events with this event type, if
given.
aggregation_key (str|None): Only fetch events with this aggregation
key, if given.
limit (int): Only fetch the most recent `limit` events.
direction (str): Whether to fetch the most recent first (`"b"`) or
the oldest first (`"f"`).
from_token (RelationPaginationToken|None): Fetch rows from the given
token, or from the start if None.
to_token (RelationPaginationToken|None): Fetch rows up to the given
token, or up to the end if None.
Returns:
Deferred[PaginationChunk]: List of event IDs that match relations
requested. The rows are of the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
where_args = [event_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
where_args.append(relation_type)
if event_type is not None:
where_clause.append("type = ?")
where_args.append(event_type)
if aggregation_key:
where_clause.append("aggregation_key = ?")
where_args.append(aggregation_key)
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
engine=self.database_engine,
)
if pagination_clause:
where_clause.append(pagination_clause)
if direction == "b":
order = "DESC"
else:
order = "ASC"
sql = """
SELECT event_id, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)
def _get_recent_references_for_event_txn(txn):
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
last_stream_id = None
events = []
for row in txn:
events.append({"event_id": row[0]})
last_topo_id = row[1]
last_stream_id = row[2]
next_batch = None
if len(events) > limit and last_topo_id and last_stream_id:
next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
def get_aggregation_groups_for_event(
self,
event_id,
event_type=None,
limit=5,
direction="b",
from_token=None,
to_token=None,
):
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id (str): Fetch events that relate to this event ID.
event_type (str|None): Only fetch events with this event type, if
given.
limit (int): Only fetch the `limit` groups.
direction (str): Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
from_token (AggregationPaginationToken|None): Fetch rows from the
given token, or from the start if None.
to_token (AggregationPaginationToken|None): Fetch rows up to the
given token, or up to the end if None.
Returns:
Deferred[PaginationChunk]: List of groups of annotations that
match. Each row is a dict with `type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args = [event_id, RelationTypes.ANNOTATION]
if event_type:
where_clause.append("type = ?")
where_args.append(event_type)
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
engine=self.database_engine,
)
if direction == "b":
order = "DESC"
else:
order = "ASC"
if having_clause:
having_clause = "HAVING " + having_clause
else:
having_clause = ""
sql = """
SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE {where_clause}
GROUP BY relation_type, type, aggregation_key
{having_clause}
ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
LIMIT ?
""".format(
where_clause=" AND ".join(where_clause),
order=order,
having_clause=having_clause,
)
def _get_aggregation_groups_for_event_txn(txn):
txn.execute(sql, where_args + [limit + 1])
next_batch = None
events = []
for row in txn:
events.append({"type": row[0], "key": row[1], "count": row[2]})
next_batch = AggregationPaginationToken(row[2], row[3])
if len(events) <= limit:
next_batch = None
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@cachedInlineCallbacks()
def get_applicable_edit(self, event_id):
"""Get the most recent edit (if any) that has happened for the given
event.
Correctly handles checking whether edits were allowed to happen.
Args:
event_id (str): The original event ID
Returns:
Deferred[EventBase|None]: Returns the most recent edit, if any.
"""
# We only allow edits for `m.room.message` events that have the same sender
# and event type. We can't assert these things during regular event auth so
# we have to do the checks post hoc.
# Fetches latest edit that has the same type and sender as the
# original, and is an `m.room.message`.
sql = """
SELECT edit.event_id FROM events AS edit
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS original ON
original.event_id = relates_to_id
AND edit.type = original.type
AND edit.sender = original.sender
WHERE
relates_to_id = ?
AND relation_type = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn):
txn.execute(
sql, (event_id, RelationTypes.REPLACE,)
)
row = txn.fetchone()
if row:
return row[0]
edit_id = yield self.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
)
if not edit_id:
return
edit_event = yield self.get_event(edit_id, allow_none=True)
defer.returnValue(edit_event)
class RelationsStore(RelationsWorkerStore):
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
Args:
txn
event (EventBase)
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
rel_type = relation.get("rel_type")
if rel_type not in (
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.REPLACE,
):
# Unknown relation type
return
parent_id = relation.get("event_id")
if not parent_id:
# Invalid relation
return
aggregation_key = relation.get("key")
self._simple_insert_txn(
txn,
table="event_relations",
values={
"event_id": event.event_id,
"relates_to_id": parent_id,
"relation_type": rel_type,
"aggregation_key": aggregation_key,
},
)
txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
txn.call_after(
self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
)
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
any redacted relations from the database.
Args:
txn
redacted_event_id (str): The event that was redacted.
"""
self._simple_delete_txn(
txn,
table="event_relations",
keyvalues={
"event_id": redacted_event_id,
}
)

View file

@ -0,0 +1,27 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Tracks related events, like reactions, replies, edits, etc. Note that things
-- in this table are not necessarily "valid", e.g. it may contain edits from
-- people who don't have power to edit other peoples events.
CREATE TABLE IF NOT EXISTS event_relations (
event_id TEXT NOT NULL,
relates_to_id TEXT NOT NULL,
relation_type TEXT NOT NULL,
aggregation_key TEXT
);
CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id);
CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key);

View file

@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events}
@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events}

View file

@ -64,59 +64,135 @@ _EventDictReturn = namedtuple(
)
def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None:
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
def generate_pagination_where_clause(
direction, column_names, from_token, to_token, engine,
):
"""Creates an SQL expression to bound the columns by the pagination
tokens.
For example creates an SQL expression like:
(6, 7) >= (topological_ordering, stream_ordering)
AND (5, 3) < (topological_ordering, stream_ordering)
would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
Note that tokens are considered to be after the row they are in, e.g. if
a row A has a token T, then we consider A to be before T. This convention
is important when figuring out inequalities for the generated SQL, and
produces the following result:
- If paginating forwards then we exclude any rows matching the from
token, but include those that match the to token.
- If paginating backwards then we include any rows matching the from
token, but include those that match the to token.
Args:
direction (str): Whether we're paginating backwards("b") or
forwards ("f").
column_names (tuple[str, str]): The column names to bound. Must *not*
be user defined as these get inserted directly into the SQL
statement without escapes.
from_token (tuple[int, int]|None): The start point for the pagination.
This is an exclusive minimum bound if direction is "f", and an
inclusive maximum bound if direction is "b".
to_token (tuple[int, int]|None): The endpoint point for the pagination.
This is an inclusive maximum bound if direction is "f", and an
exclusive minimum bound if direction is "b".
engine: The database engine to generate the clauses for
Returns:
str: The sql expression
"""
assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
bound=">=" if direction == "b" else "<",
column_names=column_names,
values=from_token,
engine=engine,
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
def upper_bound(token, engine, inclusive=True):
inclusive = "=" if inclusive else ""
if token.topological is None:
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
if to_token:
where_clause.append(
_make_generic_sql_bound(
bound="<" if direction == "b" else ">=",
column_names=column_names,
values=to_token,
engine=engine,
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
return " AND ".join(where_clause)
def _make_generic_sql_bound(bound, column_names, values, engine):
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
Only works with two columns.
Older versions of SQLite don't support that syntax so we have to expand it
out manually.
Args:
bound (str): The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
names (tuple[str, str]): The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
values (tuple[int|None, int]): The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
str
"""
assert(bound in (">", "<", ">=", "<="))
name1, name2 = column_names
val1, val2 = values
if val1 is None:
val2 = int(val2)
return "(%d %s %s)" % (val2, bound, name2)
val1 = int(val1)
val2 = int(val2)
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) %s (%s,%s))" % (
val1, val2,
bound,
name1, name2,
)
# We want to generate queries of e.g. the form:
#
# (val1 < name1 OR (val1 = name1 AND val2 <= name2))
#
# which is equivalent to (val1, val2) < (name1, name2)
return """(
{val1:d} {strict_bound} {name1}
OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
)""".format(
name1=name1,
val1=val1,
name2=name2,
val2=val2,
strict_bound=bound[0], # The first bound must always be strict equality here
bound=bound,
)
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
@ -319,7 +395,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
ret = yield self.get_events_as_list([
r.event_id for r in rows], get_prev_content=True,
)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@ -367,7 +445,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True,
)
self._set_before_and_after(ret, rows, topo_order=False)
@ -394,7 +474,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
logger.debug("stream before")
events = yield self._get_events(
events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
logger.debug("stream after")
@ -580,11 +660,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
events_before = yield self._get_events(
events_before = yield self.get_events_as_list(
[e for e in results["before"]["event_ids"]], get_prev_content=True
)
events_after = yield self._get_events(
events_after = yield self.get_events_as_list(
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
@ -697,7 +777,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self._get_events(event_ids)
events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events))
@ -758,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
bounds = upper_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (
bounds,
lower_bound(to_token, self.database_engine),
)
else:
order = "ASC"
bounds = lower_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (
bounds,
upper_bound(to_token, self.database_engine),
)
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token,
to_token=to_token,
engine=self.database_engine,
)
filter_clause, filter_args = filter_to_clause(event_filter)
@ -849,7 +925,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
events = yield self._get_events(
events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)

View file

@ -156,6 +156,25 @@ def concurrently_execute(func, args, limit):
], consumeErrors=True)).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs):
"""Executes the function with each argument concurrently.
Args:
func (func): Function to execute that returns a Deferred
iter (iter): An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
return logcontext.make_deferred_yieldable(defer.gatherResults([
run_in_background(func, item, *args, **kwargs)
for item in iter
], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.

View file

@ -27,6 +27,7 @@ def user_left_room(distributor, user, room_id):
distributor.fire("user_left_room", user=user, room_id=room_id)
# XXX: this is no longer used. We should probably kill it.
def user_joined_room(distributor, user, room_id):
distributor.fire("user_joined_room", user=user, room_id=room_id)

View file

@ -30,31 +30,14 @@ logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
def __init__(self, clock, config):
"""
Args:
clock (Clock)
window_size (int): The window size in milliseconds.
sleep_limit (int): The number of requests received in the last
`window_size` milliseconds before we artificially start
delaying processing of requests.
sleep_msec (int): The number of milliseconds to delay processing
of incoming requests by.
reject_limit (int): The maximum number of requests that are can be
queued for processing before we start rejecting requests with
a 429 Too Many Requests response.
concurrent_requests (int): The number of concurrent requests to
process.
config (FederationRateLimitConfig)
"""
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self._config = config
self.ratelimiters = {}
def ratelimit(self, host):
@ -76,25 +59,25 @@ class FederationRateLimiter(object):
host,
_PerHostRatelimiter(
clock=self.clock,
window_size=self.window_size,
sleep_limit=self.sleep_limit,
sleep_msec=self.sleep_msec,
reject_limit=self.reject_limit,
concurrent_requests=self.concurrent_requests,
config=self._config,
)
).ratelimit()
class _PerHostRatelimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
def __init__(self, clock, config):
"""
Args:
clock (Clock)
config (FederationRateLimitConfig)
"""
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_sec = sleep_msec / 1000.0
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.window_size = config.window_size
self.sleep_limit = config.sleep_limit
self.sleep_sec = config.sleep_delay / 1000.0
self.reject_limit = config.reject_limit
self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept
self.sleeping_requests = set()

View file

@ -37,8 +37,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config = self.default_config("test")
# some of the tests rely on us having a user consent version
hs_config.user_consent_version = "test_consent_version"
hs_config.max_mau_value = 50
hs_config["user_consent"] = {
"version": "test_consent_version",
"template_dir": ".",
}
hs_config["max_mau_value"] = 50
hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
return hs
@ -224,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_not_support_user(self):
res = self.get_success(self.handler.register(localpart='user'))
self.assertFalse(self.store.is_support_user(res[0]))
def test_invalid_user_id_length(self):
invalid_user_id = "x" * 256
self.get_failure(
self.handler.register(localpart=invalid_user_id),
SynapseError
)

View file

@ -37,7 +37,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.update_user_directory = True
config["update_user_directory"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
@ -333,7 +333,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.update_user_directory = True
config["update_user_directory"] = True
hs = self.setup_test_homeserver(config=config)
self.config = hs.config

View file

@ -54,7 +54,9 @@ class MatrixFederationAgentTests(TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(default_config("test")),
tls_client_options_factory=ClientTLSOptionsFactory(
default_config("test", parse=True)
),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,

View file

@ -15,6 +15,8 @@
from mock import Mock
from netaddr import IPSet
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
@ -209,6 +211,75 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
def test_client_ip_range_blacklist(self):
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Set up the ip_range blacklist
self.hs.config.federation_ip_range_blacklist = IPSet([
"127.0.0.0/8",
"fe80::/64",
])
self.reactor.lookups["internal"] = "127.0.0.1"
self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
self.reactor.lookups["fine"] = "10.20.30.40"
cl = MatrixFederationHttpClient(self.hs, None)
# Try making a GET request to a blacklisted IPv4 address
# ------------------------------------------------------
# Make the request
d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
# Nothing happened yet
self.assertNoResult(d)
self.pump(1)
# Check that it was unable to resolve the address
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 0)
f = self.failureResultOf(d)
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
# Try making a POST request to a blacklisted IPv6 address
# -------------------------------------------------------
# Make the request
d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
# Nothing has happened yet
self.assertNoResult(d)
# Move the reactor forwards
self.pump(1)
# Check that it was unable to resolve the address
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 0)
# Check that it was due to a blacklisted DNS lookup
f = self.failureResultOf(d, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
# Try making a GET request to a non-blacklisted IPv4 address
# ----------------------------------------------------------
# Make the request
d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
# Nothing has happened yet
self.assertNoResult(d)
# Move the reactor forwards
self.pump(1)
# Check that it was able to resolve the address
clients = self.reactor.tcpClients
self.assertNotEqual(len(clients), 0)
# Connection will still fail as this IP address does not resolve to anything
f = self.failureResultOf(d, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
def test_client_gets_headers(self):
"""
Once the client gets the headers, _request returns successfully.

View file

@ -52,22 +52,26 @@ class EmailPusherTests(HomeserverTestCase):
return d
config = self.default_config()
config.email_enable_notifs = True
config.start_pushers = True
config.email_template_dir = os.path.abspath(
pkg_resources.resource_filename('synapse', 'res/templates')
)
config.email_notif_template_html = "notif_mail.html"
config.email_notif_template_text = "notif_mail.txt"
config.email_smtp_host = "127.0.0.1"
config.email_smtp_port = 20
config.require_transport_security = False
config.email_smtp_user = None
config.email_smtp_pass = None
config.email_app_name = "Matrix"
config.email_notif_from = "test@example.com"
config.email_riot_base_url = None
config["email"] = {
"enable_notifs": True,
"template_dir": os.path.abspath(
pkg_resources.resource_filename('synapse', 'res/templates')
),
"expiry_template_html": "notice_expiry.html",
"expiry_template_text": "notice_expiry.txt",
"notif_template_html": "notif_mail.html",
"notif_template_text": "notif_mail.txt",
"smtp_host": "127.0.0.1",
"smtp_port": 20,
"require_transport_security": False,
"smtp_user": None,
"smtp_pass": None,
"app_name": "Matrix",
"notif_from": "test@example.com",
"riot_base_url": None,
}
config["public_baseurl"] = "aaa"
config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config, sendmail=sendmail)

View file

@ -54,7 +54,7 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
config = self.default_config()
config.start_pushers = True
config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config, simple_http_client=m)

View file

@ -42,15 +42,18 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.user_consent_version = "1"
config.public_baseurl = ""
config.form_secret = "123abc"
config["public_baseurl"] = "aaaa"
config["form_secret"] = "123abc"
# Make some temporary templates...
temp_consent_path = self.mktemp()
os.mkdir(temp_consent_path)
os.mkdir(os.path.join(temp_consent_path, 'en'))
config.user_consent_template_dir = os.path.abspath(temp_consent_path)
config["user_consent"] = {
"version": "1",
"template_dir": os.path.abspath(temp_consent_path),
}
with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
f.write("{{version}},{{has_consented}}")

View file

@ -32,7 +32,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.enable_3pid_lookup = False
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs

View file

@ -34,7 +34,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.require_membership_for_aliases = True
config["require_membership_for_aliases"] = True
self.hs = self.setup_test_homeserver(config=config)

View file

@ -36,9 +36,9 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.enable_registration_captcha = False
config.enable_registration = True
config.auto_join_rooms = []
config["enable_registration_captcha"] = False
config["enable_registration"] = True
config["auto_join_rooms"] = []
hs = self.setup_test_homeserver(
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])

View file

@ -171,7 +171,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.require_auth_for_profile_requests = True
config["require_auth_for_profile_requests"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,7 +25,7 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import Membership
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
@ -919,7 +920,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/publicRooms"
config = self.default_config()
config.restrict_public_rooms_to_local_users = True
config["restrict_public_rooms_to_local_users"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@ -936,3 +937,70 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request("GET", self.url, access_token=tok)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
profile.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["allow_per_room_profiles"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
# Set a profile for the test user
self.displayname = "test user"
data = {
"displayname": self.displayname,
}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self):
data = {
"membership": "join",
"displayname": "other test user"
}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
self.room_id, self.user_id,
),
request_data,
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
request, channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)

View file

@ -36,9 +36,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
config = self.default_config()
config.enable_registration_captcha = True
config.recaptcha_public_key = "brokencake"
config.registrations_require_3pid = []
config["enable_registration_captcha"] = True
config["recaptcha_public_key"] = "brokencake"
config["registrations_require_3pid"] = []
hs = self.setup_test_homeserver(config=config)
return hs
@ -92,7 +92,14 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.recaptcha_attempts), 1)
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
# Now we have fufilled the recaptcha fallback step, we can then send a
# also complete the dummy auth
request, channel = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
)
self.render(request)
# Now we should have fufilled a complete auth flow, including
# the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
request, channel = self.make_request(
"POST", "register", {"auth": {"session": session}}

View file

@ -201,9 +201,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
# Test for account expiring after a week.
config.enable_registration = True
config.account_validity.enabled = True
config.account_validity.period = 604800000 # Time in ms for 1 week
config["enable_registration"] = True
config["account_validity"] = {
"enabled": True,
"period": 604800000, # Time in ms for 1 week
}
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@ -299,14 +301,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
# Test for account expiring after a week and renewal emails being sent 2
# days before expiry.
config.enable_registration = True
config.account_validity.enabled = True
config.account_validity.renew_by_email_enabled = True
config.account_validity.period = 604800000 # Time in ms for 1 week
config.account_validity.renew_at = 172800000 # Time in ms for 2 days
config.account_validity.renew_email_subject = "Renew your account"
config["enable_registration"] = True
config["account_validity"] = {
"enabled": True,
"period": 604800000, # Time in ms for 1 week
"renew_at": 172800000, # Time in ms for 2 days
"renew_by_email_enabled": True,
"renew_email_subject": "Renew your account",
}
# Email config.
self.email_attempts = []
@ -315,17 +320,23 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.email_attempts.append((args, kwargs))
return
config.email_template_dir = os.path.abspath(
pkg_resources.resource_filename('synapse', 'res/templates')
)
config.email_expiry_template_html = "notice_expiry.html"
config.email_expiry_template_text = "notice_expiry.txt"
config.email_smtp_host = "127.0.0.1"
config.email_smtp_port = 20
config.require_transport_security = False
config.email_smtp_user = None
config.email_smtp_pass = None
config.email_notif_from = "test@example.com"
config["email"] = {
"enable_notifs": True,
"template_dir": os.path.abspath(
pkg_resources.resource_filename('synapse', 'res/templates')
),
"expiry_template_html": "notice_expiry.html",
"expiry_template_text": "notice_expiry.txt",
"notif_template_html": "notif_mail.html",
"notif_template_text": "notif_mail.txt",
"smtp_host": "127.0.0.1",
"smtp_port": 20,
"require_transport_security": False,
"smtp_user": None,
"smtp_pass": None,
"notif_from": "test@example.com",
}
config["public_baseurl"] = "aaa"
self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)

View file

@ -0,0 +1,539 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import json
import six
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import register, relations
from tests import unittest
class RelationsTestCase(unittest.HomeserverTestCase):
servlets = [
relations.register_servlets,
room.register_servlets,
login.register_servlets,
register.register_servlets,
admin.register_servlets_for_client_rest_resource,
]
hijack_auth = False
def make_homeserver(self, reactor, clock):
# We need to enable msc1849 support for aggregations
config = self.default_config()
config["experimental_msc1849_support_enabled"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob")
self.room = self.helper.create_room_as(self.user_id, tok=self.user_token)
self.helper.join(self.room, user=self.user2_id, tok=self.user2_token)
res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
self.parent_id = res["event_id"]
def test_send_relation(self):
"""Tests that sending a relation using the new /send_relation works
creates the right shape of event.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍")
self.assertEquals(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
request, channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, event_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assert_dict(
{
"type": "m.reaction",
"sender": self.user_id,
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"key": u"👍",
"rel_type": RelationTypes.ANNOTATION,
}
},
},
channel.json_body,
)
def test_deny_membership(self):
"""Test that we deny relations on membership events
"""
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
"""Tests that calling pagination API corectly the latest relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# We expect to get back a single pagination result, which is the full
# relation event we sent above.
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
self.assert_dict(
{"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
channel.json_body["chunk"][0],
)
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
channel.json_body.get("next_batch"), six.string_types, channel.json_body
)
def test_repeated_paginate_relations(self):
"""Test that if we paginate using a limit and tokens then we get the
expected events.
"""
expected_event_ids = []
for _ in range(10):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
prev_token = None
found_event_ids = []
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + prev_token
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
# We paginated backwards, so reverse
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation_pagination_groups(self):
"""Test that we can paginate annotation groups correctly.
"""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
idx = 0
while len(access_tokens) < 10:
user_id, token = self._create_user("test" + str(idx))
idx += 1
self.helper.join(self.room, user=user_id, tok=token)
access_tokens.append(token)
idx = 0
sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
for key in itertools.chain.from_iterable(
itertools.repeat(key, num) for key, num in sent_groups.items()
):
channel = self._send_relation(
RelationTypes.ANNOTATION,
"m.reaction",
key=key,
access_token=access_tokens[idx],
)
self.assertEquals(200, channel.code, channel.json_body)
idx += 1
idx %= len(access_tokens)
prev_token = None
found_groups = {}
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + prev_token
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
for groups in channel.json_body["chunk"]:
# We only expect reactions
self.assertEqual(groups["type"], "m.reaction", channel.json_body)
# We should only see each key once
self.assertNotIn(groups["key"], found_groups, channel.json_body)
found_groups[groups["key"]] = groups["count"]
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
self.assertEquals(sent_groups, found_groups)
def test_aggregation_pagination_within_group(self):
"""Test that we can paginate within an annotation group.
"""
expected_event_ids = []
for _ in range(10):
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", key=u"👍"
)
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
# Also send a different type of reaction so that we test we don't see it
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
prev_token = None
found_event_ids = []
encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8"))
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + prev_token
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s"
"/aggregations/%s/%s/m.reaction/%s?limit=1%s"
% (
self.room,
self.parent_id,
RelationTypes.ANNOTATION,
encoded_key,
from_token,
),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
# We paginated backwards, so reverse
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation(self):
"""Test that annotations get correctly aggregated.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
channel.json_body,
{
"chunk": [
{"type": "m.reaction", "key": "a", "count": 2},
{"type": "m.reaction", "key": "b", "count": 1},
]
},
)
def test_aggregation_redactions(self):
"""Test that annotations get correctly aggregated after a redaction.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEquals(200, channel.code, channel.json_body)
# Now lets redact one of the 'a' reactions
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
access_token=self.user_token,
content={},
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
channel.json_body,
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
def test_aggregation_must_be_annotation(self):
"""Test that aggregations must be annotations.
"""
request, channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
% (self.room, self.parent_id, RelationTypes.REPLACE),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(400, channel.code, channel.json_body)
def test_aggregation_get_event(self):
"""Test that annotations and references get correctly bundled when
getting the parent event.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
reply_1 = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]
request, channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{
RelationTypes.ANNOTATION: {
"chunk": [
{"type": "m.reaction", "key": "a", "count": 2},
{"type": "m.reaction", "key": "b", "count": 1},
]
},
RelationTypes.REFERENCE: {
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
},
},
)
def test_edit(self):
"""Test that a simple edit works.
"""
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
self.assertEquals(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
request, channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body)
self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{RelationTypes.REPLACE: {"event_id": edit_event_id}},
)
def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
"""
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
"msgtype": "m.text",
"body": "Wibble",
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
self.assertEquals(200, channel.code, channel.json_body)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
self.assertEquals(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message.WRONG_TYPE",
content={
"msgtype": "m.text",
"body": "Wibble",
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
},
)
self.assertEquals(200, channel.code, channel.json_body)
request, channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body)
self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{RelationTypes.REPLACE: {"event_id": edit_event_id}},
)
def _send_relation(
self, relation_type, event_type, key=None, content={}, access_token=None
):
"""Helper function to send a relation pointing at `self.parent_id`
Args:
relation_type (str): One of `RelationTypes`
event_type (str): The type of the event to create
key (str|None): The aggregation key used for m.annotation relation
type.
content(dict|None): The content of the created event.
access_token (str|None): The access token used to send the relation,
defaults to `self.user_token`
Returns:
FakeChannel
"""
if not access_token:
access_token = self.user_token
query = ""
if key:
query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
request, channel = self.make_request(
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, self.parent_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"),
access_token=access_token,
)
self.render(request)
return channel
def _create_user(self, localpart):
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")
return user_id, access_token

View file

@ -25,13 +25,11 @@ from six.moves.urllib import parse
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
from synapse.config.repository import MediaStorageProviderConfig
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.module_loader import load_module
from tests import unittest
@ -120,12 +118,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
client.get_file = get_file
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config = self.default_config()
config.media_store_path = self.storage_path
config.thumbnail_requirements = {}
config.max_image_pixels = 2000000
config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
@ -134,12 +134,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"store_remote": True,
"config": {"directory": self.storage_path},
}
loaded = list(load_module(provider_config)) + [
MediaStorageProviderConfig(False, False, False)
]
config.media_storage_providers = [loaded]
config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config, http_client=client)

View file

@ -16,7 +16,6 @@
import os
import attr
from netaddr import IPSet
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
@ -25,9 +24,6 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web._newclient import ResponseDone
from synapse.config.repository import MediaStorageProviderConfig
from synapse.util.module_loader import load_module
from tests import unittest
from tests.server import FakeTransport
@ -67,23 +63,23 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.storage_path = self.mktemp()
os.mkdir(self.storage_path)
config = self.default_config()
config.url_preview_enabled = True
config.max_spider_size = 9999999
config.url_preview_ip_range_blacklist = IPSet(
(
"192.168.1.1",
"1.0.0.0/8",
"3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
"2001:800::/21",
)
config["url_preview_enabled"] = True
config["max_spider_size"] = 9999999
config["url_preview_ip_range_blacklist"] = (
"192.168.1.1",
"1.0.0.0/8",
"3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
"2001:800::/21",
)
config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",))
config.url_preview_url_blacklist = []
config.media_store_path = self.storage_path
config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
config["url_preview_url_blacklist"] = []
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
@ -93,11 +89,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path},
}
loaded = list(load_module(provider_config)) + [
MediaStorageProviderConfig(False, False, False)
]
config.media_storage_providers = [loaded]
config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config)

Some files were not shown because too many files have changed in this diff Show more