0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-17 11:31:58 +01:00

merge proper fix to bug 2969

This commit is contained in:
Matthew Hodgson 2018-03-13 22:11:58 +00:00
commit 12350e3f9a
46 changed files with 578 additions and 338 deletions

View file

@ -32,3 +32,30 @@ specified by including an event_id in the URI, or by setting a
id is given, that event (and others at the same graph depth) will be retained.
If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch,
in milliseconds.
The API starts the purge running, and returns immediately with a JSON body with
a purge id:
.. code:: json
{
"purge_id": "<opaque id>"
}
Purge status query
------------------
It is possible to poll for updates on recent purges with a second API;
``GET /_matrix/client/r0/admin/purge_history_status/<purge_id>``
(again, with a suitable ``access_token``). This API returns a JSON body like
the following:
.. code:: json
{
"status": "active"
}
The status will be one of ``active``, ``complete``, or ``failed``.

View file

@ -279,9 +279,9 @@ Obviously that option means that the operations done in
that might be fixed by setting a different logcontext via a ``with
LoggingContext(...)`` in ``background_operation``).
The second option is to use ``logcontext.preserve_fn``, which wraps a function
so that it doesn't reset the logcontext even when it returns an incomplete
deferred, and adds a callback to the returned deferred to reset the
The second option is to use ``logcontext.run_in_background``, which wraps a
function so that it doesn't reset the logcontext even when it returns an
incomplete deferred, and adds a callback to the returned deferred to reset the
logcontext. In other words, it turns a function that follows the Synapse rules
about logcontexts and Deferreds into one which behaves more like an external
function — the opposite operation to that described in the previous section.
@ -293,7 +293,7 @@ It can be used like this:
def do_request_handling():
yield foreground_operation()
logcontext.preserve_fn(background_operation)()
logcontext.run_in_background(background_operation)
# this will now be logged against the request context
logger.debug("Request handling complete")

View file

@ -156,7 +156,6 @@ def start(config_options):
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def start():

View file

@ -161,7 +161,6 @@ def start(config_options):
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def start():

View file

@ -144,7 +144,6 @@ def start(config_options):
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def start():

View file

@ -211,7 +211,6 @@ def start(config_options):
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def start():

View file

@ -348,7 +348,7 @@ def setup(config_options):
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
hs.get_federation_client().start_get_pdu_cache()
register_memory_metrics(hs)

View file

@ -158,7 +158,6 @@ def start(config_options):
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def start():

View file

@ -15,11 +15,3 @@
""" This package includes all the federation specific logic.
"""
from .replication import ReplicationLayer
def initialize_http_replication(hs):
transport = hs.get_federation_transport_client()
return ReplicationLayer(hs, transport)

View file

@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastore()
self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,

View file

@ -58,6 +58,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""

View file

@ -17,12 +17,14 @@ import logging
import simplejson as json
from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
import synapse.metrics
from synapse.types import get_domain_from_id
@ -52,50 +54,19 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
@ -229,16 +200,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
if edu_type in self.edu_handlers:
try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
yield self.registry.on_edu(edu_type, origin, content)
@defer.inlineCallbacks
@log_function
@ -328,14 +290,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type)
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
resp = yield self.registry.on_query(query_type, args)
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
@ -607,3 +563,66 @@ class FederationServer(FederationBase):
origin, room_id, event_dict
)
defer.returnValue(ret)
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
def register_edu_handler(self, edu_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
edu_type (str): The type of the incoming EDU to register handler for
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
logger.warn("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)

View file

@ -1,73 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This layer is responsible for replicating with remote home servers using
a given transport.
"""
from .federation_client import FederationClient
from .federation_server import FederationServer
from .persistence import TransactionActions
import logging
logger = logging.getLogger(__name__)
class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
The layer communicates with the rest of the server via a registered
ReplicationHandler.
In more detail, the layer:
* Receives incoming data and processes it into transactions and pdus.
* Fetches any PDUs it thinks it might have missed.
* Keeps the current state for contexts up to date by applying the
suitable conflict resolution.
* Sends outgoing pdus wrapped in transactions.
* Fills out the references to previous pdus/transactions appropriately
for outgoing data.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.federation_client = self
self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
self.hs = hs
super(ReplicationLayer, self).__init__(hs)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name

View file

@ -1190,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
handler=hs.get_replication_layer(),
handler=hs.get_federation_server(),
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,

View file

@ -37,14 +37,15 @@ class DeviceHandler(BaseHandler):
self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler(
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update,
)
self.federation.register_query_handler(
federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
@ -430,7 +431,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.device_handler = device_handler

View file

@ -37,7 +37,7 @@ class DeviceMessageHandler(object):
self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler(
hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)

View file

@ -36,8 +36,8 @@ class DirectoryHandler(BaseHandler):
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query
)

View file

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
@ -40,7 +40,7 @@ class E2eKeysHandler(object):
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys
)

View file

@ -68,7 +68,7 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer()
self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
@ -78,8 +78,6 @@ class FederationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up
self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")

View file

@ -13,7 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
@ -24,9 +25,10 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken,
)
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn
from synapse.util.logcontext import preserve_fn, run_in_background
from synapse.util.metrics import measure_func
from synapse.util.frozenutils import unfreeze
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
from synapse.replication.http.send_event import send_event_to_master
@ -41,6 +43,36 @@ import ujson
logger = logging.getLogger(__name__)
class PurgeStatus(object):
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
STATUS_COMPLETE = 1
STATUS_FAILED = 2
STATUS_TEXT = {
STATUS_ACTIVE: "active",
STATUS_COMPLETE: "complete",
STATUS_FAILED: "failed",
}
def __init__(self):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
return {
"status": PurgeStatus.STATUS_TEXT[self.status]
}
class MessageHandler(BaseHandler):
def __init__(self, hs):
@ -50,14 +82,87 @@ class MessageHandler(BaseHandler):
self.clock = hs.get_clock()
self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
def start_purge_history(self, room_id, topological_ordering,
delete_local_events=False):
"""Start off a history purge on a room.
Args:
room_id (str): The room to purge from
topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
Returns:
str: unique ID for this purge transaction.
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
400,
"History purge already in progress for %s" % (room_id, ),
)
purge_id = random_string(16)
# we log the purge_id here so that it can be tied back to the
# request id in the log lines.
logger.info("[purge] starting purge_id %s", purge_id)
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
purge_id, room_id, topological_ordering, delete_local_events,
)
return purge_id
@defer.inlineCallbacks
def purge_history(self, room_id, topological_ordering,
delete_local_events=False):
with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history(
room_id, topological_ordering, delete_local_events,
)
def _purge_history(self, purge_id, room_id, topological_ordering,
delete_local_events):
"""Carry out a history purge on a room.
Args:
purge_id (str): The id for this purge
room_id (str): The room to purge from
topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
Returns:
Deferred
"""
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history(
room_id, topological_ordering, delete_local_events,
)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
reactor.callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
"""Get the current status of an active purge
Args:
purge_id (str): purge_id returned by start_purge_history
Returns:
PurgeStatus|None
"""
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
@ -562,7 +667,7 @@ class EventCreationHandler(object):
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(str)): Any extra users to notify about event
extra_users (list(UserID)): Any extra users to notify about event
"""
try:

View file

@ -93,29 +93,30 @@ class PresenceHandler(object):
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler()
self.replication.register_edu_handler(
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence
)
self.replication.register_edu_handler(
federation_registry.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
self.replication.register_edu_handler(
federation_registry.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
self.replication.register_edu_handler(
federation_registry.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),

View file

@ -31,8 +31,8 @@ class ProfileHandler(BaseHandler):
def __init__(self, hs):
super(ProfileHandler, self).__init__(hs)
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query
)

View file

@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler(
hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()

View file

@ -446,16 +446,34 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler()
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
one doesn't already exist.
Args:
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
defer.returnValue(access_token)
user_info = yield self.auth.get_user_by_access_token(
access_token
)
_, access_token = yield self.register(
defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
generate_token=True,
make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
)
defer.returnValue(access_token)
defer.returnValue((user_id, access_token))

View file

@ -409,7 +409,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer()
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(

View file

@ -55,7 +55,6 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.replication_layer = hs.get_replication_layer()
self.member_linearizer = Linearizer(name="member")
@ -138,7 +137,20 @@ class RoomMemberHandler(object):
defer.returnValue(event)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
def _remote_join(self, remote_room_hosts, room_id, user, content):
"""Try and join a room that this server is not in
Args:
remote_room_hosts (list[str]): List of servers that can be used
to join via.
room_id (str): Room that we are trying to join
user (UserID): User who is trying to join
content (dict): A dict that should be used as the content of the
join event.
Returns:
Deferred
"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
@ -154,6 +166,43 @@ class RoomMemberHandler(object):
)
yield user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def _remote_reject_invite(self, remote_room_hosts, room_id, target):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
remote_room_hosts (list[str]): List of servers to use to try and
reject invite
room_id (str)
target (UserID): The user rejecting the invite
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
)
defer.returnValue(ret)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
@defer.inlineCallbacks
def update_membership(
self,
@ -212,7 +261,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
yield self.replication_layer.exchange_third_party_invite(
yield self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@ -292,7 +341,7 @@ class RoomMemberHandler(object):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@ -306,7 +355,7 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
ret = yield self.remote_join(
ret = yield self._remote_join(
remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
@ -314,7 +363,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id)
inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@ -328,28 +377,10 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
)
defer.returnValue(ret)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
res = yield self._remote_reject_invite(
remote_room_hosts, room_id, target,
)
defer.returnValue(res)
res = yield self._local_membership_update(
requester=requester,
@ -496,7 +527,7 @@ class RoomMemberHandler(object):
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
def get_inviter(self, user_id, room_id):
def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
@ -573,7 +604,7 @@ class RoomMemberHandler(object):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self.verify_any_signature(data, id_server)
yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
@ -581,7 +612,7 @@ class RoomMemberHandler(object):
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
@ -735,20 +766,16 @@ class RoomMemberHandler(object):
}
if self.config.invite_3pid_guest:
registration_handler = self.registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
rh = self.registration_handler
guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.auth.get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
"guest_user_id": guest_user_id,
})
data = yield self.simple_http_client.post_urlencoded_get_json(

View file

@ -56,7 +56,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -59,6 +60,11 @@ response_count = metrics.register_counter(
)
)
requests_counter = metrics.register_counter(
"requests_received",
labels=["method", "servlet", ],
)
outgoing_responses_counter = metrics.register_counter(
"responses",
labels=["method", "code"],
@ -145,7 +151,8 @@ def wrap_request_handler(request_handler, include_metrics=False):
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
request_metrics.start(self.clock, name=self.__class__.__name__)
servlet_name = self.__class__.__name__
request_metrics.start(self.clock, name=servlet_name)
request_context.request = request_id
with request.processing():
@ -154,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
if include_metrics:
yield request_handler(self, request, request_metrics)
else:
requests_counter.inc(request.method, servlet_name)
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
@ -229,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
Register callbacks via register_path()
Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
@ -276,49 +284,59 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and
path.
"""
callback, group_dict = self._get_handler_for_request(request)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
requests_counter.inc(request.method, servlet_classname)
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in group_dict.items()
})
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict
mapping keys to path components as specified in the handler's
path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
if request.method == "OPTIONS":
self._send_response(request, 200, {})
return
return _options_handler, {}
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if not m:
continue
# We found a match! First update the metrics object to indicate
# which servlet is handling the request.
callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
})
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
return
if m:
# We found a match!
return path_entry.callback, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
raise UnrecognizedRequestError()
return _unrecognised_request_handler, {}
def _send_response(self, request, code, response_json_object,
response_code_message=None):
@ -335,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource):
)
def _options_handler(request):
"""Request handler for OPTIONS requests
This is a request handler suitable for return from
_get_handler_for_request. It returns a 200 and an empty body.
Args:
request (twisted.web.http.Request):
Returns:
Tuple[int, dict]: http code, response body.
"""
return 200, {}
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
This is a request handler suitable for return from
_get_handler_for_request. It actually just raises an
UnrecognizedRequestError.
Args:
request (twisted.web.http.Request):
"""
raise UnrecognizedRequestError()
class RequestMetrics(object):
def start(self, clock, name):
self.start = clock.time_msec()

View file

@ -57,15 +57,31 @@ class Metrics(object):
return metric
def register_counter(self, *args, **kwargs):
"""
Returns:
CounterMetric
"""
return self._register(CounterMetric, *args, **kwargs)
def register_callback(self, *args, **kwargs):
"""
Returns:
CallbackMetric
"""
return self._register(CallbackMetric, *args, **kwargs)
def register_distribution(self, *args, **kwargs):
"""
Returns:
DistributionMetric
"""
return self._register(DistributionMetric, *args, **kwargs)
def register_cache(self, *args, **kwargs):
"""
Returns:
CacheMetric
"""
return self._register(CacheMetric, *args, **kwargs)

View file

@ -25,7 +25,7 @@ from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure
from synapse.types import Requester
from synapse.types import Requester, UserID
import logging
import re
@ -46,7 +46,7 @@ def send_event_to_master(client, host, port, requester, event, context,
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(str)): Any extra users to notify about event
extra_users (list(UserID)): Any extra users to notify about event
"""
uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
host, port, event.event_id,
@ -59,7 +59,7 @@ def send_event_to_master(client, host, port, requester, event, context,
"context": context.serialize(event),
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": extra_users,
"extra_users": [u.to_string() for u in extra_users],
}
try:
@ -143,7 +143,7 @@ class ReplicationSendEventRestServlet(RestServlet):
context = yield EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"]
extra_users = content["extra_users"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user:
request.authenticated_entity = requester.user.to_string()

View file

@ -17,7 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request
@ -185,12 +185,43 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
errcode=Codes.BAD_JSON,
)
yield self.handlers.message_handler.purge_history(
purge_id = yield self.handlers.message_handler.start_purge_history(
room_id, depth,
delete_local_events=delete_local_events,
)
defer.returnValue((200, {}))
defer.returnValue((200, {
"purge_id": purge_id,
}))
class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/admin/purge_history_status/(?P<purge_id>[^/]+)"
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer)
"""
super(PurgeHistoryStatusRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, purge_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
purge_status = self.handlers.message_handler.get_purge_status(purge_id)
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet):
@ -561,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)

View file

@ -599,7 +599,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
"(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks

View file

@ -32,8 +32,10 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication
from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import FederationServer
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.federation_server import FederationHandlerRegistry
from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
@ -99,7 +101,8 @@ class HomeServer(object):
DEPENDENCIES = [
'http_client',
'db_pool',
'replication_layer',
'federation_client',
'federation_server',
'handlers',
'v1auth',
'auth',
@ -147,6 +150,7 @@ class HomeServer(object):
'groups_attestation_renewer',
'spam_checker',
'room_member_handler',
'federation_registry',
]
def __init__(self, hostname, **kwargs):
@ -195,8 +199,11 @@ class HomeServer(object):
def get_ratelimiter(self):
return self.ratelimiter
def build_replication_layer(self):
return initialize_http_replication(self)
def build_federation_client(self):
return FederationClient(self)
def build_federation_server(self):
return FederationServer(self)
def build_handlers(self):
return Handlers(self)
@ -387,6 +394,9 @@ class HomeServer(object):
def build_room_member_handler(self):
return RoomMemberHandler(self)
def build_federation_registry(self):
return FederationHandlerRegistry()
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -283,10 +283,11 @@ class EventsStore(EventsWorkerStore):
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
with Measure(self._clock, "persist_events"):
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)

View file

@ -245,8 +245,11 @@ class StateGroupWorkerStore(SQLBaseStore):
if types:
clause_to_args = [
(
"AND type = ? AND state_key = ?" if state_key is not None else "AND type = ?",
(etype, state_key) if state_key is not None else (etype)
"AND type = ? AND state_key = ?",
(etype, state_key)
) if state_key is not None else (
"AND type = ?",
(etype,)
)
for etype, state_key in types
]
@ -277,22 +280,25 @@ class StateGroupWorkerStore(SQLBaseStore):
results[group][key] = event_id
else:
where_args = []
where_clauses = []
wildcard_types = False
if types is not None:
where_clause = "AND ("
for typ in types:
if typ[1] is None:
where_clause += "(type = ?) OR "
where_clauses.append("(type = ?)")
where_args.extend(typ[0])
wildcard_types = True
else:
where_clause += "(type = ? AND state_key = ?) OR "
where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]])
if include_other_types:
where_clause += "(%s) OR " % (
" AND ".join(["type <> ?"] * len(types)),
where_clauses.append(
"(" + " AND ".join(["type <> ?"] * len(types)) + ")"
)
where_args.extend(t for (t, _) in types)
where_clause += "0)" # 0 to terminate the last OR
where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@ -322,9 +328,17 @@ class StateGroupWorkerStore(SQLBaseStore):
if (typ, state_key) not in results[group]
)
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
if types is not None and len(results[group]) == len(types):
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
types is not None and
not wildcard_types and
len(results[group]) == len(types)
):
break
next_group = self._simple_select_one_onecol_txn(

View file

@ -292,36 +292,41 @@ class PreserveLoggingContext(object):
def preserve_fn(f):
"""Wraps a function, to ensure that the current context is restored after
"""Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
return run_in_background(f, *args, **kwargs)
return g
def run_in_background(f, *args, **kwargs):
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the funtion completes.
Useful for wrapping functions that return a deferred which you don't yield
on.
"""
def g(*args, **kwargs):
current = LoggingContext.current_context()
res = f(*args, **kwargs)
if isinstance(res, defer.Deferred) and not res.called:
# The function will have reset the context before returning, so
# we need to restore it now.
LoggingContext.set_current_context(current)
current = LoggingContext.current_context()
res = f(*args, **kwargs)
if isinstance(res, defer.Deferred) and not res.called:
# The function will have reset the context before returning, so
# we need to restore it now.
LoggingContext.set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
# get leaked into the reactor or some other function which
# wasn't expecting it. We therefore need to reset the context
# here.
#
# (If this feels asymmetric, consider it this way: we are
# effectively forking a new thread of execution. We are
# probably currently within a ``with LoggingContext()`` block,
# which is supposed to have a single entry and exit point. But
# by spawning off another deferred, we are effectively
# adding a new exit point.)
res.addBoth(_set_context_cb, LoggingContext.sentinel)
return res
return g
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
# get leaked into the reactor or some other function which
# wasn't expecting it. We therefore need to reset the context
# here.
#
# (If this feels asymmetric, consider it this way: we are
# effectively forking a new thread of execution. We are
# probably currently within a ``with LoggingContext()`` block,
# which is supposed to have a single entry and exit point. But
# by spawning off another deferred, we are effectively
# adding a new exit point.)
res.addBoth(_set_context_cb, LoggingContext.sentinel)
return res
def make_deferred_yieldable(deferred):

View file

@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.mock_federation = Mock()
self.mock_registry = Mock()
self.query_handlers = {}
def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler
self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver(
http_client=None,
resource_for_federation=Mock(),
replication_layer=self.mock_federation,
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
hs.handlers = DirectoryHandlers(hs)

View file

@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
handlers=None,
replication_layer=mock.Mock(),
federation_client=mock.Mock(),
)
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)

View file

@ -37,23 +37,23 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.mock_federation = Mock()
self.mock_registry = Mock()
self.query_handlers = {}
def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler
self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver(
http_client=None,
handlers=None,
resource_for_federation=Mock(),
replication_layer=self.mock_federation,
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
])

View file

@ -81,7 +81,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_current_state_deltas",
]),
state_handler=self.state_handler,
handlers=None,
handlers=Mock(),
notifier=mock_notifier,
resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource,

View file

@ -31,7 +31,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver(
"blue",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),

View file

@ -114,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs = yield setup_test_homeserver(
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),

View file

@ -45,7 +45,7 @@ class ProfileTestCase(unittest.TestCase):
http_client=None,
resource_for_client=self.mock_resource,
federation=Mock(),
replication_layer=Mock(),
federation_client=Mock(),
profile_handler=self.mock_handler
)

View file

@ -46,7 +46,7 @@ class RoomPermissionsTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
@ -409,7 +409,7 @@ class RoomsMemberListTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
@ -493,7 +493,7 @@ class RoomsCreateTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
@ -582,7 +582,7 @@ class RoomTopicTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
@ -697,7 +697,7 @@ class RoomMemberStateTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
@ -829,7 +829,7 @@ class RoomMessagesTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
@ -929,7 +929,7 @@ class RoomInitialSyncTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
@ -1003,7 +1003,7 @@ class RoomMessageListTestCase(RestTestCase):
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()

View file

@ -47,7 +47,7 @@ class RoomTypingTestCase(RestTestCase):
"red",
clock=self.clock,
http_client=None,
replication_layer=Mock(),
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),

View file

@ -42,7 +42,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
federation_client=Mock(),
)
self.as_token = "token1"
@ -119,7 +119,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
federation_client=Mock(),
)
self.db_pool = hs.get_db_pool()
@ -455,7 +455,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config,
datastore=Mock(),
federation_sender=Mock(),
replication_layer=Mock(),
federation_client=Mock(),
)
ApplicationServiceStore(None, hs)
@ -473,7 +473,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config,
datastore=Mock(),
federation_sender=Mock(),
replication_layer=Mock(),
federation_client=Mock(),
)
with self.assertRaises(ConfigError) as cm:
@ -497,7 +497,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
config=config,
datastore=Mock(),
federation_sender=Mock(),
replication_layer=Mock(),
federation_client=Mock(),
)
with self.assertRaises(ConfigError) as cm: