mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 22:42:23 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into neilj/server_notices_on_blocking
This commit is contained in:
commit
fc5d937550
163 changed files with 2945 additions and 2849 deletions
48
.circleci/config.yml
Normal file
48
.circleci/config.yml
Normal file
|
@ -0,0 +1,48 @@
|
|||
version: 2
|
||||
jobs:
|
||||
sytestpy2:
|
||||
machine: true
|
||||
steps:
|
||||
- checkout
|
||||
- run: docker pull matrixdotorg/sytest-synapsepy2
|
||||
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs matrixdotorg/sytest-synapsepy2
|
||||
- store_artifacts:
|
||||
path: ~/project/logs
|
||||
destination: logs
|
||||
sytestpy2postgres:
|
||||
machine: true
|
||||
steps:
|
||||
- checkout
|
||||
- run: docker pull matrixdotorg/sytest-synapsepy2
|
||||
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy2
|
||||
- store_artifacts:
|
||||
path: ~/project/logs
|
||||
destination: logs
|
||||
sytestpy3:
|
||||
machine: true
|
||||
steps:
|
||||
- checkout
|
||||
- run: docker pull matrixdotorg/sytest-synapsepy3
|
||||
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs hawkowl/sytestpy3
|
||||
- store_artifacts:
|
||||
path: ~/project/logs
|
||||
destination: logs
|
||||
sytestpy3postgres:
|
||||
machine: true
|
||||
steps:
|
||||
- checkout
|
||||
- run: docker pull matrixdotorg/sytest-synapsepy3
|
||||
- run: docker run --rm -it -v $(pwd)\:/src -v $(pwd)/logs\:/logs -e POSTGRES=1 matrixdotorg/sytest-synapsepy3
|
||||
- store_artifacts:
|
||||
path: ~/project/logs
|
||||
destination: logs
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build:
|
||||
jobs:
|
||||
- sytestpy2
|
||||
- sytestpy2postgres
|
||||
# Currently broken while the Python 3 port is incomplete
|
||||
# - sytestpy3
|
||||
# - sytestpy3postgres
|
|
@ -3,3 +3,6 @@ Dockerfile
|
|||
.gitignore
|
||||
demo/etc
|
||||
tox.ini
|
||||
synctl
|
||||
.git/*
|
||||
.tox/*
|
||||
|
|
10
.travis.yml
10
.travis.yml
|
@ -8,6 +8,9 @@ before_script:
|
|||
- git remote set-branches --add origin develop
|
||||
- git fetch origin develop
|
||||
|
||||
services:
|
||||
- postgresql
|
||||
|
||||
matrix:
|
||||
fast_finish: true
|
||||
include:
|
||||
|
@ -20,6 +23,9 @@ matrix:
|
|||
- python: 2.7
|
||||
env: TOX_ENV=py27
|
||||
|
||||
- python: 2.7
|
||||
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
|
||||
|
||||
- python: 3.6
|
||||
env: TOX_ENV=py36
|
||||
|
||||
|
@ -29,6 +35,10 @@ matrix:
|
|||
- python: 3.6
|
||||
env: TOX_ENV=check-newsfragment
|
||||
|
||||
allow_failures:
|
||||
- python: 2.7
|
||||
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
|
||||
|
||||
install:
|
||||
- pip install tox
|
||||
|
||||
|
|
|
@ -36,3 +36,4 @@ recursive-include changelog.d *
|
|||
prune .github
|
||||
prune demo/etc
|
||||
prune docker
|
||||
prune .circleci
|
||||
|
|
1
changelog.d/1491.feature
Normal file
1
changelog.d/1491.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for the SNI extension to federation TLS connections
|
1
changelog.d/3423.misc
Normal file
1
changelog.d/3423.misc
Normal file
|
@ -0,0 +1 @@
|
|||
The test suite now can run under PostgreSQL.
|
1
changelog.d/3653.feature
Normal file
1
changelog.d/3653.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Support more federation endpoints on workers
|
1
changelog.d/3660.misc
Normal file
1
changelog.d/3660.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Sytests can now be run inside a Docker container.
|
1
changelog.d/3661.bugfix
Normal file
1
changelog.d/3661.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug on deleting 3pid when using identity servers that don't support unbind API
|
1
changelog.d/3669.misc
Normal file
1
changelog.d/3669.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Update docker base image from alpine 3.7 to 3.8.
|
1
changelog.d/3676.bugfix
Normal file
1
changelog.d/3676.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Make the tests pass on Twisted < 18.7.0
|
1
changelog.d/3677.bugfix
Normal file
1
changelog.d/3677.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Don’t ship recaptcha_ajax.js, use it directly from Google
|
1
changelog.d/3678.misc
Normal file
1
changelog.d/3678.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7.
|
1
changelog.d/3679.misc
Normal file
1
changelog.d/3679.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Synapse's tests are now formatted with the black autoformatter.
|
1
changelog.d/3681.bugfix
Normal file
1
changelog.d/3681.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fixes test_reap_monthly_active_users so it passes under postgres
|
1
changelog.d/3684.misc
Normal file
1
changelog.d/3684.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Implemented a new testing base class to reduce test boilerplate.
|
1
changelog.d/3687.feature
Normal file
1
changelog.d/3687.feature
Normal file
|
@ -0,0 +1 @@
|
|||
set admin uri via config, to be used in error messages where the user should contact the administrator
|
1
changelog.d/3690.misc
Normal file
1
changelog.d/3690.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Rename MAU prometheus metrics
|
1
changelog.d/3692.bugfix
Normal file
1
changelog.d/3692.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users
|
|
@ -1,4 +1,4 @@
|
|||
FROM docker.io/python:2-alpine3.7
|
||||
FROM docker.io/python:2-alpine3.8
|
||||
|
||||
RUN apk add --no-cache --virtual .nacl_deps \
|
||||
build-base \
|
||||
|
|
|
@ -173,10 +173,23 @@ endpoints matching the following regular expressions::
|
|||
^/_matrix/federation/v1/backfill/
|
||||
^/_matrix/federation/v1/get_missing_events/
|
||||
^/_matrix/federation/v1/publicRooms
|
||||
^/_matrix/federation/v1/query/
|
||||
^/_matrix/federation/v1/make_join/
|
||||
^/_matrix/federation/v1/make_leave/
|
||||
^/_matrix/federation/v1/send_join/
|
||||
^/_matrix/federation/v1/send_leave/
|
||||
^/_matrix/federation/v1/invite/
|
||||
^/_matrix/federation/v1/query_auth/
|
||||
^/_matrix/federation/v1/event_auth/
|
||||
^/_matrix/federation/v1/exchange_third_party_invite/
|
||||
^/_matrix/federation/v1/send/
|
||||
|
||||
The above endpoints should all be routed to the federation_reader worker by the
|
||||
reverse-proxy configuration.
|
||||
|
||||
The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single
|
||||
instance.
|
||||
|
||||
``synapse.app.federation_sender``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -780,11 +780,14 @@ class Auth(object):
|
|||
such as monthly active user limiting or global disable flag
|
||||
|
||||
Args:
|
||||
user_id(str): If present, checks for presence against existing MAU cohort
|
||||
user_id(str|None): If present, checks for presence against existing
|
||||
MAU cohort
|
||||
"""
|
||||
if self.hs.config.hs_disabled:
|
||||
raise AuthError(
|
||||
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
|
||||
403, self.hs.config.hs_disabled_message,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||
admin_uri=self.hs.config.admin_uri,
|
||||
)
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
# If the user is already part of the MAU cohort
|
||||
|
@ -796,5 +799,7 @@ class Auth(object):
|
|||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise AuthError(
|
||||
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||
403, "Monthly Active User Limits AU Limit Exceeded",
|
||||
admin_uri=self.hs.config.admin_uri,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEED
|
||||
)
|
||||
|
|
|
@ -56,8 +56,7 @@ class Codes(object):
|
|||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
||||
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
||||
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
|
||||
HS_DISABLED = "M_HS_DISABLED"
|
||||
RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED"
|
||||
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
|
||||
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
||||
|
||||
|
@ -225,11 +224,16 @@ class NotFoundError(SynapseError):
|
|||
|
||||
class AuthError(SynapseError):
|
||||
"""An error raised when there was a problem authorising an event."""
|
||||
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None):
|
||||
self.admin_uri = admin_uri
|
||||
super(AuthError, self).__init__(code, msg, errcode=errcode)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "errcode" not in kwargs:
|
||||
kwargs["errcode"] = Codes.FORBIDDEN
|
||||
super(AuthError, self).__init__(*args, **kwargs)
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
admin_uri=self.admin_uri,
|
||||
)
|
||||
|
||||
|
||||
class EventSizeError(SynapseError):
|
||||
|
|
|
@ -39,7 +39,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
|
|||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1.room import (
|
||||
JoinedRoomMemberListRestServlet,
|
||||
|
@ -66,7 +66,7 @@ class ClientReaderSlavedStore(
|
|||
DirectoryStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
TransactionStore,
|
||||
SlavedTransactionStore,
|
||||
SlavedClientIpStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
|
@ -168,11 +168,13 @@ def start(config_options):
|
|||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ss = ClientReaderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -43,7 +43,7 @@ from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
|||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1.room import (
|
||||
JoinRoomAliasServlet,
|
||||
|
@ -63,7 +63,7 @@ logger = logging.getLogger("synapse.app.event_creator")
|
|||
|
||||
class EventCreatorSlavedStore(
|
||||
DirectoryStore,
|
||||
TransactionStore,
|
||||
SlavedTransactionStore,
|
||||
SlavedProfileStore,
|
||||
SlavedAccountDataStore,
|
||||
SlavedPusherStore,
|
||||
|
@ -174,11 +174,13 @@ def start(config_options):
|
|||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ss = EventCreatorServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -32,11 +32,16 @@ from synapse.http.site import SynapseSite
|
|||
from synapse.metrics import RegistryProxy
|
||||
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||
from synapse.replication.slave.storage.profile import SlavedProfileStore
|
||||
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
|
@ -49,11 +54,16 @@ logger = logging.getLogger("synapse.app.federation_reader")
|
|||
|
||||
|
||||
class FederationReaderSlavedStore(
|
||||
SlavedProfileStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedPusherStore,
|
||||
SlavedPushRuleStore,
|
||||
SlavedReceiptsStore,
|
||||
SlavedEventStore,
|
||||
SlavedKeyStore,
|
||||
RoomStore,
|
||||
DirectoryStore,
|
||||
TransactionStore,
|
||||
SlavedTransactionStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
pass
|
||||
|
@ -143,11 +153,13 @@ def start(config_options):
|
|||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ss = FederationReaderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -36,11 +36,11 @@ from synapse.replication.slave.storage.events import SlavedEventStore
|
|||
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext, run_in_background
|
||||
from synapse.util.manhole import manhole
|
||||
|
@ -50,7 +50,7 @@ logger = logging.getLogger("synapse.app.federation_sender")
|
|||
|
||||
|
||||
class FederationSenderSlaveStore(
|
||||
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||
SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
|
||||
SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
|
||||
):
|
||||
def __init__(self, db_conn, hs):
|
||||
|
@ -186,11 +186,13 @@ def start(config_options):
|
|||
config.send_federation = True
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ps = FederationSenderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -208,11 +208,13 @@ def start(config_options):
|
|||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ss = FrontendProxyServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -303,8 +303,8 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
|
||||
# Gauges to expose monthly active user control metrics
|
||||
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
|
||||
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
|
||||
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
||||
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
|
@ -338,6 +338,7 @@ def setup(config_options):
|
|||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||
|
@ -346,6 +347,7 @@ def setup(config_options):
|
|||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
@ -530,7 +532,7 @@ def run(hs):
|
|||
if hs.config.limit_usage_by_mau:
|
||||
count = yield hs.get_datastore().get_monthly_active_count()
|
||||
current_mau_gauge.set(float(count))
|
||||
max_mau_value_gauge.set(float(hs.config.max_mau_value))
|
||||
max_mau_gauge.set(float(hs.config.max_mau_value))
|
||||
|
||||
hs.get_datastore().initialise_reserved_users(
|
||||
hs.config.mau_limits_reserved_threepids
|
||||
|
|
|
@ -34,7 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
|
|||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.server import HomeServer
|
||||
|
@ -52,7 +52,7 @@ class MediaRepositorySlavedStore(
|
|||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
SlavedClientIpStore,
|
||||
TransactionStore,
|
||||
SlavedTransactionStore,
|
||||
BaseSlavedStore,
|
||||
MediaRepositoryStore,
|
||||
):
|
||||
|
@ -155,11 +155,13 @@ def start(config_options):
|
|||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ss = MediaRepositoryServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -214,11 +214,13 @@ def start(config_options):
|
|||
config.update_user_directory = True
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
|
||||
|
||||
ps = UserDirectoryServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
tls_client_options_factory=tls_client_options_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
|
|
|
@ -193,9 +193,8 @@ def setup_logging(config, use_worker_options=False):
|
|||
|
||||
def sighup(signum, stack):
|
||||
# it might be better to use a file watcher or something for this.
|
||||
logging.info("Reloading log config from %s due to SIGHUP",
|
||||
log_config)
|
||||
load_log_config()
|
||||
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
|
||||
|
||||
load_log_config()
|
||||
|
||||
|
|
|
@ -82,6 +82,10 @@ class ServerConfig(Config):
|
|||
self.hs_disabled = config.get("hs_disabled", False)
|
||||
self.hs_disabled_message = config.get("hs_disabled_message", "")
|
||||
|
||||
# Admin uri to direct users at should their instance become blocked
|
||||
# due to resource constraints
|
||||
self.admin_uri = config.get("admin_uri", None)
|
||||
|
||||
# FIXME: federation_domain_whitelist needs sytests
|
||||
self.federation_domain_whitelist = None
|
||||
federation_domain_whitelist = config.get(
|
||||
|
|
|
@ -11,19 +11,22 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from OpenSSL import SSL, crypto
|
||||
from twisted.internet import ssl
|
||||
from twisted.internet._sslverify import _defaultCurveName
|
||||
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
||||
from twisted.internet.ssl import CertificateOptions, ContextFactory
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerContextFactory(ssl.ContextFactory):
|
||||
class ServerContextFactory(ContextFactory):
|
||||
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
|
||||
connections and to make connections to remote servers."""
|
||||
connections."""
|
||||
|
||||
def __init__(self, config):
|
||||
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
|
@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory):
|
|||
|
||||
def getContext(self):
|
||||
return self._context
|
||||
|
||||
|
||||
def _idnaBytes(text):
|
||||
"""
|
||||
Convert some text typed by a human into some ASCII bytes. This is a
|
||||
copy of twisted.internet._idna._idnaBytes. For documentation, see the
|
||||
twisted documentation.
|
||||
"""
|
||||
try:
|
||||
import idna
|
||||
except ImportError:
|
||||
return text.encode("idna")
|
||||
else:
|
||||
return idna.encode(text)
|
||||
|
||||
|
||||
def _tolerateErrors(wrapped):
|
||||
"""
|
||||
Wrap up an info_callback for pyOpenSSL so that if something goes wrong
|
||||
the error is immediately logged and the connection is dropped if possible.
|
||||
This is a copy of twisted.internet._sslverify._tolerateErrors. For
|
||||
documentation, see the twisted documentation.
|
||||
"""
|
||||
|
||||
def infoCallback(connection, where, ret):
|
||||
try:
|
||||
return wrapped(connection, where, ret)
|
||||
except: # noqa: E722, taken from the twisted implementation
|
||||
f = Failure()
|
||||
logger.exception("Error during info_callback")
|
||||
connection.get_app_data().failVerification(f)
|
||||
|
||||
return infoCallback
|
||||
|
||||
|
||||
@implementer(IOpenSSLClientConnectionCreator)
|
||||
class ClientTLSOptions(object):
|
||||
"""
|
||||
Client creator for TLS without certificate identity verification. This is a
|
||||
copy of twisted.internet._sslverify.ClientTLSOptions with the identity
|
||||
verification left out. For documentation, see the twisted documentation.
|
||||
"""
|
||||
|
||||
def __init__(self, hostname, ctx):
|
||||
self._ctx = ctx
|
||||
self._hostname = hostname
|
||||
self._hostnameBytes = _idnaBytes(hostname)
|
||||
ctx.set_info_callback(
|
||||
_tolerateErrors(self._identityVerifyingInfoCallback)
|
||||
)
|
||||
|
||||
def clientConnectionForTLS(self, tlsProtocol):
|
||||
context = self._ctx
|
||||
connection = SSL.Connection(context, None)
|
||||
connection.set_app_data(tlsProtocol)
|
||||
return connection
|
||||
|
||||
def _identityVerifyingInfoCallback(self, connection, where, ret):
|
||||
if where & SSL.SSL_CB_HANDSHAKE_START:
|
||||
connection.set_tlsext_host_name(self._hostnameBytes)
|
||||
|
||||
|
||||
class ClientTLSOptionsFactory(object):
|
||||
"""Factory for Twisted ClientTLSOptions that are used to make connections
|
||||
to remote servers for federation."""
|
||||
|
||||
def __init__(self, config):
|
||||
# We don't use config options yet
|
||||
pass
|
||||
|
||||
def get_options(self, host):
|
||||
return ClientTLSOptions(
|
||||
host.decode('utf-8'),
|
||||
CertificateOptions(verify=False).getContext()
|
||||
)
|
||||
|
|
|
@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/"
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||
def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
|
||||
"""Fetch the keys for a remote server."""
|
||||
|
||||
factory = SynapseKeyClientFactory()
|
||||
factory.path = path
|
||||
factory.host = server_name
|
||||
endpoint = matrix_federation_endpoint(
|
||||
reactor, server_name, ssl_context_factory, timeout=30
|
||||
reactor, server_name, tls_client_options_factory, timeout=30
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
|
|
|
@ -512,7 +512,7 @@ class Keyring(object):
|
|||
continue
|
||||
|
||||
(response, tls_certificate) = yield fetch_server_key(
|
||||
server_name, self.hs.tls_server_context_factory,
|
||||
server_name, self.hs.tls_client_options_factory,
|
||||
path=(b"/_matrix/key/v2/server/%s" % (
|
||||
urllib.quote(requested_key_id),
|
||||
)).encode("ascii"),
|
||||
|
@ -655,7 +655,7 @@ class Keyring(object):
|
|||
# Try to fetch the key from the remote server.
|
||||
|
||||
(response, tls_certificate) = yield fetch_server_key(
|
||||
server_name, self.hs.tls_server_context_factory
|
||||
server_name, self.hs.tls_client_options_factory
|
||||
)
|
||||
|
||||
# Check the response.
|
||||
|
|
|
@ -39,8 +39,12 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
|
|||
from synapse.federation.persistence import TransactionActions
|
||||
from synapse.federation.units import Edu, Transaction
|
||||
from synapse.http.endpoint import parse_server_name
|
||||
from synapse.replication.http.federation import (
|
||||
ReplicationFederationSendEduRestServlet,
|
||||
ReplicationGetQueryRestServlet,
|
||||
)
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util import async
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
|
@ -67,8 +71,8 @@ class FederationServer(FederationBase):
|
|||
self.auth = hs.get_auth()
|
||||
self.handler = hs.get_handlers().federation_handler
|
||||
|
||||
self._server_linearizer = async.Linearizer("fed_server")
|
||||
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
||||
self._server_linearizer = Linearizer("fed_server")
|
||||
self._transaction_linearizer = Linearizer("fed_txn_handler")
|
||||
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
|
||||
|
@ -200,7 +204,7 @@ class FederationServer(FederationBase):
|
|||
event_id, f.getTraceback().rstrip(),
|
||||
)
|
||||
|
||||
yield async.concurrently_execute(
|
||||
yield concurrently_execute(
|
||||
process_pdus_for_room, pdus_by_room.keys(),
|
||||
TRANSACTION_CONCURRENCY_LIMIT,
|
||||
)
|
||||
|
@ -760,6 +764,8 @@ class FederationHandlerRegistry(object):
|
|||
if edu_type in self.edu_handlers:
|
||||
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
|
||||
|
||||
logger.info("Registering federation EDU handler for %r", edu_type)
|
||||
|
||||
self.edu_handlers[edu_type] = handler
|
||||
|
||||
def register_query_handler(self, query_type, handler):
|
||||
|
@ -778,6 +784,8 @@ class FederationHandlerRegistry(object):
|
|||
"Already have a Query handler for %s" % (query_type,)
|
||||
)
|
||||
|
||||
logger.info("Registering federation query handler for %r", query_type)
|
||||
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -800,3 +808,49 @@ class FederationHandlerRegistry(object):
|
|||
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
|
||||
|
||||
return handler(args)
|
||||
|
||||
|
||||
class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||
"""A FederationHandlerRegistry for worker processes.
|
||||
|
||||
When receiving EDU or queries it will check if an appropriate handler has
|
||||
been registered on the worker, if there isn't one then it calls off to the
|
||||
master process.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.config = hs.config
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
|
||||
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
|
||||
|
||||
super(ReplicationFederationHandlerRegistry, self).__init__()
|
||||
|
||||
def on_edu(self, edu_type, origin, content):
|
||||
"""Overrides FederationHandlerRegistry
|
||||
"""
|
||||
handler = self.edu_handlers.get(edu_type)
|
||||
if handler:
|
||||
return super(ReplicationFederationHandlerRegistry, self).on_edu(
|
||||
edu_type, origin, content,
|
||||
)
|
||||
|
||||
return self._send_edu(
|
||||
edu_type=edu_type,
|
||||
origin=origin,
|
||||
content=content,
|
||||
)
|
||||
|
||||
def on_query(self, query_type, args):
|
||||
"""Overrides FederationHandlerRegistry
|
||||
"""
|
||||
handler = self.query_handlers.get(query_type)
|
||||
if handler:
|
||||
return handler(args)
|
||||
|
||||
return self._get_query_client(
|
||||
query_type=query_type,
|
||||
args=args,
|
||||
)
|
||||
|
|
|
@ -828,12 +828,26 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def delete_threepid(self, user_id, medium, address):
|
||||
"""Attempts to unbind the 3pid on the identity servers and deletes it
|
||||
from the local database.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
medium (str)
|
||||
address (str)
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: Returns True if successfully unbound the 3pid on
|
||||
the identity server, False if identity server doesn't support the
|
||||
unbind API.
|
||||
"""
|
||||
|
||||
# 'Canonicalise' email addresses as per above
|
||||
if medium == 'email':
|
||||
address = address.lower()
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
yield identity_handler.unbind_threepid(
|
||||
result = yield identity_handler.try_unbind_threepid(
|
||||
user_id,
|
||||
{
|
||||
'medium': medium,
|
||||
|
@ -841,10 +855,10 @@ class AuthHandler(BaseHandler):
|
|||
},
|
||||
)
|
||||
|
||||
ret = yield self.store.user_delete_threepid(
|
||||
yield self.store.user_delete_threepid(
|
||||
user_id, medium, address,
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
defer.returnValue(result)
|
||||
|
||||
def _save_session(self, session):
|
||||
# TODO: Persistent storage
|
||||
|
|
|
@ -51,7 +51,8 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
erase_data (bool): whether to GDPR-erase the user's data
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
Deferred[bool]: True if identity server supports removing
|
||||
threepids, otherwise False.
|
||||
"""
|
||||
# FIXME: Theoretically there is a race here wherein user resets
|
||||
# password using threepid.
|
||||
|
@ -60,16 +61,22 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
# leave the user still active so they can try again.
|
||||
# Ideally we would prevent password resets and then do this in the
|
||||
# background thread.
|
||||
|
||||
# This will be set to false if the identity server doesn't support
|
||||
# unbinding
|
||||
identity_server_supports_unbinding = True
|
||||
|
||||
threepids = yield self.store.user_get_threepids(user_id)
|
||||
for threepid in threepids:
|
||||
try:
|
||||
yield self._identity_handler.unbind_threepid(
|
||||
result = yield self._identity_handler.try_unbind_threepid(
|
||||
user_id,
|
||||
{
|
||||
'medium': threepid['medium'],
|
||||
'address': threepid['address'],
|
||||
},
|
||||
)
|
||||
identity_server_supports_unbinding &= result
|
||||
except Exception:
|
||||
# Do we want this to be a fatal error or should we carry on?
|
||||
logger.exception("Failed to remove threepid from ID server")
|
||||
|
@ -103,6 +110,8 @@ class DeactivateAccountHandler(BaseHandler):
|
|||
# parts users from rooms (if it isn't already running)
|
||||
self._start_user_parting()
|
||||
|
||||
defer.returnValue(identity_server_supports_unbinding)
|
||||
|
||||
def _start_user_parting(self):
|
||||
"""
|
||||
Start the process that goes through the table of users
|
||||
|
|
|
@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes
|
|||
from synapse.api.errors import FederationDeniedError
|
||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
|
|
@ -49,10 +49,15 @@ from synapse.crypto.event_signing import (
|
|||
compute_event_signature,
|
||||
)
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.replication.http.federation import (
|
||||
ReplicationCleanRoomRestServlet,
|
||||
ReplicationFederationSendEventsRestServlet,
|
||||
)
|
||||
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
|
||||
from synapse.state import resolve_events_with_factory
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
from synapse.util.logutils import log_function
|
||||
|
@ -91,6 +96,18 @@ class FederationHandler(BaseHandler):
|
|||
self.spam_checker = hs.get_spam_checker()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
self.config = hs.config
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
self._send_events_to_master = (
|
||||
ReplicationFederationSendEventsRestServlet.make_client(hs)
|
||||
)
|
||||
self._notify_user_membership_change = (
|
||||
ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
|
||||
)
|
||||
self._clean_room_for_join_client = (
|
||||
ReplicationCleanRoomRestServlet.make_client(hs)
|
||||
)
|
||||
|
||||
# When joining a room we need to queue any events for that room up
|
||||
self.room_queues = {}
|
||||
|
@ -1158,7 +1175,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
yield self._persist_events([(event, context)])
|
||||
yield self.persist_events_and_notify([(event, context)])
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
|
@ -1189,7 +1206,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
yield self._persist_events([(event, context)])
|
||||
yield self.persist_events_and_notify([(event, context)])
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
|
@ -1432,7 +1449,7 @@ class FederationHandler(BaseHandler):
|
|||
event, context
|
||||
)
|
||||
|
||||
yield self._persist_events(
|
||||
yield self.persist_events_and_notify(
|
||||
[(event, context)],
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
@ -1470,7 +1487,7 @@ class FederationHandler(BaseHandler):
|
|||
], consumeErrors=True,
|
||||
))
|
||||
|
||||
yield self._persist_events(
|
||||
yield self.persist_events_and_notify(
|
||||
[
|
||||
(ev_info["event"], context)
|
||||
for ev_info, context in zip(event_infos, contexts)
|
||||
|
@ -1558,7 +1575,7 @@ class FederationHandler(BaseHandler):
|
|||
raise
|
||||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
yield self._persist_events(
|
||||
yield self.persist_events_and_notify(
|
||||
[
|
||||
(e, events_to_context[e.event_id])
|
||||
for e in itertools.chain(auth_events, state)
|
||||
|
@ -1569,7 +1586,7 @@ class FederationHandler(BaseHandler):
|
|||
event, old_state=state
|
||||
)
|
||||
|
||||
yield self._persist_events(
|
||||
yield self.persist_events_and_notify(
|
||||
[(event, new_event_context)],
|
||||
)
|
||||
|
||||
|
@ -2297,7 +2314,7 @@ class FederationHandler(BaseHandler):
|
|||
for revocation.
|
||||
"""
|
||||
try:
|
||||
response = yield self.hs.get_simple_http_client().get_json(
|
||||
response = yield self.http_client.get_json(
|
||||
url,
|
||||
{"public_key": public_key}
|
||||
)
|
||||
|
@ -2310,7 +2327,7 @@ class FederationHandler(BaseHandler):
|
|||
raise AuthError(403, "Third party certificate was invalid")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _persist_events(self, event_and_contexts, backfilled=False):
|
||||
def persist_events_and_notify(self, event_and_contexts, backfilled=False):
|
||||
"""Persists events and tells the notifier/pushers about them, if
|
||||
necessary.
|
||||
|
||||
|
@ -2322,14 +2339,21 @@ class FederationHandler(BaseHandler):
|
|||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
max_stream_id = yield self.store.persist_events(
|
||||
event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
if self.config.worker_app:
|
||||
yield self._send_events_to_master(
|
||||
store=self.store,
|
||||
event_and_contexts=event_and_contexts,
|
||||
backfilled=backfilled
|
||||
)
|
||||
else:
|
||||
max_stream_id = yield self.store.persist_events(
|
||||
event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
if not backfilled: # Never notify for backfilled events
|
||||
for event, _ in event_and_contexts:
|
||||
self._notify_persisted_event(event, max_stream_id)
|
||||
if not backfilled: # Never notify for backfilled events
|
||||
for event, _ in event_and_contexts:
|
||||
self._notify_persisted_event(event, max_stream_id)
|
||||
|
||||
def _notify_persisted_event(self, event, max_stream_id):
|
||||
"""Checks to see if notifier/pushers should be notified about the
|
||||
|
@ -2368,9 +2392,25 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
def _clean_room_for_join(self, room_id):
|
||||
return self.store.clean_room_for_join(room_id)
|
||||
"""Called to clean up any data in DB for a given room, ready for the
|
||||
server to join the room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
"""
|
||||
if self.config.worker_app:
|
||||
return self._clean_room_for_join_client(room_id)
|
||||
else:
|
||||
return self.store.clean_room_for_join(room_id)
|
||||
|
||||
def user_joined_room(self, user, room_id):
|
||||
"""Called when a new user has joined the room
|
||||
"""
|
||||
return user_joined_room(self.distributor, user, room_id)
|
||||
if self.config.worker_app:
|
||||
return self._notify_user_membership_change(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
change="joined",
|
||||
)
|
||||
else:
|
||||
return user_joined_room(self.distributor, user, room_id)
|
||||
|
|
|
@ -137,15 +137,19 @@ class IdentityHandler(BaseHandler):
|
|||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def unbind_threepid(self, mxid, threepid):
|
||||
"""
|
||||
Removes a binding from an identity server
|
||||
def try_unbind_threepid(self, mxid, threepid):
|
||||
"""Removes a binding from an identity server
|
||||
|
||||
Args:
|
||||
mxid (str): Matrix user ID of binding to be removed
|
||||
threepid (dict): Dict with medium & address of binding to be removed
|
||||
|
||||
Raises:
|
||||
SynapseError: If we failed to contact the identity server
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True on success, otherwise False
|
||||
Deferred[bool]: True on success, otherwise False if the identity
|
||||
server doesn't support unbinding
|
||||
"""
|
||||
logger.debug("unbinding threepid %r from %s", threepid, mxid)
|
||||
if not self.trusted_id_servers:
|
||||
|
@ -175,11 +179,21 @@ class IdentityHandler(BaseHandler):
|
|||
content=content,
|
||||
destination_is=id_server,
|
||||
)
|
||||
yield self.http_client.post_json_get_json(
|
||||
url,
|
||||
content,
|
||||
headers,
|
||||
)
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(
|
||||
url,
|
||||
content,
|
||||
headers,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if e.code in (400, 404, 501,):
|
||||
# The remote server probably doesn't support unbinding (yet)
|
||||
logger.warn("Received %d response while unbinding threepid", e.code)
|
||||
defer.returnValue(False)
|
||||
else:
|
||||
logger.error("Failed to unbind threepid on identity server: %s", e)
|
||||
raise SynapseError(502, "Failed to contact identity server")
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
|
|||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import StreamToken, UserID
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
|
|
@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event
|
|||
from synapse.events.validator import EventValidator
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.types import RoomAlias, UserID
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.logcontext import run_in_background
|
||||
from synapse.util.metrics import measure_func
|
||||
|
|
|
@ -22,7 +22,7 @@ from synapse.api.constants import Membership
|
|||
from synapse.api.errors import SynapseError
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.async import ReadWriteLock
|
||||
from synapse.util.async_helpers import ReadWriteLock
|
||||
from synapse.util.logcontext import run_in_background
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
|
|
@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.metrics import LaterGauge
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.logcontext import run_in_background
|
||||
from synapse.util.logutils import log_function
|
||||
|
@ -95,6 +95,7 @@ class PresenceHandler(object):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.hs = hs
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -230,6 +231,10 @@ class PresenceHandler(object):
|
|||
earlier than they should when synapse is restarted. This affect of this
|
||||
is some spurious presence changes that will self-correct.
|
||||
"""
|
||||
# If the DB pool has already terminated, don't try updating
|
||||
if not self.hs.get_db_pool().running:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Performing _on_shutdown. Persisting %d unpersisted changes",
|
||||
len(self.user_to_current_state)
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
|
|||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield self._check_mau_limits()
|
||||
|
||||
yield self.auth.check_auth_blocking()
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = yield self.auth_handler().hash(password)
|
||||
|
@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
|
|||
400,
|
||||
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
||||
)
|
||||
yield self._check_mau_limits()
|
||||
yield self.auth.check_auth_blocking()
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
|
@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
|
|||
"""
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
yield self._check_mau_limits()
|
||||
yield self.auth.check_auth_blocking()
|
||||
need_register = True
|
||||
|
||||
try:
|
||||
|
@ -533,14 +534,3 @@ class RegistrationHandler(BaseHandler):
|
|||
remote_room_hosts=remote_room_hosts,
|
||||
action="join",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_mau_limits(self):
|
||||
"""
|
||||
Do not accept registrations if monthly active user limits exceeded
|
||||
and limiting is enabled
|
||||
"""
|
||||
try:
|
||||
yield self.auth.check_auth_blocking()
|
||||
except AuthError as e:
|
||||
raise RegistrationError(e.code, str(e), e.errcode)
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import synapse.types
|
|||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.types import RoomID, UserID
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
|
|
@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers
|
|||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http import cancelled_to_request_timed_out_error, redact_uri
|
||||
from synapse.http.endpoint import SpiderEndpoint
|
||||
from synapse.util.async import add_timeout_to_deferred
|
||||
from synapse.util.async_helpers import add_timeout_to_deferred
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SERVER_CACHE = {}
|
||||
|
||||
# our record of an individual server which can be tried to reach a destination.
|
||||
|
@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name):
|
|||
return host, port
|
||||
|
||||
|
||||
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
|
||||
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
|
||||
timeout=None):
|
||||
"""Construct an endpoint for the given matrix destination.
|
||||
|
||||
Args:
|
||||
reactor: Twisted reactor.
|
||||
destination (bytes): The name of the server to connect to.
|
||||
ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
|
||||
which generates SSL contexts to use for TLS.
|
||||
tls_client_options_factory
|
||||
(synapse.crypto.context_factory.ClientTLSOptionsFactory):
|
||||
Factory which generates TLS options for client connections.
|
||||
timeout (int): connection timeout in seconds
|
||||
"""
|
||||
|
||||
|
@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
|
|||
if timeout is not None:
|
||||
endpoint_kw_args.update(timeout=timeout)
|
||||
|
||||
if ssl_context_factory is None:
|
||||
if tls_client_options_factory is None:
|
||||
transport_endpoint = HostnameEndpoint
|
||||
default_port = 8008
|
||||
else:
|
||||
def transport_endpoint(reactor, host, port, timeout):
|
||||
return wrapClientTLS(
|
||||
ssl_context_factory,
|
||||
tls_client_options_factory.get_options(host),
|
||||
HostnameEndpoint(reactor, host, port, timeout=timeout))
|
||||
default_port = 8448
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ from synapse.api.errors import (
|
|||
from synapse.http import cancelled_to_request_timed_out_error
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.async import add_timeout_to_deferred
|
||||
from synapse.util.async_helpers import add_timeout_to_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3
|
|||
|
||||
class MatrixFederationEndpointFactory(object):
|
||||
def __init__(self, hs):
|
||||
self.tls_server_context_factory = hs.tls_server_context_factory
|
||||
self.tls_client_options_factory = hs.tls_client_options_factory
|
||||
|
||||
def endpointForURI(self, uri):
|
||||
destination = uri.netloc
|
||||
|
||||
return matrix_federation_endpoint(
|
||||
reactor, destination, timeout=10,
|
||||
ssl_context_factory=self.tls_server_context_factory
|
||||
tls_client_options_factory=self.tls_client_options_factory
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from synapse.api.errors import AuthError
|
|||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.types import StreamToken
|
||||
from synapse.util.async import (
|
||||
from synapse.util.async_helpers import (
|
||||
DeferredTimeoutError,
|
||||
ObservableDeferred,
|
||||
add_timeout_to_deferred,
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.event_auth import get_user_power_level
|
||||
from synapse.state import POWER_KEY
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ from synapse.push.presentable_names import (
|
|||
name_from_member_event,
|
||||
)
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.replication.http import membership, send_event
|
||||
from synapse.replication.http import federation, membership, send_event
|
||||
|
||||
REPLICATION_PREFIX = "/_synapse/replication"
|
||||
|
||||
|
@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource):
|
|||
def register_servlets(self, hs):
|
||||
send_event.register_servlets(hs, self)
|
||||
membership.register_servlets(hs, self)
|
||||
federation.register_servlets(hs, self)
|
||||
|
|
259
synapse/replication/http/federation.py
Normal file
259
synapse/replication/http/federation.py
Normal file
|
@ -0,0 +1,259 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||
"""Handles events newly received from federation, including persisting and
|
||||
notifying.
|
||||
|
||||
The API looks like:
|
||||
|
||||
POST /_synapse/replication/fed_send_events/:txn_id
|
||||
|
||||
{
|
||||
"events": [{
|
||||
"event": { .. serialized event .. },
|
||||
"internal_metadata": { .. serialized internal_metadata .. },
|
||||
"rejected_reason": .., // The event.rejected_reason field
|
||||
"context": { .. serialized event context .. },
|
||||
}],
|
||||
"backfilled": false
|
||||
"""
|
||||
|
||||
NAME = "fed_send_events"
|
||||
PATH_ARGS = ()
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_handler = hs.get_handlers().federation_handler
|
||||
|
||||
@staticmethod
|
||||
@defer.inlineCallbacks
|
||||
def _serialize_payload(store, event_and_contexts, backfilled):
|
||||
"""
|
||||
Args:
|
||||
store
|
||||
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
|
||||
backfilled (bool): Whether or not the events are the result of
|
||||
backfilling
|
||||
"""
|
||||
event_payloads = []
|
||||
for event, context in event_and_contexts:
|
||||
serialized_context = yield context.serialize(event, store)
|
||||
|
||||
event_payloads.append({
|
||||
"event": event.get_pdu_json(),
|
||||
"internal_metadata": event.internal_metadata.get_dict(),
|
||||
"rejected_reason": event.rejected_reason,
|
||||
"context": serialized_context,
|
||||
})
|
||||
|
||||
payload = {
|
||||
"events": event_payloads,
|
||||
"backfilled": backfilled,
|
||||
}
|
||||
|
||||
defer.returnValue(payload)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_request(self, request):
|
||||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
backfilled = content["backfilled"]
|
||||
|
||||
event_payloads = content["events"]
|
||||
|
||||
event_and_contexts = []
|
||||
for event_payload in event_payloads:
|
||||
event_dict = event_payload["event"]
|
||||
internal_metadata = event_payload["internal_metadata"]
|
||||
rejected_reason = event_payload["rejected_reason"]
|
||||
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
|
||||
|
||||
context = yield EventContext.deserialize(
|
||||
self.store, event_payload["context"],
|
||||
)
|
||||
|
||||
event_and_contexts.append((event, context))
|
||||
|
||||
logger.info(
|
||||
"Got %d events from federation",
|
||||
len(event_and_contexts),
|
||||
)
|
||||
|
||||
yield self.federation_handler.persist_events_and_notify(
|
||||
event_and_contexts, backfilled,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
||||
"""Handles EDUs newly received from federation, including persisting and
|
||||
notifying.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/fed_send_edu/:edu_type/:txn_id
|
||||
|
||||
{
|
||||
"origin": ...,
|
||||
"content: { ... }
|
||||
}
|
||||
"""
|
||||
|
||||
NAME = "fed_send_edu"
|
||||
PATH_ARGS = ("edu_type",)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.registry = hs.get_federation_registry()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(edu_type, origin, content):
|
||||
return {
|
||||
"origin": origin,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_request(self, request, edu_type):
|
||||
with Measure(self.clock, "repl_fed_send_edu_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
origin = content["origin"]
|
||||
edu_content = content["content"]
|
||||
|
||||
logger.info(
|
||||
"Got %r edu from $s",
|
||||
edu_type, origin,
|
||||
)
|
||||
|
||||
result = yield self.registry.on_edu(edu_type, origin, edu_content)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||
"""Handle responding to queries from federation.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/fed_query/:query_type
|
||||
|
||||
{
|
||||
"args": { ... }
|
||||
}
|
||||
"""
|
||||
|
||||
NAME = "fed_query"
|
||||
PATH_ARGS = ("query_type",)
|
||||
|
||||
# This is a query, so let's not bother caching
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReplicationGetQueryRestServlet, self).__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.registry = hs.get_federation_registry()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(query_type, args):
|
||||
"""
|
||||
Args:
|
||||
query_type (str)
|
||||
args (dict): The arguments received for the given query type
|
||||
"""
|
||||
return {
|
||||
"args": args,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_request(self, request, query_type):
|
||||
with Measure(self.clock, "repl_fed_query_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
args = content["args"]
|
||||
|
||||
logger.info(
|
||||
"Got %r query",
|
||||
query_type,
|
||||
)
|
||||
|
||||
result = yield self.registry.on_query(query_type, args)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
||||
"""Called to clean up any data in DB for a given room, ready for the
|
||||
server to join the room.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
NAME = "fed_cleanup_room"
|
||||
PATH_ARGS = ("room_id",)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReplicationCleanRoomRestServlet, self).__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(room_id, args):
|
||||
"""
|
||||
Args:
|
||||
room_id (str)
|
||||
"""
|
||||
return {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_request(self, request, room_id):
|
||||
yield self.store.clean_room_for_join(room_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
||||
ReplicationFederationSendEduRestServlet(hs).register(http_server)
|
||||
ReplicationGetQueryRestServlet(hs).register(http_server)
|
||||
ReplicationCleanRoomRestServlet(hs).register(http_server)
|
|
@ -13,19 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.transactions import TransactionStore
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
|
||||
|
||||
class TransactionStore(BaseSlavedStore):
|
||||
get_destination_retry_timings = TransactionStore.__dict__[
|
||||
"get_destination_retry_timings"
|
||||
]
|
||||
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
||||
set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
|
||||
_set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
|
||||
|
||||
prep_send_transaction = DataStore.prep_send_transaction.__func__
|
||||
delivered_txn = DataStore.delivered_txn.__func__
|
||||
class SlavedTransactionStore(TransactionStore, BaseSlavedStore):
|
||||
pass
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
to ensure idempotency when performing PUTs using the REST API."""
|
||||
import logging
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -391,10 +391,17 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
yield self._deactivate_account_handler.deactivate_account(
|
||||
result = yield self._deactivate_account_handler.deactivate_account(
|
||||
target_user_id, erase,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
if result:
|
||||
id_server_unbind_result = "success"
|
||||
else:
|
||||
id_server_unbind_result = "no-support"
|
||||
|
||||
defer.returnValue((200, {
|
||||
"id_server_unbind_result": id_server_unbind_result,
|
||||
}))
|
||||
|
||||
|
||||
class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||
|
|
|
@ -209,10 +209,17 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
yield self._deactivate_account_handler.deactivate_account(
|
||||
result = yield self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(), erase,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
if result:
|
||||
id_server_unbind_result = "success"
|
||||
else:
|
||||
id_server_unbind_result = "no-support"
|
||||
|
||||
defer.returnValue((200, {
|
||||
"id_server_unbind_result": id_server_unbind_result,
|
||||
}))
|
||||
|
||||
|
||||
class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
|
@ -364,7 +371,7 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
user_id = requester.user.to_string()
|
||||
|
||||
try:
|
||||
yield self.auth_handler.delete_threepid(
|
||||
ret = yield self.auth_handler.delete_threepid(
|
||||
user_id, body['medium'], body['address']
|
||||
)
|
||||
except Exception:
|
||||
|
@ -374,7 +381,14 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
logger.exception("Failed to remove threepid")
|
||||
raise SynapseError(500, "Failed to remove threepid")
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
if ret:
|
||||
id_server_unbind_result = "success"
|
||||
else:
|
||||
id_server_unbind_result = "no-support"
|
||||
|
||||
defer.returnValue((200, {
|
||||
"id_server_unbind_result": id_server_unbind_result,
|
||||
}))
|
||||
|
||||
|
||||
class WhoamiRestServlet(RestServlet):
|
||||
|
|
|
@ -36,7 +36,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
from synapse.util.stringutils import is_ascii, random_string
|
||||
|
|
|
@ -42,7 +42,7 @@ from synapse.http.server import (
|
|||
)
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
from synapse.util.stringutils import is_ascii, random_string
|
||||
|
|
|
@ -36,6 +36,7 @@ from synapse.federation.federation_client import FederationClient
|
|||
from synapse.federation.federation_server import (
|
||||
FederationHandlerRegistry,
|
||||
FederationServer,
|
||||
ReplicationFederationHandlerRegistry,
|
||||
)
|
||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||
from synapse.federation.transaction_queue import TransactionQueue
|
||||
|
@ -423,7 +424,10 @@ class HomeServer(object):
|
|||
return RoomMemberMasterHandler(self)
|
||||
|
||||
def build_federation_registry(self):
|
||||
return FederationHandlerRegistry()
|
||||
if self.config.worker_app:
|
||||
return ReplicationFederationHandlerRegistry(self)
|
||||
else:
|
||||
return FederationHandlerRegistry()
|
||||
|
||||
def build_server_notices_manager(self):
|
||||
if self.config.worker_app:
|
||||
|
|
|
@ -28,7 +28,7 @@ from synapse import event_auth
|
|||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="style.css">
|
||||
<script src="js/jquery-2.1.3.min.js"></script>
|
||||
<script src="js/recaptcha_ajax.js"></script>
|
||||
<script src="https://www.google.com/recaptcha/api/js/recaptcha_ajax.js"></script>
|
||||
<script src="register_config.js"></script>
|
||||
<script src="js/register.js"></script>
|
||||
</head>
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||
|
||||
def _update_client_ips_batch(self):
|
||||
|
||||
# If the DB pool has already terminated, don't try updating
|
||||
if not self.hs.get_db_pool().running:
|
||||
return
|
||||
|
||||
def update():
|
||||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
|
|
|
@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
|
|||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
||||
|
@ -1435,88 +1435,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
|||
(event.event_id, event.redacts)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_events_in_timeline(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
|
||||
defer.returnValue(set(r["event_id"] for r in rows))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_seen_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Args:
|
||||
event_ids (iterable[str]):
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: The events we have already seen.
|
||||
"""
|
||||
results = set()
|
||||
|
||||
def have_seen_events_txn(txn, chunk):
|
||||
sql = (
|
||||
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
|
||||
% (",".join("?" * len(chunk)), )
|
||||
)
|
||||
txn.execute(sql, chunk)
|
||||
for (event_id, ) in txn:
|
||||
results.add(event_id)
|
||||
|
||||
# break the input up into chunks of 100
|
||||
input_iterator = iter(event_ids)
|
||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
|
||||
[]):
|
||||
yield self.runInteraction(
|
||||
"have_seen_events",
|
||||
have_seen_events_txn,
|
||||
chunk,
|
||||
)
|
||||
defer.returnValue(results)
|
||||
|
||||
def get_seen_events_with_rejections(self, event_ids):
|
||||
"""Given a list of event ids, check if we rejected them.
|
||||
|
||||
Args:
|
||||
event_ids (list[str])
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, str|None):
|
||||
Has an entry for each event id we already have seen. Maps to
|
||||
the rejected reason string if we rejected the event, else maps
|
||||
to None.
|
||||
"""
|
||||
if not event_ids:
|
||||
return defer.succeed({})
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.event_id, reason FROM events as e "
|
||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||
"WHERE e.event_id = ?"
|
||||
)
|
||||
|
||||
res = {}
|
||||
for event_id in event_ids:
|
||||
txn.execute(sql, (event_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
_, rejected = row
|
||||
res[event_id] = rejected
|
||||
|
||||
return res
|
||||
|
||||
return self.runInteraction("get_rejection_reasons", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_daily_messages(self):
|
||||
"""
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
|
@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
|
||||
|
||||
defer.returnValue(cache_entry)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_events_in_timeline(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
|
||||
defer.returnValue(set(r["event_id"] for r in rows))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_seen_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Args:
|
||||
event_ids (iterable[str]):
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: The events we have already seen.
|
||||
"""
|
||||
results = set()
|
||||
|
||||
def have_seen_events_txn(txn, chunk):
|
||||
sql = (
|
||||
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
|
||||
% (",".join("?" * len(chunk)), )
|
||||
)
|
||||
txn.execute(sql, chunk)
|
||||
for (event_id, ) in txn:
|
||||
results.add(event_id)
|
||||
|
||||
# break the input up into chunks of 100
|
||||
input_iterator = iter(event_ids)
|
||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
|
||||
[]):
|
||||
yield self.runInteraction(
|
||||
"have_seen_events",
|
||||
have_seen_events_txn,
|
||||
chunk,
|
||||
)
|
||||
defer.returnValue(results)
|
||||
|
||||
def get_seen_events_with_rejections(self, event_ids):
|
||||
"""Given a list of event ids, check if we rejected them.
|
||||
|
||||
Args:
|
||||
event_ids (list[str])
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, str|None):
|
||||
Has an entry for each event id we already have seen. Maps to
|
||||
the rejected reason string if we rejected the event, else maps
|
||||
to None.
|
||||
"""
|
||||
if not event_ids:
|
||||
return defer.succeed({})
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.event_id, reason FROM events as e "
|
||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||
"WHERE e.event_id = ?"
|
||||
)
|
||||
|
||||
res = {}
|
||||
for event_id in event_ids:
|
||||
txn.execute(sql, (event_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
_, rejected = row
|
||||
res[event_id] = rejected
|
||||
|
||||
return res
|
||||
|
||||
return self.runInteraction("get_rejection_reasons", f)
|
||||
|
|
|
@ -46,7 +46,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||
tp["medium"], tp["address"]
|
||||
)
|
||||
if user_id:
|
||||
self.upsert_monthly_active_user(user_id)
|
||||
yield self.upsert_monthly_active_user(user_id)
|
||||
reserved_user_list.append(user_id)
|
||||
else:
|
||||
logger.warning(
|
||||
|
@ -64,23 +64,27 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||
Deferred[]
|
||||
"""
|
||||
def _reap_users(txn):
|
||||
# Purge stale users
|
||||
|
||||
thirty_days_ago = (
|
||||
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
)
|
||||
# Purge stale users
|
||||
|
||||
# questionmarks is a hack to overcome sqlite not supporting
|
||||
# tuples in 'WHERE IN %s'
|
||||
questionmarks = '?' * len(self.reserved_users)
|
||||
query_args = [thirty_days_ago]
|
||||
query_args.extend(self.reserved_users)
|
||||
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
|
||||
|
||||
sql = """
|
||||
DELETE FROM monthly_active_users
|
||||
WHERE timestamp < ?
|
||||
AND user_id NOT IN ({})
|
||||
""".format(','.join(questionmarks))
|
||||
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
|
||||
# when len(reserved_users) == 0. Works fine on sqlite.
|
||||
if len(self.reserved_users) > 0:
|
||||
# questionmarks is a hack to overcome sqlite not supporting
|
||||
# tuples in 'WHERE IN %s'
|
||||
questionmarks = '?' * len(self.reserved_users)
|
||||
|
||||
query_args.extend(self.reserved_users)
|
||||
sql = base_sql + """ AND user_id NOT IN ({})""".format(
|
||||
','.join(questionmarks)
|
||||
)
|
||||
else:
|
||||
sql = base_sql
|
||||
|
||||
txn.execute(sql, query_args)
|
||||
|
||||
|
@ -93,16 +97,24 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
|||
# negative LIMIT values. So there is no way to write it that both can
|
||||
# support
|
||||
query_args = [self.hs.config.max_mau_value]
|
||||
query_args.extend(self.reserved_users)
|
||||
sql = """
|
||||
|
||||
base_sql = """
|
||||
DELETE FROM monthly_active_users
|
||||
WHERE user_id NOT IN (
|
||||
SELECT user_id FROM monthly_active_users
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
)
|
||||
AND user_id NOT IN ({})
|
||||
""".format(','.join(questionmarks))
|
||||
"""
|
||||
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
|
||||
# when len(reserved_users) == 0. Works fine on sqlite.
|
||||
if len(self.reserved_users) > 0:
|
||||
query_args.extend(self.reserved_users)
|
||||
sql = base_sql + """ AND user_id NOT IN ({})""".format(
|
||||
','.join(questionmarks)
|
||||
)
|
||||
else:
|
||||
sql = base_sql
|
||||
txn.execute(sql, query_args)
|
||||
|
||||
yield self.runInteraction("reap_monthly_active_users", _reap_users)
|
||||
|
|
|
@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
|
|||
|
||||
|
||||
class RoomWorkerStore(SQLBaseStore):
|
||||
def get_room(self, room_id):
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room to retrieve.
|
||||
Returns:
|
||||
A namedtuple containing the room information, or an empty list.
|
||||
"""
|
||||
return self._simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "is_public", "creator"),
|
||||
desc="get_room",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def get_public_room_ids(self):
|
||||
return self._simple_select_onecol(
|
||||
table="rooms",
|
||||
|
@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
|||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
||||
def get_room(self, room_id):
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room to retrieve.
|
||||
Returns:
|
||||
A namedtuple containing the room information, or an empty list.
|
||||
"""
|
||||
return self._simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "is_public", "creator"),
|
||||
desc="get_room",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_room_is_public(self, room_id, is_public):
|
||||
def set_room_is_public_txn(txn, next_id):
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.stringutils import to_ascii
|
||||
|
|
|
@ -188,62 +188,30 @@ class Linearizer(object):
|
|||
# things blocked from executing.
|
||||
self.key_to_defer = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def queue(self, key):
|
||||
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
|
||||
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
|
||||
# propagated inside inlineCallbacks until Twisted 18.7)
|
||||
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
|
||||
|
||||
# If the number of things executing is greater than the maximum
|
||||
# then add a deferred to the list of blocked items
|
||||
# When on of the things currently executing finishes it will callback
|
||||
# When one of the things currently executing finishes it will callback
|
||||
# this item so that it can continue executing.
|
||||
if entry[0] >= self.max_count:
|
||||
new_defer = defer.Deferred()
|
||||
entry[1][new_defer] = 1
|
||||
|
||||
logger.info(
|
||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
try:
|
||||
yield make_deferred_yieldable(new_defer)
|
||||
except Exception as e:
|
||||
if isinstance(e, CancelledError):
|
||||
logger.info(
|
||||
"Cancelling wait for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
"Unexpected exception waiting for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
|
||||
# we just have to take ourselves back out of the queue.
|
||||
del entry[1][new_defer]
|
||||
raise
|
||||
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
entry[0] += 1
|
||||
|
||||
# if the code holding the lock completes synchronously, then it
|
||||
# will recursively run the next claimant on the list. That can
|
||||
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||
#
|
||||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||
# ensure that we fall back to the reactor between each iteration.
|
||||
#
|
||||
# (This needs to happen while we hold the lock, and the context manager's exit
|
||||
# code must be synchronous, so this is the only sensible place.)
|
||||
yield self._clock.sleep(0)
|
||||
|
||||
res = self._await_lock(key)
|
||||
else:
|
||||
logger.info(
|
||||
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
entry[0] += 1
|
||||
res = defer.succeed(None)
|
||||
|
||||
# once we successfully get the lock, we need to return a context manager which
|
||||
# will release the lock.
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
def _ctx_manager(_):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -264,7 +232,64 @@ class Linearizer(object):
|
|||
# map.
|
||||
del self.key_to_defer[key]
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
res.addCallback(_ctx_manager)
|
||||
return res
|
||||
|
||||
def _await_lock(self, key):
|
||||
"""Helper for queue: adds a deferred to the queue
|
||||
|
||||
Assumes that we've already checked that we've reached the limit of the number
|
||||
of lock-holders we allow. Creates a new deferred which is added to the list, and
|
||||
adds some management around cancellations.
|
||||
|
||||
Returns the deferred, which will callback once we have secured the lock.
|
||||
|
||||
"""
|
||||
entry = self.key_to_defer[key]
|
||||
|
||||
logger.info(
|
||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
|
||||
new_defer = make_deferred_yieldable(defer.Deferred())
|
||||
entry[1][new_defer] = 1
|
||||
|
||||
def cb(_r):
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
entry[0] += 1
|
||||
|
||||
# if the code holding the lock completes synchronously, then it
|
||||
# will recursively run the next claimant on the list. That can
|
||||
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||
#
|
||||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||
# ensure that we fall back to the reactor between each iteration.
|
||||
#
|
||||
# (This needs to happen while we hold the lock, and the context manager's exit
|
||||
# code must be synchronous, so this is the only sensible place.)
|
||||
return self._clock.sleep(0)
|
||||
|
||||
def eb(e):
|
||||
logger.info("defer %r got err %r", new_defer, e)
|
||||
if isinstance(e, CancelledError):
|
||||
logger.info(
|
||||
"Cancelling wait for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warn(
|
||||
"Unexpected exception waiting for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
|
||||
# we just have to take ourselves back out of the queue.
|
||||
del entry[1][new_defer]
|
||||
return e
|
||||
|
||||
new_defer.addCallbacks(cb, eb)
|
||||
return new_defer
|
||||
|
||||
|
||||
class ReadWriteLock(object):
|
|
@ -25,7 +25,7 @@ from six import itervalues, string_types
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches import get_cache_factor_for
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
|
||||
|
||||
class SnapshotCache(object):
|
||||
|
|
|
@ -526,7 +526,7 @@ _to_ignore = [
|
|||
"synapse.util.logcontext",
|
||||
"synapse.http.server",
|
||||
"synapse.storage._base",
|
||||
"synapse.util.async",
|
||||
"synapse.util.async_helpers",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -15,4 +15,7 @@
|
|||
|
||||
from twisted.trial import util
|
||||
|
||||
from tests import utils
|
||||
|
||||
util.DEFAULT_TIMEOUT_DURATION = 10
|
||||
utils.setupdb()
|
||||
|
|
|
@ -34,13 +34,12 @@ class TestHandlers(object):
|
|||
|
||||
|
||||
class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.state_handler = Mock()
|
||||
self.store = Mock()
|
||||
|
||||
self.hs = yield setup_test_homeserver(handlers=None)
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
|
||||
self.hs.get_datastore = Mock(return_value=self.store)
|
||||
self.hs.handlers = TestHandlers(self.hs)
|
||||
self.auth = Auth(self.hs)
|
||||
|
@ -53,11 +52,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_user_valid_token(self):
|
||||
user_info = {
|
||||
"name": self.test_user,
|
||||
"token_id": "ditto",
|
||||
"device_id": "device",
|
||||
}
|
||||
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||
|
||||
request = Mock(args={})
|
||||
|
@ -76,10 +71,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.failureResultOf(d, AuthError)
|
||||
|
||||
def test_get_user_by_req_user_missing_token(self):
|
||||
user_info = {
|
||||
"name": self.test_user,
|
||||
"token_id": "ditto",
|
||||
}
|
||||
user_info = {"name": self.test_user, "token_id": "ditto"}
|
||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||
|
||||
request = Mock(args={})
|
||||
|
@ -90,8 +82,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_appservice_valid_token(self):
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
ip_range_whitelist=None,
|
||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||
)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||
|
@ -106,8 +97,11 @@ class AuthTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_get_user_by_req_appservice_valid_token_good_ip(self):
|
||||
from netaddr import IPSet
|
||||
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
token="foobar",
|
||||
url="a_url",
|
||||
sender=self.test_user,
|
||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||
)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
|
@ -122,8 +116,11 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
|
||||
from netaddr import IPSet
|
||||
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
token="foobar",
|
||||
url="a_url",
|
||||
sender=self.test_user,
|
||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||
)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
|
@ -160,8 +157,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
ip_range_whitelist=None,
|
||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||
)
|
||||
app_service.is_interested_in_user = Mock(return_value=True)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
|
@ -174,15 +170,13 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(
|
||||
requester.user.to_string(),
|
||||
masquerading_user_id.decode('utf8')
|
||||
requester.user.to_string(), masquerading_user_id.decode('utf8')
|
||||
)
|
||||
|
||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||
app_service = Mock(
|
||||
token="foobar", url="a_url", sender=self.test_user,
|
||||
ip_range_whitelist=None,
|
||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||
)
|
||||
app_service.is_interested_in_user = Mock(return_value=False)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
|
@ -201,17 +195,15 @@ class AuthTestCase(unittest.TestCase):
|
|||
# TODO(danielwh): Remove this mock when we remove the
|
||||
# get_user_by_access_token fallback.
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
return_value={
|
||||
"name": "@baldrick:matrix.org",
|
||||
"device_id": "device",
|
||||
}
|
||||
return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
|
||||
)
|
||||
|
||||
user_id = "@baldrick:matrix.org"
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
|
@ -225,15 +217,14 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_guest_user_from_macaroon(self):
|
||||
self.store.get_user_by_id = Mock(return_value={
|
||||
"is_guest": True,
|
||||
})
|
||||
self.store.get_user_by_id = Mock(return_value={"is_guest": True})
|
||||
|
||||
user_id = "@baldrick:matrix.org"
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
|
@ -257,7 +248,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
|
@ -277,7 +269,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
|
||||
|
@ -298,7 +291,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key + "wrong")
|
||||
key=self.hs.config.macaroon_secret_key + "wrong",
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
|
@ -320,7 +314,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
|
@ -347,7 +342,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||
|
@ -380,7 +376,8 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
key=self.hs.config.macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
|
@ -401,9 +398,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
token = yield self.hs.handlers.auth_handler.issue_access_token(
|
||||
USER_ID, "DEVICE"
|
||||
)
|
||||
self.store.add_access_token_to_user.assert_called_with(
|
||||
USER_ID, token, "DEVICE"
|
||||
)
|
||||
self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
|
||||
|
||||
def get_user(tok):
|
||||
if token != tok:
|
||||
|
@ -414,10 +409,9 @@ class AuthTestCase(unittest.TestCase):
|
|||
"token_id": 1234,
|
||||
"device_id": "DEVICE",
|
||||
}
|
||||
|
||||
self.store.get_user_by_access_token = get_user
|
||||
self.store.get_user_by_id = Mock(return_value={
|
||||
"is_guest": False,
|
||||
})
|
||||
self.store.get_user_by_id = Mock(return_value={"is_guest": False})
|
||||
|
||||
# check the token works
|
||||
request = Mock(args={})
|
||||
|
@ -461,8 +455,11 @@ class AuthTestCase(unittest.TestCase):
|
|||
return_value=defer.succeed(lots_of_users)
|
||||
)
|
||||
|
||||
with self.assertRaises(AuthError):
|
||||
with self.assertRaises(AuthError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
||||
# Ensure does not throw an error
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
|
@ -476,5 +473,6 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
with self.assertRaises(AuthError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
|
||||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
|
|
@ -38,7 +38,6 @@ def MockEvent(**kwargs):
|
|||
|
||||
|
||||
class FilteringTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_federation_resource = MockHttpResource()
|
||||
|
@ -47,6 +46,7 @@ class FilteringTestCase(unittest.TestCase):
|
|||
self.mock_http_client.put_json = DeferredMockCallable()
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
handlers=None,
|
||||
http_client=self.mock_http_client,
|
||||
keyring=Mock(),
|
||||
|
@ -64,7 +64,7 @@ class FilteringTestCase(unittest.TestCase):
|
|||
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
|
||||
{"event_format": "other"},
|
||||
{"room": {"not_rooms": ["#foo:pik-test"]}},
|
||||
{"presence": {"senders": ["@bar;pik.test.com"]}}
|
||||
{"presence": {"senders": ["@bar;pik.test.com"]}},
|
||||
]
|
||||
for filter in invalid_filters:
|
||||
with self.assertRaises(SynapseError) as check_filter_error:
|
||||
|
@ -81,34 +81,34 @@ class FilteringTestCase(unittest.TestCase):
|
|||
"include_leave": False,
|
||||
"rooms": ["!dee:pik-test"],
|
||||
"not_rooms": ["!gee:pik-test"],
|
||||
"account_data": {"limit": 0, "types": ["*"]}
|
||||
"account_data": {"limit": 0, "types": ["*"]},
|
||||
}
|
||||
},
|
||||
{
|
||||
"room": {
|
||||
"state": {
|
||||
"types": ["m.room.*"],
|
||||
"not_rooms": ["!726s6s6q:example.com"]
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
},
|
||||
"timeline": {
|
||||
"limit": 10,
|
||||
"types": ["m.room.message"],
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
"not_senders": ["@spam:example.com"]
|
||||
"not_senders": ["@spam:example.com"],
|
||||
},
|
||||
"ephemeral": {
|
||||
"types": ["m.receipt", "m.typing"],
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
"not_senders": ["@spam:example.com"]
|
||||
}
|
||||
"not_senders": ["@spam:example.com"],
|
||||
},
|
||||
},
|
||||
"presence": {
|
||||
"types": ["m.presence"],
|
||||
"not_senders": ["@alice:example.com"]
|
||||
"not_senders": ["@alice:example.com"],
|
||||
},
|
||||
"event_format": "client",
|
||||
"event_fields": ["type", "content", "sender"]
|
||||
}
|
||||
"event_fields": ["type", "content", "sender"],
|
||||
},
|
||||
]
|
||||
for filter in valid_filters:
|
||||
try:
|
||||
|
@ -121,229 +121,131 @@ class FilteringTestCase(unittest.TestCase):
|
|||
pass
|
||||
|
||||
def test_definition_types_works_with_literals(self):
|
||||
definition = {
|
||||
"types": ["m.room.message", "org.matrix.foo.bar"]
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
)
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_types_works_with_wildcards(self):
|
||||
definition = {
|
||||
"types": ["m.*", "org.matrix.foo.bar"]
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
)
|
||||
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_types_works_with_unknowns(self):
|
||||
definition = {
|
||||
"types": ["m.room.message", "org.matrix.foo.bar"]
|
||||
}
|
||||
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="now.for.something.completely.different",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
room_id="!foo:bar",
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_types_works_with_literals(self):
|
||||
definition = {
|
||||
"not_types": ["m.room.message", "org.matrix.foo.bar"]
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
)
|
||||
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_types_works_with_wildcards(self):
|
||||
definition = {
|
||||
"not_types": ["m.room.message", "org.matrix.*"]
|
||||
}
|
||||
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="org.matrix.custom.event",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_types_works_with_unknowns(self):
|
||||
definition = {
|
||||
"not_types": ["m.*", "org.*"]
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="com.nom.nom.nom",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
)
|
||||
definition = {"not_types": ["m.*", "org.*"]}
|
||||
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_types_takes_priority_over_types(self):
|
||||
definition = {
|
||||
"not_types": ["m.*", "org.*"],
|
||||
"types": ["m.room.message", "m.room.topic"]
|
||||
"types": ["m.room.message", "m.room.topic"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.topic",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
)
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_senders_works_with_literals(self):
|
||||
definition = {
|
||||
"senders": ["@flibble:wibble"]
|
||||
}
|
||||
definition = {"senders": ["@flibble:wibble"]}
|
||||
event = MockEvent(
|
||||
sender="@flibble:wibble",
|
||||
type="com.nom.nom.nom",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||
)
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_senders_works_with_unknowns(self):
|
||||
definition = {
|
||||
"senders": ["@flibble:wibble"]
|
||||
}
|
||||
definition = {"senders": ["@flibble:wibble"]}
|
||||
event = MockEvent(
|
||||
sender="@challenger:appears",
|
||||
type="com.nom.nom.nom",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_senders_works_with_literals(self):
|
||||
definition = {
|
||||
"not_senders": ["@flibble:wibble"]
|
||||
}
|
||||
definition = {"not_senders": ["@flibble:wibble"]}
|
||||
event = MockEvent(
|
||||
sender="@flibble:wibble",
|
||||
type="com.nom.nom.nom",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_senders_works_with_unknowns(self):
|
||||
definition = {
|
||||
"not_senders": ["@flibble:wibble"]
|
||||
}
|
||||
definition = {"not_senders": ["@flibble:wibble"]}
|
||||
event = MockEvent(
|
||||
sender="@challenger:appears",
|
||||
type="com.nom.nom.nom",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||
)
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_senders_takes_priority_over_senders(self):
|
||||
definition = {
|
||||
"not_senders": ["@misspiggy:muppets"],
|
||||
"senders": ["@kermit:muppets", "@misspiggy:muppets"]
|
||||
"senders": ["@kermit:muppets", "@misspiggy:muppets"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@misspiggy:muppets",
|
||||
type="m.room.topic",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_rooms_works_with_literals(self):
|
||||
definition = {
|
||||
"rooms": ["!secretbase:unknown"]
|
||||
}
|
||||
definition = {"rooms": ["!secretbase:unknown"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown"
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||
)
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_rooms_works_with_unknowns(self):
|
||||
definition = {
|
||||
"rooms": ["!secretbase:unknown"]
|
||||
}
|
||||
definition = {"rooms": ["!secretbase:unknown"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!anothersecretbase:unknown"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
room_id="!anothersecretbase:unknown",
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_rooms_works_with_literals(self):
|
||||
definition = {
|
||||
"not_rooms": ["!anothersecretbase:unknown"]
|
||||
}
|
||||
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!anothersecretbase:unknown"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
room_id="!anothersecretbase:unknown",
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_rooms_works_with_unknowns(self):
|
||||
definition = {
|
||||
"not_rooms": ["!secretbase:unknown"]
|
||||
}
|
||||
definition = {"not_rooms": ["!secretbase:unknown"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!anothersecretbase:unknown"
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
room_id="!anothersecretbase:unknown",
|
||||
)
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_not_rooms_takes_priority_over_rooms(self):
|
||||
definition = {
|
||||
"not_rooms": ["!secretbase:unknown"],
|
||||
"rooms": ["!secretbase:unknown"]
|
||||
"rooms": ["!secretbase:unknown"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown"
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_combined_event(self):
|
||||
definition = {
|
||||
|
@ -352,16 +254,14 @@ class FilteringTestCase(unittest.TestCase):
|
|||
"rooms": ["!stage:unknown"],
|
||||
"not_rooms": ["!piggyshouse:muppets"],
|
||||
"types": ["m.room.message", "muppets.kermit.*"],
|
||||
"not_types": ["muppets.misspiggy.*"]
|
||||
"not_types": ["muppets.misspiggy.*"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@kermit:muppets", # yup
|
||||
type="m.room.message", # yup
|
||||
room_id="!stage:unknown" # yup
|
||||
)
|
||||
self.assertTrue(
|
||||
Filter(definition).check(event)
|
||||
room_id="!stage:unknown", # yup
|
||||
)
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
def test_definition_combined_event_bad_sender(self):
|
||||
definition = {
|
||||
|
@ -370,16 +270,14 @@ class FilteringTestCase(unittest.TestCase):
|
|||
"rooms": ["!stage:unknown"],
|
||||
"not_rooms": ["!piggyshouse:muppets"],
|
||||
"types": ["m.room.message", "muppets.kermit.*"],
|
||||
"not_types": ["muppets.misspiggy.*"]
|
||||
"not_types": ["muppets.misspiggy.*"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@misspiggy:muppets", # nope
|
||||
type="m.room.message", # yup
|
||||
room_id="!stage:unknown" # yup
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
room_id="!stage:unknown", # yup
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_combined_event_bad_room(self):
|
||||
definition = {
|
||||
|
@ -388,16 +286,14 @@ class FilteringTestCase(unittest.TestCase):
|
|||
"rooms": ["!stage:unknown"],
|
||||
"not_rooms": ["!piggyshouse:muppets"],
|
||||
"types": ["m.room.message", "muppets.kermit.*"],
|
||||
"not_types": ["muppets.misspiggy.*"]
|
||||
"not_types": ["muppets.misspiggy.*"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@kermit:muppets", # yup
|
||||
type="m.room.message", # yup
|
||||
room_id="!piggyshouse:muppets" # nope
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
room_id="!piggyshouse:muppets", # nope
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_definition_combined_event_bad_type(self):
|
||||
definition = {
|
||||
|
@ -406,37 +302,26 @@ class FilteringTestCase(unittest.TestCase):
|
|||
"rooms": ["!stage:unknown"],
|
||||
"not_rooms": ["!piggyshouse:muppets"],
|
||||
"types": ["m.room.message", "muppets.kermit.*"],
|
||||
"not_types": ["muppets.misspiggy.*"]
|
||||
"not_types": ["muppets.misspiggy.*"],
|
||||
}
|
||||
event = MockEvent(
|
||||
sender="@kermit:muppets", # yup
|
||||
type="muppets.misspiggy.kisses", # nope
|
||||
room_id="!stage:unknown" # yup
|
||||
)
|
||||
self.assertFalse(
|
||||
Filter(definition).check(event)
|
||||
room_id="!stage:unknown", # yup
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter_presence_match(self):
|
||||
user_filter_json = {
|
||||
"presence": {
|
||||
"types": ["m.*"]
|
||||
}
|
||||
}
|
||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||
filter_id = yield self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
user_filter=user_filter_json,
|
||||
)
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.profile",
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
)
|
||||
event = MockEvent(sender="@foo:bar", type="m.profile")
|
||||
events = [event]
|
||||
|
||||
user_filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=filter_id,
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
|
||||
results = user_filter.filter_presence(events=events)
|
||||
|
@ -444,15 +329,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter_presence_no_match(self):
|
||||
user_filter_json = {
|
||||
"presence": {
|
||||
"types": ["m.*"]
|
||||
}
|
||||
}
|
||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||
|
||||
filter_id = yield self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart + "2",
|
||||
user_filter=user_filter_json,
|
||||
user_localpart=user_localpart + "2", user_filter=user_filter_json
|
||||
)
|
||||
event = MockEvent(
|
||||
event_id="$asdasd:localhost",
|
||||
|
@ -462,8 +342,7 @@ class FilteringTestCase(unittest.TestCase):
|
|||
events = [event]
|
||||
|
||||
user_filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart + "2",
|
||||
filter_id=filter_id,
|
||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||
)
|
||||
|
||||
results = user_filter.filter_presence(events=events)
|
||||
|
@ -471,27 +350,15 @@ class FilteringTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter_room_state_match(self):
|
||||
user_filter_json = {
|
||||
"room": {
|
||||
"state": {
|
||||
"types": ["m.*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
filter_id = yield self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
user_filter=user_filter_json,
|
||||
)
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.topic",
|
||||
room_id="!foo:bar"
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
)
|
||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||
events = [event]
|
||||
|
||||
user_filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=filter_id,
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
|
||||
results = user_filter.filter_room_state(events=events)
|
||||
|
@ -499,27 +366,17 @@ class FilteringTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter_room_state_no_match(self):
|
||||
user_filter_json = {
|
||||
"room": {
|
||||
"state": {
|
||||
"types": ["m.*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
filter_id = yield self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
user_filter=user_filter_json,
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
)
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="org.matrix.custom.event",
|
||||
room_id="!foo:bar"
|
||||
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
||||
)
|
||||
events = [event]
|
||||
|
||||
user_filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=filter_id,
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
|
||||
results = user_filter.filter_room_state(events)
|
||||
|
@ -543,45 +400,32 @@ class FilteringTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter(self):
|
||||
user_filter_json = {
|
||||
"room": {
|
||||
"state": {
|
||||
"types": ["m.*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
|
||||
filter_id = yield self.filtering.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
user_filter=user_filter_json,
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
)
|
||||
|
||||
self.assertEquals(filter_id, 0)
|
||||
self.assertEquals(user_filter_json, (
|
||||
yield self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=0,
|
||||
)
|
||||
))
|
||||
self.assertEquals(
|
||||
user_filter_json,
|
||||
(
|
||||
yield self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=0
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter(self):
|
||||
user_filter_json = {
|
||||
"room": {
|
||||
"state": {
|
||||
"types": ["m.*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||
|
||||
filter_id = yield self.datastore.add_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
user_filter=user_filter_json,
|
||||
user_localpart=user_localpart, user_filter=user_filter_json
|
||||
)
|
||||
|
||||
filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart,
|
||||
filter_id=filter_id,
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
|
||||
self.assertEquals(filter.get_filter_json(), user_filter_json)
|
||||
|
|
|
@ -4,17 +4,16 @@ from tests import unittest
|
|||
|
||||
|
||||
class TestRatelimiter(unittest.TestCase):
|
||||
|
||||
def test_allowed(self):
|
||||
limiter = Ratelimiter()
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
|
||||
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(10., time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1,
|
||||
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
|
||||
)
|
||||
self.assertFalse(allowed)
|
||||
self.assertEquals(10., time_allowed)
|
||||
|
@ -28,7 +27,7 @@ class TestRatelimiter(unittest.TestCase):
|
|||
def test_pruning(self):
|
||||
limiter = Ratelimiter()
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
|
||||
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
||||
)
|
||||
|
||||
self.assertIn("test_id_1", limiter.message_counts)
|
||||
|
|
|
@ -24,14 +24,10 @@ from tests import unittest
|
|||
|
||||
|
||||
def _regex(regex, exclusive=True):
|
||||
return {
|
||||
"regex": re.compile(regex),
|
||||
"exclusive": exclusive
|
||||
}
|
||||
return {"regex": re.compile(regex), "exclusive": exclusive}
|
||||
|
||||
|
||||
class ApplicationServiceTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.service = ApplicationService(
|
||||
id="unique_identifier",
|
||||
|
@ -41,8 +37,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
namespaces={
|
||||
ApplicationService.NS_USERS: [],
|
||||
ApplicationService.NS_ROOMS: [],
|
||||
ApplicationService.NS_ALIASES: []
|
||||
}
|
||||
ApplicationService.NS_ALIASES: [],
|
||||
},
|
||||
)
|
||||
self.event = Mock(
|
||||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||
|
@ -52,25 +48,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_user_id_prefix_no_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.assertFalse((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_room_member_is_checked(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@someone_else:matrix.org"
|
||||
self.event.type = "m.room.member"
|
||||
self.event.state_key = "@irc_foobar:matrix.org"
|
||||
|
@ -98,60 +88,47 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#irc_foobar:matrix.org", "#athing:matrix.org"
|
||||
"#irc_foobar:matrix.org",
|
||||
"#athing:matrix.org",
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
self.assertTrue((yield self.service.is_interested(self.event, self.store)))
|
||||
|
||||
def test_non_exclusive_alias(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org", exclusive=False)
|
||||
)
|
||||
self.assertFalse(self.service.is_exclusive_alias(
|
||||
"#irc_foobar:matrix.org"
|
||||
))
|
||||
self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
|
||||
|
||||
def test_non_exclusive_room(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!irc_.*:matrix.org", exclusive=False)
|
||||
)
|
||||
self.assertFalse(self.service.is_exclusive_room(
|
||||
"!irc_foobar:matrix.org"
|
||||
))
|
||||
self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
|
||||
|
||||
def test_non_exclusive_user(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*:matrix.org", exclusive=False)
|
||||
)
|
||||
self.assertFalse(self.service.is_exclusive_user(
|
||||
"@irc_foobar:matrix.org"
|
||||
))
|
||||
self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
|
||||
|
||||
def test_exclusive_alias(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org", exclusive=True)
|
||||
)
|
||||
self.assertTrue(self.service.is_exclusive_alias(
|
||||
"#irc_foobar:matrix.org"
|
||||
))
|
||||
self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
|
||||
|
||||
def test_exclusive_user(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*:matrix.org", exclusive=True)
|
||||
)
|
||||
self.assertTrue(self.service.is_exclusive_user(
|
||||
"@irc_foobar:matrix.org"
|
||||
))
|
||||
self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
|
||||
|
||||
def test_exclusive_room(self):
|
||||
self.service.namespaces[ApplicationService.NS_ROOMS].append(
|
||||
_regex("!irc_.*:matrix.org", exclusive=True)
|
||||
)
|
||||
self.assertTrue(self.service.is_exclusive_room(
|
||||
"!irc_foobar:matrix.org"
|
||||
))
|
||||
self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_alias_no_match(self):
|
||||
|
@ -159,47 +136,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.store.get_aliases_for_room.return_value = [
|
||||
"#xmpp_foobar:matrix.org", "#athing:matrix.org"
|
||||
"#xmpp_foobar:matrix.org",
|
||||
"#athing:matrix.org",
|
||||
]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertFalse((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
self.assertFalse((yield self.service.is_interested(self.event, self.store)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_regex_multiple_matches(self):
|
||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||
_regex("#irc_.*:matrix.org")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.sender = "@irc_foobar:matrix.org"
|
||||
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
|
||||
self.store.get_users_in_room.return_value = []
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
self.event, self.store
|
||||
)))
|
||||
self.assertTrue((yield self.service.is_interested(self.event, self.store)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_interested_in_self(self):
|
||||
# make sure invites get through
|
||||
self.service.sender = "@appservice:name"
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.event.type = "m.room.member"
|
||||
self.event.content = {
|
||||
"membership": "invite"
|
||||
}
|
||||
self.event.content = {"membership": "invite"}
|
||||
self.event.state_key = self.service.sender
|
||||
self.assertTrue((yield self.service.is_interested(self.event)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_member_list_match(self):
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(
|
||||
_regex("@irc_.*")
|
||||
)
|
||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||
self.store.get_users_in_room.return_value = [
|
||||
"@alice:here",
|
||||
"@irc_fo:here", # AS user
|
||||
|
@ -208,6 +174,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
self.store.get_aliases_for_room.return_value = []
|
||||
|
||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||
self.assertTrue((yield self.service.is_interested(
|
||||
event=self.event, store=self.store
|
||||
)))
|
||||
self.assertTrue(
|
||||
(yield self.service.is_interested(event=self.event, store=self.store))
|
||||
)
|
||||
|
|
|
@ -30,7 +30,6 @@ from ..utils import MockClock
|
|||
|
||||
|
||||
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
self.store = Mock()
|
||||
|
@ -38,8 +37,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
self.recoverer = Mock()
|
||||
self.recoverer_fn = Mock(return_value=self.recoverer)
|
||||
self.txnctrl = _TransactionController(
|
||||
clock=self.clock, store=self.store, as_api=self.as_api,
|
||||
recoverer_fn=self.recoverer_fn
|
||||
clock=self.clock,
|
||||
store=self.store,
|
||||
as_api=self.as_api,
|
||||
recoverer_fn=self.recoverer_fn,
|
||||
)
|
||||
|
||||
def test_single_service_up_txn_sent(self):
|
||||
|
@ -54,9 +55,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
return_value=defer.succeed(ApplicationServiceState.UP)
|
||||
)
|
||||
txn.send = Mock(return_value=defer.succeed(True))
|
||||
self.store.create_appservice_txn = Mock(
|
||||
return_value=defer.succeed(txn)
|
||||
)
|
||||
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||
|
||||
# actual call
|
||||
self.txnctrl.send(service, events)
|
||||
|
@ -77,9 +76,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
self.store.get_appservice_state = Mock(
|
||||
return_value=defer.succeed(ApplicationServiceState.DOWN)
|
||||
)
|
||||
self.store.create_appservice_txn = Mock(
|
||||
return_value=defer.succeed(txn)
|
||||
)
|
||||
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||
|
||||
# actual call
|
||||
self.txnctrl.send(service, events)
|
||||
|
@ -104,9 +101,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
)
|
||||
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
|
||||
txn.send = Mock(return_value=defer.succeed(False)) # fails to send
|
||||
self.store.create_appservice_txn = Mock(
|
||||
return_value=defer.succeed(txn)
|
||||
)
|
||||
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
|
||||
|
||||
# actual call
|
||||
self.txnctrl.send(service, events)
|
||||
|
@ -124,7 +119,6 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
self.as_api = Mock()
|
||||
|
@ -146,6 +140,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||
|
||||
def take_txn(*args, **kwargs):
|
||||
return defer.succeed(txns.pop(0))
|
||||
|
||||
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
|
||||
|
||||
self.recoverer.recover()
|
||||
|
@ -171,6 +166,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||
return defer.succeed(txns.pop(0))
|
||||
else:
|
||||
return defer.succeed(txn)
|
||||
|
||||
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
|
||||
|
||||
self.recoverer.recover()
|
||||
|
@ -197,7 +193,6 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.txn_ctrl = Mock()
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
|
||||
|
@ -211,9 +206,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||
|
||||
def test_send_single_event_with_queue(self):
|
||||
d = defer.Deferred()
|
||||
self.txn_ctrl.send = Mock(
|
||||
side_effect=lambda x, y: make_deferred_yieldable(d),
|
||||
)
|
||||
self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
|
||||
service = Mock(id=4)
|
||||
event = Mock(event_id="first")
|
||||
event2 = Mock(event_id="second")
|
||||
|
@ -247,6 +240,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
|||
|
||||
def do_send(x, y):
|
||||
return make_deferred_yieldable(send_return_list.pop(0))
|
||||
|
||||
self.txn_ctrl.send = Mock(side_effect=do_send)
|
||||
|
||||
# send events for different ASes and make sure they are sent
|
||||
|
|
|
@ -24,7 +24,6 @@ from tests import unittest
|
|||
|
||||
|
||||
class ConfigGenerationTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.dir = tempfile.mkdtemp()
|
||||
self.file = os.path.join(self.dir, "homeserver.yaml")
|
||||
|
@ -33,23 +32,30 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||
shutil.rmtree(self.dir)
|
||||
|
||||
def test_generate_config_generates_files(self):
|
||||
HomeServerConfig.load_or_generate_config("", [
|
||||
"--generate-config",
|
||||
"-c", self.file,
|
||||
"--report-stats=yes",
|
||||
"-H", "lemurs.win"
|
||||
])
|
||||
HomeServerConfig.load_or_generate_config(
|
||||
"",
|
||||
[
|
||||
"--generate-config",
|
||||
"-c",
|
||||
self.file,
|
||||
"--report-stats=yes",
|
||||
"-H",
|
||||
"lemurs.win",
|
||||
],
|
||||
)
|
||||
|
||||
self.assertSetEqual(
|
||||
set([
|
||||
"homeserver.yaml",
|
||||
"lemurs.win.log.config",
|
||||
"lemurs.win.signing.key",
|
||||
"lemurs.win.tls.crt",
|
||||
"lemurs.win.tls.dh",
|
||||
"lemurs.win.tls.key",
|
||||
]),
|
||||
set(os.listdir(self.dir))
|
||||
set(
|
||||
[
|
||||
"homeserver.yaml",
|
||||
"lemurs.win.log.config",
|
||||
"lemurs.win.signing.key",
|
||||
"lemurs.win.tls.crt",
|
||||
"lemurs.win.tls.dh",
|
||||
"lemurs.win.tls.key",
|
||||
]
|
||||
),
|
||||
set(os.listdir(self.dir)),
|
||||
)
|
||||
|
||||
self.assert_log_filename_is(
|
||||
|
|
|
@ -24,7 +24,6 @@ from tests import unittest
|
|||
|
||||
|
||||
class ConfigLoadingTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.dir = tempfile.mkdtemp()
|
||||
print(self.dir)
|
||||
|
@ -43,15 +42,14 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||
def test_generates_and_loads_macaroon_secret_key(self):
|
||||
self.generate_config()
|
||||
|
||||
with open(self.file,
|
||||
"r") as f:
|
||||
with open(self.file, "r") as f:
|
||||
raw = yaml.load(f)
|
||||
self.assertIn("macaroon_secret_key", raw)
|
||||
|
||||
config = HomeServerConfig.load_config("", ["-c", self.file])
|
||||
self.assertTrue(
|
||||
hasattr(config, "macaroon_secret_key"),
|
||||
"Want config to have attr macaroon_secret_key"
|
||||
"Want config to have attr macaroon_secret_key",
|
||||
)
|
||||
if len(config.macaroon_secret_key) < 5:
|
||||
self.fail(
|
||||
|
@ -62,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||
self.assertTrue(
|
||||
hasattr(config, "macaroon_secret_key"),
|
||||
"Want config to have attr macaroon_secret_key"
|
||||
"Want config to have attr macaroon_secret_key",
|
||||
)
|
||||
if len(config.macaroon_secret_key) < 5:
|
||||
self.fail(
|
||||
|
@ -80,10 +78,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||
|
||||
def test_disable_registration(self):
|
||||
self.generate_config()
|
||||
self.add_lines_to_config([
|
||||
"enable_registration: true",
|
||||
"disable_registration: true",
|
||||
])
|
||||
self.add_lines_to_config(
|
||||
["enable_registration: true", "disable_registration: true"]
|
||||
)
|
||||
# Check that disable_registration clobbers enable_registration.
|
||||
config = HomeServerConfig.load_config("", ["-c", self.file])
|
||||
self.assertFalse(config.enable_registration)
|
||||
|
@ -92,18 +89,23 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
|||
self.assertFalse(config.enable_registration)
|
||||
|
||||
# Check that either config value is clobbered by the command line.
|
||||
config = HomeServerConfig.load_or_generate_config("", [
|
||||
"-c", self.file, "--enable-registration"
|
||||
])
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
"", ["-c", self.file, "--enable-registration"]
|
||||
)
|
||||
self.assertTrue(config.enable_registration)
|
||||
|
||||
def generate_config(self):
|
||||
HomeServerConfig.load_or_generate_config("", [
|
||||
"--generate-config",
|
||||
"-c", self.file,
|
||||
"--report-stats=yes",
|
||||
"-H", "lemurs.win"
|
||||
])
|
||||
HomeServerConfig.load_or_generate_config(
|
||||
"",
|
||||
[
|
||||
"--generate-config",
|
||||
"-c",
|
||||
self.file,
|
||||
"--report-stats=yes",
|
||||
"-H",
|
||||
"lemurs.win",
|
||||
],
|
||||
)
|
||||
|
||||
def generate_config_and_remove_lines_containing(self, needle):
|
||||
self.generate_config()
|
||||
|
|
|
@ -24,9 +24,7 @@ from tests import unittest
|
|||
|
||||
# Perform these tests using given secret key so we get entirely deterministic
|
||||
# signatures output that we can test against.
|
||||
SIGNING_KEY_SEED = decode_base64(
|
||||
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
|
||||
)
|
||||
SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
|
||||
|
||||
KEY_ALG = "ed25519"
|
||||
KEY_VER = 1
|
||||
|
@ -36,7 +34,6 @@ HOSTNAME = "domain"
|
|||
|
||||
|
||||
class EventSigningTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
|
||||
self.signing_key.alg = KEY_ALG
|
||||
|
@ -51,7 +48,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
'signatures': {},
|
||||
'type': "X",
|
||||
'unsigned': {'age_ts': 1000000},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
|
||||
|
@ -61,8 +58,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
self.assertTrue(hasattr(event, 'hashes'))
|
||||
self.assertIn('sha256', event.hashes)
|
||||
self.assertEquals(
|
||||
event.hashes['sha256'],
|
||||
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
|
||||
event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(event, 'signatures'))
|
||||
|
@ -77,9 +73,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
def test_sign_message(self):
|
||||
builder = EventBuilder(
|
||||
{
|
||||
'content': {
|
||||
'body': "Here is the message content",
|
||||
},
|
||||
'content': {'body': "Here is the message content"},
|
||||
'event_id': "$0:domain",
|
||||
'origin': "domain",
|
||||
'origin_server_ts': 1000000,
|
||||
|
@ -98,8 +92,7 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
self.assertTrue(hasattr(event, 'hashes'))
|
||||
self.assertIn('sha256', event.hashes)
|
||||
self.assertEquals(
|
||||
event.hashes['sha256'],
|
||||
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
|
||||
event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(event, 'signatures'))
|
||||
|
@ -108,5 +101,5 @@ class EventSigningTestCase(unittest.TestCase):
|
|||
self.assertEquals(
|
||||
event.signatures[HOSTNAME][KEY_NAME],
|
||||
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
|
||||
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
|
||||
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA",
|
||||
)
|
||||
|
|
|
@ -36,9 +36,7 @@ class MockPerspectiveServer(object):
|
|||
|
||||
def get_verify_keys(self):
|
||||
vk = signedjson.key.get_verify_key(self.key)
|
||||
return {
|
||||
"%s:%s" % (vk.alg, vk.version): vk,
|
||||
}
|
||||
return {"%s:%s" % (vk.alg, vk.version): vk}
|
||||
|
||||
def get_signed_key(self, server_name, verify_key):
|
||||
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||
|
@ -47,10 +45,8 @@ class MockPerspectiveServer(object):
|
|||
"old_verify_keys": {},
|
||||
"valid_until_ts": time.time() * 1000 + 3600,
|
||||
"verify_keys": {
|
||||
key_id: {
|
||||
"key": signedjson.key.encode_verify_key_base64(verify_key)
|
||||
}
|
||||
}
|
||||
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
|
||||
},
|
||||
}
|
||||
signedjson.sign.sign_json(res, self.server_name, self.key)
|
||||
return res
|
||||
|
@ -62,18 +58,14 @@ class KeyringTestCase(unittest.TestCase):
|
|||
self.mock_perspective_server = MockPerspectiveServer()
|
||||
self.http_client = Mock()
|
||||
self.hs = yield utils.setup_test_homeserver(
|
||||
handlers=None,
|
||||
http_client=self.http_client,
|
||||
self.addCleanup, handlers=None, http_client=self.http_client
|
||||
)
|
||||
self.hs.config.perspectives = {
|
||||
self.mock_perspective_server.server_name:
|
||||
self.mock_perspective_server.get_verify_keys()
|
||||
}
|
||||
keys = self.mock_perspective_server.get_verify_keys()
|
||||
self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
|
||||
|
||||
def check_context(self, _, expected):
|
||||
self.assertEquals(
|
||||
getattr(LoggingContext.current_context(), "request", None),
|
||||
expected
|
||||
getattr(LoggingContext.current_context(), "request", None), expected
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -89,8 +81,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
context_one.request = "one"
|
||||
|
||||
wait_1_deferred = kr.wait_for_previous_lookups(
|
||||
["server1"],
|
||||
{"server1": lookup_1_deferred},
|
||||
["server1"], {"server1": lookup_1_deferred}
|
||||
)
|
||||
|
||||
# there were no previous lookups, so the deferred should be ready
|
||||
|
@ -105,8 +96,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
# set off another wait. It should block because the first lookup
|
||||
# hasn't yet completed.
|
||||
wait_2_deferred = kr.wait_for_previous_lookups(
|
||||
["server1"],
|
||||
{"server1": lookup_2_deferred},
|
||||
["server1"], {"server1": lookup_2_deferred}
|
||||
)
|
||||
self.assertFalse(wait_2_deferred.called)
|
||||
# ... so we should have reset the LoggingContext.
|
||||
|
@ -132,21 +122,19 @@ class KeyringTestCase(unittest.TestCase):
|
|||
persp_resp = {
|
||||
"server_keys": [
|
||||
self.mock_perspective_server.get_signed_key(
|
||||
"server10",
|
||||
signedjson.key.get_verify_key(key1)
|
||||
),
|
||||
"server10", signedjson.key.get_verify_key(key1)
|
||||
)
|
||||
]
|
||||
}
|
||||
persp_deferred = defer.Deferred()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_perspectives(**kwargs):
|
||||
self.assertEquals(
|
||||
LoggingContext.current_context().request, "11",
|
||||
)
|
||||
self.assertEquals(LoggingContext.current_context().request, "11")
|
||||
with logcontext.PreserveLoggingContext():
|
||||
yield persp_deferred
|
||||
defer.returnValue(persp_resp)
|
||||
|
||||
self.http_client.post_json.side_effect = get_perspectives
|
||||
|
||||
with LoggingContext("11") as context_11:
|
||||
|
@ -154,9 +142,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
|
||||
# start off a first set of lookups
|
||||
res_deferreds = kr.verify_json_objects_for_server(
|
||||
[("server10", json1),
|
||||
("server11", {})
|
||||
]
|
||||
[("server10", json1), ("server11", {})]
|
||||
)
|
||||
|
||||
# the unsigned json should be rejected pretty quickly
|
||||
|
@ -172,7 +158,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
|
||||
# wait a tick for it to send the request to the perspectives server
|
||||
# (it first tries the datastore)
|
||||
yield clock.sleep(1) # XXX find out why this takes so long!
|
||||
yield clock.sleep(1) # XXX find out why this takes so long!
|
||||
self.http_client.post_json.assert_called_once()
|
||||
|
||||
self.assertIs(LoggingContext.current_context(), context_11)
|
||||
|
@ -186,7 +172,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
self.http_client.post_json.return_value = defer.Deferred()
|
||||
|
||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||
[("server10", json1)],
|
||||
[("server10", json1)]
|
||||
)
|
||||
yield clock.sleep(1)
|
||||
self.http_client.post_json.assert_not_called()
|
||||
|
@ -207,8 +193,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||
|
||||
key1 = signedjson.key.generate_signing_key(1)
|
||||
yield self.hs.datastore.store_server_verify_key(
|
||||
"server9", "", time.time() * 1000,
|
||||
signedjson.key.get_verify_key(key1),
|
||||
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
|
||||
)
|
||||
json1 = {}
|
||||
signedjson.sign.sign_json(json1, "server9", key1)
|
||||
|
|
|
@ -31,25 +31,20 @@ def MockEvent(**kwargs):
|
|||
class PruneEventTestCase(unittest.TestCase):
|
||||
""" Asserts that a new event constructed with `evdict` will look like
|
||||
`matchdict` when it is redacted. """
|
||||
|
||||
def run_test(self, evdict, matchdict):
|
||||
self.assertEquals(
|
||||
prune_event(FrozenEvent(evdict)).get_dict(),
|
||||
matchdict
|
||||
)
|
||||
self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict)
|
||||
|
||||
def test_minimal(self):
|
||||
self.run_test(
|
||||
{
|
||||
'type': 'A',
|
||||
'event_id': '$test:domain',
|
||||
},
|
||||
{'type': 'A', 'event_id': '$test:domain'},
|
||||
{
|
||||
'type': 'A',
|
||||
'event_id': '$test:domain',
|
||||
'content': {},
|
||||
'signatures': {},
|
||||
'unsigned': {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_basic_keys(self):
|
||||
|
@ -70,23 +65,19 @@ class PruneEventTestCase(unittest.TestCase):
|
|||
'content': {},
|
||||
'signatures': {},
|
||||
'unsigned': {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_unsigned_age_ts(self):
|
||||
self.run_test(
|
||||
{
|
||||
'type': 'B',
|
||||
'event_id': '$test:domain',
|
||||
'unsigned': {'age_ts': 20},
|
||||
},
|
||||
{'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}},
|
||||
{
|
||||
'type': 'B',
|
||||
'event_id': '$test:domain',
|
||||
'content': {},
|
||||
'signatures': {},
|
||||
'unsigned': {'age_ts': 20},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.run_test(
|
||||
|
@ -101,23 +92,19 @@ class PruneEventTestCase(unittest.TestCase):
|
|||
'content': {},
|
||||
'signatures': {},
|
||||
'unsigned': {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_content(self):
|
||||
self.run_test(
|
||||
{
|
||||
'type': 'C',
|
||||
'event_id': '$test:domain',
|
||||
'content': {'things': 'here'},
|
||||
},
|
||||
{'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}},
|
||||
{
|
||||
'type': 'C',
|
||||
'event_id': '$test:domain',
|
||||
'content': {},
|
||||
'signatures': {},
|
||||
'unsigned': {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.run_test(
|
||||
|
@ -132,27 +119,20 @@ class PruneEventTestCase(unittest.TestCase):
|
|||
'content': {'creator': '@2:domain'},
|
||||
'signatures': {},
|
||||
'unsigned': {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SerializeEventTestCase(unittest.TestCase):
|
||||
|
||||
def serialize(self, ev, fields):
|
||||
return serialize_event(ev, 1479807801915, only_event_fields=fields)
|
||||
|
||||
def test_event_fields_works_with_keys(self):
|
||||
self.assertEquals(
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar"
|
||||
),
|
||||
["room_id"]
|
||||
MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
|
||||
),
|
||||
{
|
||||
"room_id": "!foo:bar",
|
||||
}
|
||||
{"room_id": "!foo:bar"},
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_nested_keys(self):
|
||||
|
@ -161,17 +141,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"body": "A message",
|
||||
},
|
||||
content={"body": "A message"},
|
||||
),
|
||||
["content.body"]
|
||||
["content.body"],
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"body": "A message",
|
||||
}
|
||||
}
|
||||
{"content": {"body": "A message"}},
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_dot_keys(self):
|
||||
|
@ -180,17 +154,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"key.with.dots": {},
|
||||
},
|
||||
content={"key.with.dots": {}},
|
||||
),
|
||||
["content.key\.with\.dots"]
|
||||
["content.key\.with\.dots"],
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"key.with.dots": {},
|
||||
}
|
||||
}
|
||||
{"content": {"key.with.dots": {}}},
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_nested_dot_keys(self):
|
||||
|
@ -201,21 +169,12 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
room_id="!foo:bar",
|
||||
content={
|
||||
"not_me": 1,
|
||||
"nested.dot.key": {
|
||||
"leaf.key": 42,
|
||||
"not_me_either": 1,
|
||||
},
|
||||
"nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
|
||||
},
|
||||
),
|
||||
["content.nested\.dot\.key.leaf\.key"]
|
||||
["content.nested\.dot\.key.leaf\.key"],
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"nested.dot.key": {
|
||||
"leaf.key": 42,
|
||||
},
|
||||
}
|
||||
}
|
||||
{"content": {"nested.dot.key": {"leaf.key": 42}}},
|
||||
)
|
||||
|
||||
def test_event_fields_nops_with_unknown_keys(self):
|
||||
|
@ -224,17 +183,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": "bar",
|
||||
},
|
||||
content={"foo": "bar"},
|
||||
),
|
||||
["content.foo", "content.notexists"]
|
||||
["content.foo", "content.notexists"],
|
||||
),
|
||||
{
|
||||
"content": {
|
||||
"foo": "bar",
|
||||
}
|
||||
}
|
||||
{"content": {"foo": "bar"}},
|
||||
)
|
||||
|
||||
def test_event_fields_nops_with_non_dict_keys(self):
|
||||
|
@ -243,13 +196,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": ["I", "am", "an", "array"],
|
||||
},
|
||||
content={"foo": ["I", "am", "an", "array"]},
|
||||
),
|
||||
["content.foo.am"]
|
||||
["content.foo.am"],
|
||||
),
|
||||
{}
|
||||
{},
|
||||
)
|
||||
|
||||
def test_event_fields_nops_with_array_keys(self):
|
||||
|
@ -258,13 +209,11 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
MockEvent(
|
||||
sender="@alice:localhost",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": ["I", "am", "an", "array"],
|
||||
},
|
||||
content={"foo": ["I", "am", "an", "array"]},
|
||||
),
|
||||
["content.foo.1"]
|
||||
["content.foo.1"],
|
||||
),
|
||||
{}
|
||||
{},
|
||||
)
|
||||
|
||||
def test_event_fields_all_fields_if_empty(self):
|
||||
|
@ -274,31 +223,21 @@ class SerializeEventTestCase(unittest.TestCase):
|
|||
type="foo",
|
||||
event_id="test",
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": "bar",
|
||||
},
|
||||
content={"foo": "bar"},
|
||||
),
|
||||
[]
|
||||
[],
|
||||
),
|
||||
{
|
||||
"type": "foo",
|
||||
"event_id": "test",
|
||||
"room_id": "!foo:bar",
|
||||
"content": {
|
||||
"foo": "bar",
|
||||
},
|
||||
"unsigned": {}
|
||||
}
|
||||
"content": {"foo": "bar"},
|
||||
"unsigned": {},
|
||||
},
|
||||
)
|
||||
|
||||
def test_event_fields_fail_if_fields_not_str(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.serialize(
|
||||
MockEvent(
|
||||
room_id="!foo:bar",
|
||||
content={
|
||||
"foo": "bar",
|
||||
},
|
||||
),
|
||||
["room_id", 4]
|
||||
MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
|
||||
)
|
||||
|
|
|
@ -23,10 +23,7 @@ from tests import unittest
|
|||
@unittest.DEBUG
|
||||
class ServerACLsTestCase(unittest.TestCase):
|
||||
def test_blacklisted_server(self):
|
||||
e = _create_acl_event({
|
||||
"allow": ["*"],
|
||||
"deny": ["evil.com"],
|
||||
})
|
||||
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
||||
logging.info("ACL event: %s", e.content)
|
||||
|
||||
self.assertFalse(server_matches_acl_event("evil.com", e))
|
||||
|
@ -36,10 +33,7 @@ class ServerACLsTestCase(unittest.TestCase):
|
|||
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
|
||||
|
||||
def test_block_ip_literals(self):
|
||||
e = _create_acl_event({
|
||||
"allow_ip_literals": False,
|
||||
"allow": ["*"],
|
||||
})
|
||||
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
|
||||
logging.info("ACL event: %s", e.content)
|
||||
|
||||
self.assertFalse(server_matches_acl_event("1.2.3.4", e))
|
||||
|
@ -49,10 +43,12 @@ class ServerACLsTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
def _create_acl_event(content):
|
||||
return FrozenEvent({
|
||||
"room_id": "!a:b",
|
||||
"event_id": "$a:b",
|
||||
"type": "m.room.server_acls",
|
||||
"sender": "@a:b",
|
||||
"content": content
|
||||
})
|
||||
return FrozenEvent(
|
||||
{
|
||||
"room_id": "!a:b",
|
||||
"event_id": "$a:b",
|
||||
"type": "m.room.server_acls",
|
||||
"sender": "@a:b",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -45,20 +45,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
services = [
|
||||
self._mkservice(is_interested=False),
|
||||
interested_service,
|
||||
self._mkservice(is_interested=False)
|
||||
self._mkservice(is_interested=False),
|
||||
]
|
||||
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
self.mock_store.get_user_by_id = Mock(return_value=[])
|
||||
|
||||
event = Mock(
|
||||
sender="@someone:anywhere",
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
|
||||
)
|
||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||
(0, [event]),
|
||||
(0, [])
|
||||
(0, []),
|
||||
]
|
||||
self.mock_as_api.push = Mock()
|
||||
yield self.handler.notify_interested_services(0)
|
||||
|
@ -74,21 +72,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
self.mock_store.get_user_by_id = Mock(return_value=None)
|
||||
|
||||
event = Mock(
|
||||
sender=user_id,
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||
(0, [event]),
|
||||
(0, [])
|
||||
(0, []),
|
||||
]
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.mock_as_api.query_user.assert_called_once_with(
|
||||
services[0], user_id
|
||||
)
|
||||
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_query_user_exists_known_user(self):
|
||||
|
@ -96,25 +88,19 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
services = [self._mkservice(is_interested=True)]
|
||||
services[0].is_interested_in_user = Mock(return_value=True)
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
self.mock_store.get_user_by_id = Mock(return_value={
|
||||
"name": user_id
|
||||
})
|
||||
self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
|
||||
|
||||
event = Mock(
|
||||
sender=user_id,
|
||||
type="m.room.message",
|
||||
room_id="!foo:bar"
|
||||
)
|
||||
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
|
||||
self.mock_as_api.push = Mock()
|
||||
self.mock_as_api.query_user = Mock()
|
||||
self.mock_store.get_new_events_for_appservice.side_effect = [
|
||||
(0, [event]),
|
||||
(0, [])
|
||||
(0, []),
|
||||
]
|
||||
yield self.handler.notify_interested_services(0)
|
||||
self.assertFalse(
|
||||
self.mock_as_api.query_user.called,
|
||||
"query_user called when it shouldn't have been."
|
||||
"query_user called when it shouldn't have been.",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -129,7 +115,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
services = [
|
||||
self._mkservice_alias(is_interested_in_alias=False),
|
||||
interested_service,
|
||||
self._mkservice_alias(is_interested_in_alias=False)
|
||||
self._mkservice_alias(is_interested_in_alias=False),
|
||||
]
|
||||
|
||||
self.mock_store.get_app_services = Mock(return_value=services)
|
||||
|
@ -140,8 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
result = yield self.handler.query_room_alias_exists(room_alias)
|
||||
|
||||
self.mock_as_api.query_alias.assert_called_once_with(
|
||||
interested_service,
|
||||
room_alias_str
|
||||
interested_service, room_alias_str
|
||||
)
|
||||
self.assertEquals(result.room_id, room_id)
|
||||
self.assertEquals(result.servers, servers)
|
||||
|
|
|
@ -35,7 +35,7 @@ class AuthHandlers(object):
|
|||
class AuthTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(handlers=None)
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
|
||||
self.hs.handlers = AuthHandlers(self.hs)
|
||||
self.auth_handler = self.hs.handlers.auth_handler
|
||||
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||
|
@ -81,9 +81,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
def test_short_term_login_token_gives_user_id(self):
|
||||
self.hs.clock.now = 1000
|
||||
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", 5000
|
||||
)
|
||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
|
@ -98,17 +96,13 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", 5000
|
||||
)
|
||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
self.assertEqual(
|
||||
"a_user", user_id
|
||||
)
|
||||
self.assertEqual("a_user", user_id)
|
||||
|
||||
# add another "user_id" caveat, which might allow us to override the
|
||||
# user_id.
|
||||
|
@ -165,7 +159,5 @@ class AuthTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
def _get_macaroon(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"user_a", 5000
|
||||
)
|
||||
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
|
||||
return pymacaroons.Macaroon.deserialize(token)
|
||||
|
|
|
@ -28,13 +28,13 @@ user2 = "@theresa:bbb"
|
|||
class DeviceTestCase(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||
self.clock = None # type: utils.MockClock
|
||||
self.clock = None # type: utils.MockClock
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield utils.setup_test_homeserver()
|
||||
hs = yield utils.setup_test_homeserver(self.addCleanup)
|
||||
self.handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -44,7 +44,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||
res = yield self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name"
|
||||
initial_device_display_name="display name",
|
||||
)
|
||||
self.assertEqual(res, "fco")
|
||||
|
||||
|
@ -56,14 +56,14 @@ class DeviceTestCase(unittest.TestCase):
|
|||
res1 = yield self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name"
|
||||
initial_device_display_name="display name",
|
||||
)
|
||||
self.assertEqual(res1, "fco")
|
||||
|
||||
res2 = yield self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="new display name"
|
||||
initial_device_display_name="new display name",
|
||||
)
|
||||
self.assertEqual(res2, "fco")
|
||||
|
||||
|
@ -75,7 +75,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||
device_id = yield self.handler.check_device_registered(
|
||||
user_id="@theresa:foo",
|
||||
device_id=None,
|
||||
initial_device_display_name="display"
|
||||
initial_device_display_name="display",
|
||||
)
|
||||
|
||||
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
|
||||
|
@ -87,43 +87,53 @@ class DeviceTestCase(unittest.TestCase):
|
|||
|
||||
res = yield self.handler.get_devices_by_user(user1)
|
||||
self.assertEqual(3, len(res))
|
||||
device_map = {
|
||||
d["device_id"]: d for d in res
|
||||
}
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "xyz",
|
||||
"display_name": "display 0",
|
||||
"last_seen_ip": None,
|
||||
"last_seen_ts": None,
|
||||
}, device_map["xyz"])
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "fco",
|
||||
"display_name": "display 1",
|
||||
"last_seen_ip": "ip1",
|
||||
"last_seen_ts": 1000000,
|
||||
}, device_map["fco"])
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "abc",
|
||||
"display_name": "display 2",
|
||||
"last_seen_ip": "ip3",
|
||||
"last_seen_ts": 3000000,
|
||||
}, device_map["abc"])
|
||||
device_map = {d["device_id"]: d for d in res}
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user1,
|
||||
"device_id": "xyz",
|
||||
"display_name": "display 0",
|
||||
"last_seen_ip": None,
|
||||
"last_seen_ts": None,
|
||||
},
|
||||
device_map["xyz"],
|
||||
)
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user1,
|
||||
"device_id": "fco",
|
||||
"display_name": "display 1",
|
||||
"last_seen_ip": "ip1",
|
||||
"last_seen_ts": 1000000,
|
||||
},
|
||||
device_map["fco"],
|
||||
)
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user1,
|
||||
"device_id": "abc",
|
||||
"display_name": "display 2",
|
||||
"last_seen_ip": "ip3",
|
||||
"last_seen_ts": 3000000,
|
||||
},
|
||||
device_map["abc"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_device(self):
|
||||
yield self._record_users()
|
||||
|
||||
res = yield self.handler.get_device(user1, "abc")
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "abc",
|
||||
"display_name": "display 2",
|
||||
"last_seen_ip": "ip3",
|
||||
"last_seen_ts": 3000000,
|
||||
}, res)
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user1,
|
||||
"device_id": "abc",
|
||||
"display_name": "display 2",
|
||||
"last_seen_ip": "ip3",
|
||||
"last_seen_ts": 3000000,
|
||||
},
|
||||
res,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_delete_device(self):
|
||||
|
@ -153,8 +163,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||
update)
|
||||
yield self.handler.update_device("user_id", "unknown_device_id", update)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_users(self):
|
||||
|
@ -168,16 +177,17 @@ class DeviceTestCase(unittest.TestCase):
|
|||
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_user(self, user_id, device_id, display_name,
|
||||
access_token=None, ip=None):
|
||||
def _record_user(
|
||||
self, user_id, device_id, display_name, access_token=None, ip=None
|
||||
):
|
||||
device_id = yield self.handler.check_device_registered(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=display_name
|
||||
initial_device_display_name=display_name,
|
||||
)
|
||||
|
||||
if ip is not None:
|
||||
yield self.store.insert_client_ip(
|
||||
user_id,
|
||||
access_token, ip, "user_agent", device_id)
|
||||
user_id, access_token, ip, "user_agent", device_id
|
||||
)
|
||||
self.clock.advance_time(1000)
|
||||
|
|
|
@ -42,9 +42,11 @@ class DirectoryTestCase(unittest.TestCase):
|
|||
|
||||
def register_query_handler(query_type, handler):
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
self.mock_registry.register_query_handler = register_query_handler
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
http_client=None,
|
||||
resource_for_federation=Mock(),
|
||||
federation_client=self.mock_federation,
|
||||
|
@ -68,10 +70,7 @@ class DirectoryTestCase(unittest.TestCase):
|
|||
|
||||
result = yield self.handler.get_association(self.my_room)
|
||||
|
||||
self.assertEquals({
|
||||
"room_id": "!8765qwer:test",
|
||||
"servers": ["test"],
|
||||
}, result)
|
||||
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_remote_association(self):
|
||||
|
@ -81,16 +80,13 @@ class DirectoryTestCase(unittest.TestCase):
|
|||
|
||||
result = yield self.handler.get_association(self.remote_room)
|
||||
|
||||
self.assertEquals({
|
||||
"room_id": "!8765qwer:test",
|
||||
"servers": ["test", "remote"],
|
||||
}, result)
|
||||
self.assertEquals(
|
||||
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
|
||||
)
|
||||
self.mock_federation.make_query.assert_called_with(
|
||||
destination="remote",
|
||||
query_type="directory",
|
||||
args={
|
||||
"room_alias": "#another:remote",
|
||||
},
|
||||
args={"room_alias": "#another:remote"},
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
@ -105,7 +101,4 @@ class DirectoryTestCase(unittest.TestCase):
|
|||
{"room_alias": "#your-room:test"}
|
||||
)
|
||||
|
||||
self.assertEquals({
|
||||
"room_id": "!8765asdf:test",
|
||||
"servers": ["test"],
|
||||
}, response)
|
||||
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
|
||||
|
|
|
@ -28,14 +28,13 @@ from tests import unittest, utils
|
|||
class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
|
||||
self.hs = None # type: synapse.server.HomeServer
|
||||
self.hs = None # type: synapse.server.HomeServer
|
||||
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield utils.setup_test_homeserver(
|
||||
handlers=None,
|
||||
federation_client=mock.Mock(),
|
||||
self.addCleanup, handlers=None, federation_client=mock.Mock()
|
||||
)
|
||||
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
|
||||
|
||||
|
@ -54,30 +53,21 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
device_id = "xyz"
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {
|
||||
"key": "key2",
|
||||
"signatures": {"k1": "sig1"}
|
||||
},
|
||||
"alg2:k3": {
|
||||
"key": "key3",
|
||||
},
|
||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||
"alg2:k3": {"key": "key3"},
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||
})
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||
|
||||
# we should be able to change the signature without a problem
|
||||
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||
})
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_change_one_time_keys(self):
|
||||
|
@ -87,25 +77,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
device_id = "xyz"
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {
|
||||
"key": "key2",
|
||||
"signatures": {"k1": "sig1"}
|
||||
},
|
||||
"alg2:k3": {
|
||||
"key": "key3",
|
||||
},
|
||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||
"alg2:k3": {"key": "key3"},
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||
})
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
|
||||
)
|
||||
self.fail("No error when changing string key")
|
||||
except errors.SynapseError:
|
||||
|
@ -113,7 +96,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
|
||||
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
|
||||
)
|
||||
self.fail("No error when replacing dict key with string")
|
||||
except errors.SynapseError:
|
||||
|
@ -121,9 +104,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {
|
||||
"one_time_keys": {"alg1:k1": {"key": "key"}}
|
||||
},
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
|
||||
)
|
||||
self.fail("No error when replacing string key with dict")
|
||||
except errors.SynapseError:
|
||||
|
@ -131,13 +112,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {
|
||||
local_user,
|
||||
device_id,
|
||||
{
|
||||
"one_time_keys": {
|
||||
"alg2:k2": {
|
||||
"key": "key3",
|
||||
"signatures": {"k1": "sig1"},
|
||||
}
|
||||
},
|
||||
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
|
||||
}
|
||||
},
|
||||
)
|
||||
self.fail("No error when replacing dict key")
|
||||
|
@ -148,31 +128,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||
def test_claim_one_time_key(self):
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {
|
||||
"alg1:k1": "key1",
|
||||
}
|
||||
keys = {"alg1:k1": "key1"}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys},
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
self.assertDictEqual(res, {
|
||||
"one_time_key_counts": {"alg1": 1}
|
||||
})
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
|
||||
|
||||
res2 = yield self.handler.claim_one_time_keys({
|
||||
"one_time_keys": {
|
||||
local_user: {
|
||||
device_id: "alg1"
|
||||
}
|
||||
}
|
||||
}, timeout=None)
|
||||
self.assertEqual(res2, {
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_user: {
|
||||
device_id: {
|
||||
"alg1:k1": "key1"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
res2 = yield self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
self.assertEqual(
|
||||
res2,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -39,8 +39,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
|
||||
prev_state = UserPresenceState.default(user_id)
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
state=PresenceState.ONLINE, last_active_ts=now
|
||||
)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
|
@ -54,23 +53,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
self.assertEquals(state.last_federation_update_ts, now)
|
||||
|
||||
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||
wheel_timer.insert.assert_has_calls([
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + IDLE_TIMER
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
||||
),
|
||||
], any_order=True)
|
||||
wheel_timer.insert.assert_has_calls(
|
||||
[
|
||||
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
|
||||
),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_online(self):
|
||||
wheel_timer = Mock()
|
||||
|
@ -79,14 +77,11 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
|
||||
prev_state = UserPresenceState.default(user_id)
|
||||
prev_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
currently_active=True,
|
||||
state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
|
||||
)
|
||||
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
state=PresenceState.ONLINE, last_active_ts=now
|
||||
)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
|
@ -101,23 +96,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
self.assertEquals(state.last_federation_update_ts, now)
|
||||
|
||||
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||
wheel_timer.insert.assert_has_calls([
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + IDLE_TIMER
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
||||
),
|
||||
], any_order=True)
|
||||
wheel_timer.insert.assert_has_calls(
|
||||
[
|
||||
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
|
||||
),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_online_last_active_noop(self):
|
||||
wheel_timer = Mock()
|
||||
|
@ -132,8 +126,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
state=PresenceState.ONLINE, last_active_ts=now
|
||||
)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
|
@ -148,23 +141,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
self.assertEquals(state.last_federation_update_ts, now)
|
||||
|
||||
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||
wheel_timer.insert.assert_has_calls([
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + IDLE_TIMER
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
||||
),
|
||||
], any_order=True)
|
||||
wheel_timer.insert.assert_has_calls(
|
||||
[
|
||||
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
|
||||
),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_online_last_active(self):
|
||||
wheel_timer = Mock()
|
||||
|
@ -178,9 +170,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
currently_active=True,
|
||||
)
|
||||
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
)
|
||||
new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||
|
@ -193,18 +183,17 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
self.assertEquals(state.last_federation_update_ts, now)
|
||||
|
||||
self.assertEquals(wheel_timer.insert.call_count, 2)
|
||||
wheel_timer.insert.assert_has_calls([
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_active_ts + IDLE_TIMER
|
||||
),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||
)
|
||||
], any_order=True)
|
||||
wheel_timer.insert.assert_has_calls(
|
||||
[
|
||||
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||
),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
def test_remote_ping_timer(self):
|
||||
wheel_timer = Mock()
|
||||
|
@ -213,13 +202,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
|
||||
prev_state = UserPresenceState.default(user_id)
|
||||
prev_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
state=PresenceState.ONLINE, last_active_ts=now
|
||||
)
|
||||
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
)
|
||||
new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now
|
||||
|
@ -232,13 +218,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
self.assertEquals(new_state.status_msg, state.status_msg)
|
||||
|
||||
self.assertEquals(wheel_timer.insert.call_count, 1)
|
||||
wheel_timer.insert.assert_has_calls([
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT
|
||||
),
|
||||
], any_order=True)
|
||||
wheel_timer.insert.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
|
||||
)
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
def test_online_to_offline(self):
|
||||
wheel_timer = Mock()
|
||||
|
@ -247,14 +236,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
|
||||
prev_state = UserPresenceState.default(user_id)
|
||||
prev_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
currently_active=True,
|
||||
state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
|
||||
)
|
||||
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.OFFLINE,
|
||||
)
|
||||
new_state = prev_state.copy_and_replace(state=PresenceState.OFFLINE)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||
|
@ -273,14 +258,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
|
||||
prev_state = UserPresenceState.default(user_id)
|
||||
prev_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
currently_active=True,
|
||||
state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
|
||||
)
|
||||
|
||||
new_state = prev_state.copy_and_replace(
|
||||
state=PresenceState.UNAVAILABLE,
|
||||
)
|
||||
new_state = prev_state.copy_and_replace(state=PresenceState.UNAVAILABLE)
|
||||
|
||||
state, persist_and_notify, federation_ping = handle_update(
|
||||
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||
|
@ -293,13 +274,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||
self.assertEquals(new_state.status_msg, state.status_msg)
|
||||
|
||||
self.assertEquals(wheel_timer.insert.call_count, 1)
|
||||
wheel_timer.insert.assert_has_calls([
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||
)
|
||||
], any_order=True)
|
||||
wheel_timer.insert.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
now=now,
|
||||
obj=user_id,
|
||||
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||
)
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
|
||||
class PresenceTimeoutTestCase(unittest.TestCase):
|
||||
|
@ -314,9 +298,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
last_user_sync_ts=now,
|
||||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
|
||||
|
@ -332,9 +314,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
|
||||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
self.assertEquals(new_state.state, PresenceState.OFFLINE)
|
||||
|
@ -369,9 +349,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
|
||||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
self.assertEquals(new_state, new_state)
|
||||
|
@ -388,9 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
last_federation_update_ts=now,
|
||||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||
|
||||
self.assertIsNone(new_state)
|
||||
|
||||
|
@ -425,9 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
last_federation_update_ts=now,
|
||||
)
|
||||
|
||||
new_state = handle_timeout(
|
||||
state, is_mine=True, syncing_user_ids=set(), now=now
|
||||
)
|
||||
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
|
||||
|
||||
self.assertIsNotNone(new_state)
|
||||
self.assertEquals(state, new_state)
|
||||
|
|
|
@ -48,15 +48,14 @@ class ProfileTestCase(unittest.TestCase):
|
|||
self.mock_registry.register_query_handler = register_query_handler
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
http_client=None,
|
||||
handlers=None,
|
||||
resource_for_federation=Mock(),
|
||||
federation_client=self.mock_federation,
|
||||
federation_server=Mock(),
|
||||
federation_registry=self.mock_registry,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
])
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
)
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
|
@ -74,9 +73,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
yield self.store.set_profile_displayname(
|
||||
self.frank.localpart, "Frank"
|
||||
)
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
|
||||
displayname = yield self.handler.get_displayname(self.frank)
|
||||
|
||||
|
@ -85,22 +82,18 @@ class ProfileTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_set_my_name(self):
|
||||
yield self.handler.set_displayname(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"Frank Jr."
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)),
|
||||
"Frank Jr."
|
||||
"Frank Jr.",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_set_my_name_noauth(self):
|
||||
d = self.handler.set_displayname(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.bob),
|
||||
"Frank Jr."
|
||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||
)
|
||||
|
||||
yield self.assertFailure(d, AuthError)
|
||||
|
@ -145,11 +138,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_set_my_avatar(self):
|
||||
yield self.handler.set_avatar_url(
|
||||
self.frank, synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif"
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
"http://my.server/pic.gif"
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ from mock import Mock
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import RegistrationError
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
from synapse.types import UserID, create_requester
|
||||
|
||||
|
@ -40,13 +40,15 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
self.mock_distributor.declare("registered_user")
|
||||
self.mock_captcha_client = Mock()
|
||||
self.hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
handlers=None,
|
||||
http_client=None,
|
||||
expire_access_token=True,
|
||||
profile_handler=Mock(),
|
||||
)
|
||||
self.macaroon_generator = Mock(
|
||||
generate_access_token=Mock(return_value='secret'))
|
||||
generate_access_token=Mock(return_value='secret')
|
||||
)
|
||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
|
@ -62,7 +64,8 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
user_id = "@someone:test"
|
||||
requester = create_requester("@as:test")
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
requester, local_part, display_name)
|
||||
requester, local_part, display_name
|
||||
)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
||||
|
||||
|
@ -73,13 +76,15 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
yield store.register(
|
||||
user_id=frank.to_string(),
|
||||
token="jkv;g498752-43gj['eamb!-5",
|
||||
password_hash=None)
|
||||
password_hash=None,
|
||||
)
|
||||
local_part = "frank"
|
||||
display_name = "Frank"
|
||||
user_id = "@frank:test"
|
||||
requester = create_requester("@as:test")
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
requester, local_part, display_name)
|
||||
requester, local_part, display_name
|
||||
)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
||||
|
||||
|
@ -104,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.lots_of_users)
|
||||
)
|
||||
with self.assertRaises(RegistrationError):
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -113,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.lots_of_users)
|
||||
)
|
||||
with self.assertRaises(RegistrationError):
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.handler.register(localpart="local_part")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -122,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.lots_of_users)
|
||||
)
|
||||
with self.assertRaises(RegistrationError):
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.handler.register_saml2(localpart="local_part")
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue