Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-05-26 09:43:21 +01:00
commit 8beca8e21f
79 changed files with 1450 additions and 578 deletions

1
changelog.d/7453.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that would cause Synapse not to resync out-of-sync device lists.

1
changelog.d/7517.feature Normal file
View file

@ -0,0 +1 @@
Add option to move event persistence off master.

1
changelog.d/7534.misc Normal file
View file

@ -0,0 +1 @@
All endpoints now respond with a 200 OK for `OPTIONS` requests.

1
changelog.d/7536.misc Normal file
View file

@ -0,0 +1 @@
Synapse now exports [detailed allocator statistics](https://doc.pypy.org/en/latest/gc_info.html#gc-get-stats) and basic GC timings as Prometheus metrics (`pypy_gc_time_seconds_total` and `pypy_memory_bytes`) when run under PyPy. Contributed by Ivan Shapovalov.

1
changelog.d/7542.misc Normal file
View file

@ -0,0 +1 @@
Add ability to wait for replication streams.

1
changelog.d/7546.misc Normal file
View file

@ -0,0 +1 @@
Optimise some references to `hs.config`.

1
changelog.d/7547.misc Normal file
View file

@ -0,0 +1 @@
On upgrade room only send canonical alias once.

1
changelog.d/7550.misc Normal file
View file

@ -0,0 +1 @@
Fix some indentation inconsistencies in the sample config.

1
changelog.d/7552.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "Missing RelayState parameter" error when using user interactive authentication with SAML for some SAML providers.

1
changelog.d/7553.misc Normal file
View file

@ -0,0 +1 @@
Include `synapse.http.site` in type checking.

1
changelog.d/7554.misc Normal file
View file

@ -0,0 +1 @@
Fix some test code to not mangle stacktraces, to make it easier to debug errors.

1
changelog.d/7555.misc Normal file
View file

@ -0,0 +1 @@
Refresh apt cache when building dh_virtualenv docker image.

1
changelog.d/7556.misc Normal file
View file

@ -0,0 +1 @@
Stop logging some expected HTTP request errors as exceptions.

1
changelog.d/7557.misc Normal file
View file

@ -0,0 +1 @@
Convert sending mail to async/await.

1
changelog.d/7558.misc Normal file
View file

@ -0,0 +1 @@
Simplify `reap_monthly_active_users`.

1
changelog.d/7560.misc Normal file
View file

@ -0,0 +1 @@
All endpoints now respond with a 200 OK for `OPTIONS` requests.

View file

@ -31,8 +31,10 @@ RUN mkdir /dh-virtualenv
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/matrix-org/dh-virtualenv/archive/matrixorg-20200519.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
# install its build deps
RUN cd /dh-virtualenv \
# install its build deps. We do another apt-cache-update here, because we might
# be using a stale cache from docker build.
RUN apt-get update -qq -o Acquire::Languages=none \
&& cd /dh-virtualenv \
&& env DEBIAN_FRONTEND=noninteractive mk-build-deps -ri -t "apt-get -y --no-install-recommends"
# build it

View file

@ -322,22 +322,27 @@ listeners:
# Used by phonehome stats to group together related servers.
#server_context: context
# Resource-constrained homeserver Settings
# Resource-constrained homeserver settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
# limit_remote_rooms.complexity, it will disallow joining or
# instantly leave.
# When this is enabled, the room "complexity" will be checked before a user
# joins a new remote room. If it is above the complexity limit, the server will
# disallow joining, or will instantly leave.
#
# limit_remote_rooms.complexity_error can be set to customise the text
# displayed to the user when a room above the complexity threshold has
# its join cancelled.
# Room complexity is an arbitrary measure based on factors such as the number of
# users in the room.
#
# Uncomment the below lines to enable:
#limit_remote_rooms:
# enabled: true
# complexity: 1.0
# complexity_error: "This room is too complex."
limit_remote_rooms:
# Uncomment to enable room complexity checking.
#
#enabled: true
# the limit above which rooms cannot be joined. The default is 1.0.
#
#complexity: 0.5
# override the error which is returned when the room is too complex.
#
#complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
@ -942,25 +947,28 @@ url_preview_accept_language:
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
# See docs/CAPTCHA_SETUP.md for full details of configuring this.
# This homeserver's ReCAPTCHA public key.
# This homeserver's ReCAPTCHA public key. Must be specified if
# enable_registration_captcha is enabled.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
# This homeserver's ReCAPTCHA private key.
# This homeserver's ReCAPTCHA private key. Must be specified if
# enable_registration_captcha is enabled.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# Uncomment to enable ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
# public/private key. Defaults to 'false'.
#
#enable_registration_captcha: false
#enable_registration_captcha: true
# The API endpoint to use for verifying m.login.recaptcha responses.
# Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify".
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
#recaptcha_siteverify_api: "https://my.recaptcha.site"
## TURN ##
@ -1104,7 +1112,7 @@ account_validity:
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
# registration_shared_secret: <PRIVATE STRING>
#registration_shared_secret: <PRIVATE STRING>
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
@ -1237,7 +1245,8 @@ metrics_flags:
#known_servers: true
# Whether or not to report anonymized homeserver usage statistics.
# report_stats: true|false
#
#report_stats: true|false
# The endpoint to report the anonymized homeserver usage statistics to.
# Defaults to https://matrix.org/report-usage-stats/push
@ -1273,13 +1282,13 @@ metrics_flags:
# the registration_shared_secret is used, if one is given; otherwise,
# a secret key is derived from the signing key.
#
# macaroon_secret_key: <PRIVATE STRING>
#macaroon_secret_key: <PRIVATE STRING>
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent
# forms to work.
#
# form_secret: <PRIVATE STRING>
#form_secret: <PRIVATE STRING>
## Signing Keys ##
@ -1764,8 +1773,8 @@ email:
# Username/password for authentication to the SMTP server. By default, no
# authentication is attempted.
#
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
#smtp_user: "exampleusername"
#smtp_pass: "examplepassword"
# Uncomment the following to require TLS transport security for SMTP.
# By default, Synapse will connect over plain text, and will then switch to

View file

@ -196,6 +196,9 @@ class MockHomeserver:
def get_reactor(self):
return reactor
def get_instance_name(self):
return "master"
class Porter(object):
def __init__(self, **kwargs):

View file

@ -22,11 +22,10 @@ from typing import Dict, Iterable
from typing_extensions import ContextManager
from twisted.internet import defer, reactor
from twisted.web.resource import NoResource
import synapse
import synapse.events
from synapse.api.errors import SynapseError
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.api.urls import (
CLIENT_API_PREFIX,
FEDERATION_PREFIX,
@ -40,14 +39,22 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
from synapse.http.server import JsonResource
from synapse.handlers.presence import (
BasePresenceHandler,
PresenceState,
get_interested_parties,
)
from synapse.http.server import JsonResource, OptionsResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
)
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@ -202,9 +209,14 @@ class KeyUploadServlet(RestServlet):
# is there.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
headers = {"Authorization": auth_headers}
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
)
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
)
except HttpResponseException as e:
raise e.to_synapse() from e
except RequestSendFailed as e:
raise SynapseError(502, "Failed to talk to master") from e
return 200, result
else:
@ -243,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler):
# but we haven't notified the master of that yet
self.users_going_offline = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
self._send_stop_syncing_loop = self.clock.looping_call(
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
@ -300,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler):
self.users_going_offline.pop(user_id, None)
self.send_user_sync(user_id, False, last_sync_ms)
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
return defer.succeed(None)
async def user_syncing(
self, user_id: str, affect_presence: bool
) -> ContextManager[None]:
@ -382,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler):
if count > 0
]
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
presence = state["presence"]
valid_presence = (
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
)
if presence not in valid_presence:
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
# If presence is disabled, no-op
if not self.hs.config.use_presence:
return
# Proxy request to master
await self._set_state_client(
user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
)
async def bump_presence_active_time(self, user):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
# If presence is disabled, no-op
if not self.hs.config.use_presence:
return
# Proxy request to master
user_id = user.to_string()
await self._bump_active_client(user_id=user_id)
class GenericWorkerTyping(object):
def __init__(self, hs):
@ -561,7 +608,7 @@ class GenericWorkerServer(HomeServer):
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
root_resource = create_resource_tree(resources, NoResource())
root_resource = create_resource_tree(resources, OptionsResource())
_base.listen_tcp(
bind_addresses,

View file

@ -31,7 +31,7 @@ from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from twisted.web.resource import EncodingResourceWrapper, IResource, NoResource
from twisted.web.resource import EncodingResourceWrapper, IResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -52,7 +52,11 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
from synapse.http.server import (
OptionsResource,
RootOptionsRedirectResource,
RootRedirect,
)
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
@ -121,11 +125,11 @@ class SynapseHomeServer(HomeServer):
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
root_resource = RootRedirect(STATIC_PREFIX)
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
else:
root_resource = NoResource()
root_resource = OptionsResource()
root_resource = create_resource_tree(resources, root_resource)

View file

@ -32,23 +32,26 @@ class CaptchaConfig(Config):
def generate_config_section(self, **kwargs):
return """\
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
# See docs/CAPTCHA_SETUP.md for full details of configuring this.
# This homeserver's ReCAPTCHA public key.
# This homeserver's ReCAPTCHA public key. Must be specified if
# enable_registration_captcha is enabled.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
# This homeserver's ReCAPTCHA private key.
# This homeserver's ReCAPTCHA private key. Must be specified if
# enable_registration_captcha is enabled.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# Uncomment to enable ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
# public/private key. Defaults to 'false'.
#
#enable_registration_captcha: false
#enable_registration_captcha: true
# The API endpoint to use for verifying m.login.recaptcha responses.
# Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify".
#
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
#recaptcha_siteverify_api: "https://my.recaptcha.site"
"""

View file

@ -311,8 +311,8 @@ class EmailConfig(Config):
# Username/password for authentication to the SMTP server. By default, no
# authentication is attempted.
#
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
#smtp_user: "exampleusername"
#smtp_pass: "examplepassword"
# Uncomment the following to require TLS transport security for SMTP.
# By default, Synapse will connect over plain text, and will then switch to

View file

@ -175,8 +175,8 @@ class KeyConfig(Config):
)
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
else:
macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>"
form_secret = "# form_secret: <PRIVATE STRING>"
macaroon_secret_key = "#macaroon_secret_key: <PRIVATE STRING>"
form_secret = "#form_secret: <PRIVATE STRING>"
return (
"""\

View file

@ -257,5 +257,6 @@ def setup_logging(
logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
logging.info("Instance name: %s", hs.get_instance_name())
return logger

View file

@ -93,10 +93,11 @@ class MetricsConfig(Config):
#known_servers: true
# Whether or not to report anonymized homeserver usage statistics.
#
"""
if report_stats is None:
res += "# report_stats: true|false\n"
res += "#report_stats: true|false\n"
else:
res += "report_stats: %s\n" % ("true" if report_stats else "false")

View file

@ -148,9 +148,7 @@ class RegistrationConfig(Config):
random_string_with_symbols(50),
)
else:
registration_shared_secret = (
"# registration_shared_secret: <PRIVATE STRING>"
)
registration_shared_secret = "#registration_shared_secret: <PRIVATE STRING>"
return (
"""\

View file

@ -434,7 +434,7 @@ class ServerConfig(Config):
)
self.limit_remote_rooms = LimitRemoteRoomsConfig(
**config.get("limit_remote_rooms", {})
**(config.get("limit_remote_rooms") or {})
)
bind_port = config.get("bind_port")
@ -895,22 +895,27 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
# Resource-constrained homeserver Settings
# Resource-constrained homeserver settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
# limit_remote_rooms.complexity, it will disallow joining or
# instantly leave.
# When this is enabled, the room "complexity" will be checked before a user
# joins a new remote room. If it is above the complexity limit, the server will
# disallow joining, or will instantly leave.
#
# limit_remote_rooms.complexity_error can be set to customise the text
# displayed to the user when a room above the complexity threshold has
# its join cancelled.
# Room complexity is an arbitrary measure based on factors such as the number of
# users in the room.
#
# Uncomment the below lines to enable:
#limit_remote_rooms:
# enabled: true
# complexity: 1.0
# complexity_error: "This room is too complex."
limit_remote_rooms:
# Uncomment to enable room complexity checking.
#
#enabled: true
# the limit above which rooms cannot be joined. The default is 1.0.
#
#complexity: 0.5
# override the error which is returned when the room is too complex.
#
#complexity_error: "This room is too complex."
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.

View file

@ -15,7 +15,7 @@
import attr
from ._base import Config
from ._base import Config, ConfigError
@attr.s
@ -27,6 +27,17 @@ class InstanceLocationConfig:
port = attr.ib(type=int)
@attr.s
class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instance that writes to the event and backfill streams.
"""
events = attr.ib(default="master", type=str)
class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the
@ -83,11 +94,26 @@ class WorkerConfig(Config):
bind_addresses.append("")
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map", {}) or {}
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events also appears in
# `instance_map`.
if (
self.writers.events != "master"
and self.writers.events not in self.instance_map
):
raise ConfigError(
"Instance %r is configured to write events but does not appear in `instance_map` config."
% (self.writers.events,)
)
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
# the config. A lot of these options have a worker_* prefix when running

View file

@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
@ -535,6 +536,15 @@ class DeviceListUpdater(object):
iterable=True,
)
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
self.clock.looping_call(
run_as_background_process,
30 * 1000,
func=self._maybe_retry_device_resync,
desc="_maybe_retry_device_resync",
)
@trace
@defer.inlineCallbacks
def incoming_device_list_update(self, origin, edu_content):
@ -679,11 +689,50 @@ class DeviceListUpdater(object):
return False
@defer.inlineCallbacks
def user_device_resync(self, user_id):
def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
if self._resync_retry_in_progress:
return
try:
# Prevent another call of this function to retry resyncing device lists so
# we don't send too many requests.
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs.
for user_id in need_resync:
# Try to resync the current user's devices list. Exception handling
# isn't necessary here, since user_device_resync catches all instances
# of "Exception" that might be raised from the federation request. This
# means that if an exception is raised by this function, it must be
# because of a database issue, which means _maybe_retry_device_resync
# probably won't be able to go much further anyway.
result = yield self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False,
)
# user_device_resync only returns a result if it managed to successfully
# resync and update the database. Updating the table of users requiring
# resync isn't necessary here as user_device_resync already does it
# (through self.store.update_remote_device_list_cache).
if result:
logger.debug(
"Successfully resynced the device list for %s" % user_id,
)
finally:
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
@defer.inlineCallbacks
def user_device_resync(self, user_id, mark_failed_as_stale=True):
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (str): The user's id whose device_list will be updated.
mark_failed_as_stale (bool): Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this
request:
@ -694,10 +743,23 @@ class DeviceListUpdater(object):
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again
# later
logger.warning("Failed to handle device list update for %s", user_id)
except NotRetryingDestination:
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
return
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s", user_id, e,
)
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@ -711,13 +773,17 @@ class DeviceListUpdater(object):
logger.info(e)
return
except Exception as e:
# TODO: Remember that we are now out of sync and try again
# later
set_tag("error", True)
log_kv(
{"message": "Exception raised by federation request", "exception": e}
)
logger.exception("Failed to handle device list update for %s", user_id)
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
return
log_kv({"result": result})
stream_id = result["stream_id"]

View file

@ -40,6 +40,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
FederationError,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
@ -125,10 +126,10 @@ class FederationHandler(BaseHandler):
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
hs
)
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
hs
)
@ -1036,6 +1037,12 @@ class FederationHandler(BaseHandler):
# TODO: We can probably do something more intelligent here.
return True
except SynapseError as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
except HttpResponseException as e:
if 400 <= e.code < 500:
raise e.to_synapse_error()
logger.info("Failed to backfill from %s because %s", dom, e)
continue
except CodeMessageException as e:
@ -1214,7 +1221,7 @@ class FederationHandler(BaseHandler):
async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
) -> None:
) -> Tuple[str, int]:
""" Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`.
@ -1235,6 +1242,10 @@ class FederationHandler(BaseHandler):
content: The event content to use for the join event.
"""
# TODO: We should be able to call this on workers, but the upgrading of
# room stuff after join currently doesn't work on workers.
assert self.config.worker.worker_app is None
logger.debug("Joining %s to %s", joinee, room_id)
origin, event, room_version_obj = await self._make_and_verify_event(
@ -1297,15 +1308,23 @@ class FederationHandler(BaseHandler):
room_id=room_id, room_version=room_version_obj,
)
await self._persist_auth_tree(
max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.config.worker.writers.events, "events", max_stream_id
)
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = await self.store.get_room_predecessor(room_id)
if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
return event.event_id, max_stream_id
old_room_id = predecessor["room_id"]
logger.debug(
"Found predecessor for %s during remote join: %s", room_id, old_room_id
@ -1318,6 +1337,7 @@ class FederationHandler(BaseHandler):
)
logger.debug("Finished joining %s to %s", joinee, room_id)
return event.event_id, max_stream_id
finally:
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
@ -1547,7 +1567,7 @@ class FederationHandler(BaseHandler):
async def do_remotely_reject_invite(
self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
) -> EventBase:
) -> Tuple[EventBase, int]:
origin, event, room_version = await self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content
)
@ -1567,9 +1587,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(target_hosts, event)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
stream_id = await self.persist_events_and_notify([(event, context)])
return event
return event, stream_id
async def _make_and_verify_event(
self,
@ -1881,7 +1901,7 @@ class FederationHandler(BaseHandler):
state: List[EventBase],
event: EventBase,
room_version: RoomVersion,
) -> None:
) -> int:
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event separately. Notifies about the persisted events
@ -1975,7 +1995,7 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
await self.persist_events_and_notify([(event, new_event_context)])
return await self.persist_events_and_notify([(event, new_event_context)])
async def _prep_event(
self,
@ -2828,7 +2848,7 @@ class FederationHandler(BaseHandler):
self,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> None:
) -> int:
"""Persists events and tells the notifier/pushers about them, if
necessary.
@ -2837,12 +2857,14 @@ class FederationHandler(BaseHandler):
backfilled: Whether these events are a result of
backfilling or not
"""
if self.config.worker_app:
await self._send_events_to_master(
if self.config.worker.writers.events != self._instance_name:
result = await self._send_events(
instance_name=self.config.worker.writers.events,
store=self.store,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
return result["max_stream_id"]
else:
max_stream_id = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
@ -2857,6 +2879,8 @@ class FederationHandler(BaseHandler):
for event, _ in event_and_contexts:
await self._notify_persisted_event(event, max_stream_id)
return max_stream_id
async def _notify_persisted_event(
self, event: EventBase, max_stream_id: int
) -> None:

View file

@ -290,8 +290,7 @@ class IdentityHandler(BaseHandler):
return changed
@defer.inlineCallbacks
def send_threepid_validation(
async def send_threepid_validation(
self,
email_address,
client_secret,
@ -319,7 +318,7 @@ class IdentityHandler(BaseHandler):
"""
# Check that this email/client_secret/send_attempt combo is new or
# greater than what we've seen previously
session = yield self.store.get_threepid_validation_session(
session = await self.store.get_threepid_validation_session(
"email", client_secret, address=email_address, validated=False
)
@ -353,7 +352,7 @@ class IdentityHandler(BaseHandler):
# Send the mail with the link containing the token, client_secret
# and session_id
try:
yield send_email_func(email_address, token, client_secret, session_id)
await send_email_func(email_address, token, client_secret, session_id)
except Exception:
logger.exception(
"Error sending threepid validation email to %s", email_address
@ -364,7 +363,7 @@ class IdentityHandler(BaseHandler):
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
)
yield self.store.start_or_continue_validation_session(
await self.store.start_or_continue_validation_session(
"email",
email_address,
session_id,

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import Optional, Tuple
from six import iteritems, itervalues, string_types
@ -42,6 +42,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events import EventBase
from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@ -365,10 +366,13 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._is_event_writer = (
self.config.worker.writers.events == hs.get_instance_name()
)
self.room_invite_state_types = self.hs.config.room_invite_state_types
self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
@ -632,7 +636,9 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
async def send_nonmember_event(
self, requester, event, context, ratelimit=True
) -> int:
"""
Persists and notifies local clients and federation of an event.
@ -641,6 +647,9 @@ class EventCreationHandler(object):
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
Return:
The stream_id of the persisted event.
"""
if event.type == EventTypes.Member:
raise SynapseError(
@ -661,7 +670,7 @@ class EventCreationHandler(object):
)
return prev_state
await self.handle_new_client_event(
return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@ -690,7 +699,7 @@ class EventCreationHandler(object):
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
@ -713,10 +722,10 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
await self.send_nonmember_event(
stream_id = await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
return event, stream_id
@measure_func("create_new_client_event")
@defer.inlineCallbacks
@ -776,7 +785,7 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event")
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
) -> int:
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@ -789,6 +798,9 @@ class EventCreationHandler(object):
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
Return:
The stream_id of the persisted event.
"""
if event.is_state() and (event.type, event.state_key) == (
@ -828,8 +840,9 @@ class EventCreationHandler(object):
success = False
try:
# If we're a worker we need to hit out to the master.
if self._is_worker_app:
await self.send_event_to_master(
if not self._is_event_writer:
result = await self.send_event(
instance_name=self.config.worker.writers.events,
event_id=event.event_id,
store=self.store,
requester=requester,
@ -838,14 +851,17 @@ class EventCreationHandler(object):
ratelimit=ratelimit,
extra_users=extra_users,
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
success = True
return
return stream_id
await self.persist_and_notify_client_event(
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
return stream_id
finally:
if not success:
# Ensure that we actually remove the entries in the push actions
@ -888,13 +904,13 @@ class EventCreationHandler(object):
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
) -> int:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
This should only be run on master.
This should only be run on the instance in charge of persisting events.
"""
assert not self._is_worker_app
assert self._is_event_writer
if ratelimit:
# We check if this is a room admin redacting an event so that we
@ -1078,6 +1094,8 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event_stream_id
async def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()

View file

@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC):
) -> None:
"""Set the presence state of the user. """
@abc.abstractmethod
async def bump_presence_active_time(self, user: UserID):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"):

View file

@ -22,6 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
from typing import Tuple
from six import iteritems, string_types
@ -88,6 +89,8 @@ class RoomCreationHandler(BaseHandler):
self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config
self._replication = hs.get_replication_data_handler()
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@ -439,73 +442,78 @@ class RoomCreationHandler(BaseHandler):
new_room_id: str,
old_room_state: StateMap[str],
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
#
# Note that we'll only be able to remove aliases that (a) aren't owned by an AS,
# and (b) unless the user is a server admin, which the user created.
#
# This is probably correct - given we don't allow such aliases to be deleted
# normally, it would be odd to allow it in the case of doing a room upgrade -
# but it makes the upgrade less effective, and you have to wonder why a room
# admin can't remove aliases that point to that room anyway.
# (cf https://github.com/matrix-org/synapse/issues/2360)
#
removed_aliases = []
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
await self.store.update_aliases_for_room(old_room_id, new_room_id)
# if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
# of this.
if not removed_aliases:
if not canonical_alias_event:
return
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
await directory_handler.create_association(
requester,
RoomAlias.from_string(alias),
new_room_id,
servers=(self.hs.hostname,),
check_membership=False,
)
logger.info("Moved alias %s to new room", alias)
except SynapseError as e:
# I'm not really expecting this to happen, but it could if the spam
# checking module decides it shouldn't, or similar.
logger.error("Error adding alias %s to new room: %s", alias, e)
# If there is a canonical alias we need to update the one in the old
# room and set one in the new one.
old_canonical_alias_content = dict(canonical_alias_event.content)
new_canonical_alias_content = {}
canonical = canonical_alias_event.content.get("alias")
if canonical and self.hs.is_mine_id(canonical):
new_canonical_alias_content["alias"] = canonical
old_canonical_alias_content.pop("alias", None)
# We convert to a list as it will be a Tuple.
old_alt_aliases = list(old_canonical_alias_content.get("alt_aliases", []))
if old_alt_aliases:
old_canonical_alias_content["alt_aliases"] = old_alt_aliases
new_alt_aliases = new_canonical_alias_content.setdefault("alt_aliases", [])
for alias in canonical_alias_event.content.get("alt_aliases", []):
try:
if self.hs.is_mine_id(alias):
new_alt_aliases.append(alias)
old_alt_aliases.remove(alias)
except Exception:
logger.info(
"Invalid alias %s in canonical alias event %s",
alias,
canonical_alias_event_id,
)
if not old_alt_aliases:
old_canonical_alias_content.pop("alt_aliases")
# If a canonical alias event existed for the old room, fire a canonical
# alias event for the new room with a copy of the information.
try:
if canonical_alias_event:
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": canonical_alias_event.content,
},
ratelimit=False,
)
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
"state_key": "",
"room_id": old_room_id,
"sender": requester.user.to_string(),
"content": old_canonical_alias_content,
},
ratelimit=False,
)
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in old room: %s", e)
try:
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": new_canonical_alias_content,
},
ratelimit=False,
)
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
@ -513,7 +521,7 @@ class RoomCreationHandler(BaseHandler):
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
):
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
@ -530,9 +538,9 @@ class RoomCreationHandler(BaseHandler):
`avatar_url` and/or `displayname`.
Returns:
Deferred[dict]:
a dict containing the keys `room_id` and, if an alias was
requested, `room_alias`.
First, a dict containing the keys `room_id` and, if an alias
was, requested, `room_alias`. Secondly, the stream_id of the
last persisted event.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
@ -664,7 +672,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
await self._send_events_for_new_room(
last_stream_id = await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@ -678,7 +686,10 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
await self.event_creation_handler.create_and_send_nonmember_event(
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@ -692,7 +703,10 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
await self.event_creation_handler.create_and_send_nonmember_event(
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@ -710,7 +724,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
await self.room_member_handler.update_membership(
_, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@ -724,7 +738,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
await self.hs.get_room_member_handler().do_3pid_invite(
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@ -740,7 +754,12 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
return result
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", last_stream_id
)
return result, last_stream_id
async def _send_events_for_new_room(
self,
@ -753,7 +772,13 @@ class RoomCreationHandler(BaseHandler):
room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None,
):
) -> int:
"""Sends the initial events into a new room.
Returns:
The stream_id of the last event persisted.
"""
def create(etype, content, **kwargs):
e = {"type": etype, "content": content}
@ -762,12 +787,16 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype, content, **kwargs):
async def send(etype, content, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
await self.event_creation_handler.create_and_send_nonmember_event(
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
return last_stream_id
config = RoomCreationHandler.PRESETS_DICT[preset_config]
@ -792,7 +821,9 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
await send(etype=EventTypes.PowerLevels, content=pl_content)
last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=pl_content
)
else:
power_level_content = {
"users": {creator_id: 100},
@ -825,33 +856,39 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
await send(etype=EventTypes.PowerLevels, content=power_level_content)
last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=power_level_content
)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content)
last_sent_stream_id = await send(
etype=etype, state_key=state_key, content=content
)
return last_sent_stream_id
async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,

View file

@ -17,7 +17,7 @@
import abc
import logging
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple
from six.moves import http_client
@ -26,6 +26,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.replication.http.membership import (
ReplicationLocallyRejectInviteRestServlet,
)
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@ -44,11 +47,6 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@ -72,6 +70,17 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
else:
self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
hs
)
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@ -85,7 +94,7 @@ class RoomMemberHandler(object):
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
) -> Tuple[str, int]:
"""Try and join a room that this server is not in
Args:
@ -105,7 +114,7 @@ class RoomMemberHandler(object):
room_id: str,
target: UserID,
content: dict,
) -> dict:
) -> Tuple[Optional[str], int]:
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
@ -122,6 +131,22 @@ class RoomMemberHandler(object):
"""
raise NotImplementedError()
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
if self._is_on_event_persistence_instance:
return await self.persist_event_storage.locally_reject_invite(
user_id, room_id
)
else:
result = await self._locally_reject_client(
instance_name=self._event_stream_writer_instance,
user_id=user_id,
room_id=room_id,
)
return result["stream_id"]
@abc.abstractmethod
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
@ -155,7 +180,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> EventBase:
) -> Tuple[str, int]:
user_id = target.to_string()
if content is None:
@ -188,9 +213,10 @@ class RoomMemberHandler(object):
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
await self.event_creation_handler.handle_new_client_event(
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
@ -214,7 +240,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id)
return event
return event.event_id, stream_id
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
@ -264,7 +290,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]:
) -> Tuple[Optional[str], int]:
key = (room_id,)
as_id = object()
@ -314,7 +340,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]:
) -> Tuple[Optional[str], int]:
content_specified = bool(content)
if content is None:
content = {}
@ -418,7 +444,13 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
return old_state
_, stream_id = await self.store.get_event_ordering(
old_state.event_id
)
return (
old_state.event_id,
stream_id,
)
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
@ -725,7 +757,7 @@ class RoomMemberHandler(object):
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> None:
) -> int:
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@ -757,11 +789,11 @@ class RoomMemberHandler(object):
)
if invitee:
await self.update_membership(
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
await self._make_and_store_3pid_invite(
stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@ -772,6 +804,8 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
return stream_id
async def _make_and_store_3pid_invite(
self,
requester: Requester,
@ -782,7 +816,7 @@ class RoomMemberHandler(object):
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> None:
) -> int:
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
@ -837,7 +871,10 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
await self.event_creation_handler.create_and_send_nonmember_event(
(
event,
stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@ -855,6 +892,7 @@ class RoomMemberHandler(object):
ratelimit=False,
txn_id=txn_id,
)
return stream_id
async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
@ -936,7 +974,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id: str,
user: UserID,
content: dict,
) -> None:
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -965,7 +1003,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
await self.federation_handler.do_invite_join(
event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
await self._user_joined_room(user, room_id)
@ -975,14 +1013,14 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
if too_complex is False:
# We checked, and we're under the limit.
return
return event_id, stream_id
# Check again, but with the local state events
too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
return
return event_id, stream_id
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
@ -995,6 +1033,8 @@ class RoomMemberMasterHandler(RoomMemberHandler):
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
return event_id, stream_id
async def _remote_reject_invite(
self,
requester: Requester,
@ -1002,15 +1042,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id: str,
target: UserID,
content: dict,
) -> dict:
) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
try:
ret = await fed_handler.do_remotely_reject_invite(
event, stream_id = await fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content,
)
return ret
return event.event_id, stream_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@ -1020,8 +1060,8 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(target.to_string(), room_id)
return {}
stream_id = await self.locally_reject_invite(target.to_string(), room_id)
return None, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import List, Optional
from typing import List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
await self._user_joined_room(user, room_id)
return ret
return ret["event_id"], ret["stream_id"]
async def _remote_reject_invite(
self,
@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
room_id: str,
target: UserID,
content: dict,
) -> dict:
) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
"""
return await self._remote_reject_client(
ret = await self._remote_reject_client(
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
content=content,
)
return ret["event_id"], ret["stream_id"]
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room

View file

@ -144,6 +144,11 @@ def _handle_json_response(reactor, timeout_sec, request, response):
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
except TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response", request.txn_id, request.destination,
)
raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
@ -424,6 +429,8 @@ class MatrixFederationHttpClient(object):
)
response = yield request_deferred
except TimeoutError as e:
raise RequestSendFailed(e, can_retry=True) from e
except DNSLookupError as e:
raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
except Exception as e:

View file

@ -350,9 +350,6 @@ class JsonResource(HttpServer, resource.Resource):
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
if request.method == b"OPTIONS":
return _options_handler, "options_request_handler", {}
request_path = request.path.decode("ascii")
# Loop through all the registered callbacks to check if the method
@ -448,6 +445,26 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request)
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request):
code, response_json_object = _options_handler(request)
return respond_with_json(
request, code, response_json_object, send_cors=True, canonical_json=False,
)
def getChildWithDefault(self, path, request):
if request.method == b"OPTIONS":
return self # select ourselves as the child to render
return resource.Resource.getChildWithDefault(self, path, request)
class RootOptionsRedirectResource(OptionsResource, RootRedirect):
pass
def respond_with_json(
request,
code,

View file

@ -14,6 +14,7 @@
import contextlib
import logging
import time
from typing import Optional
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@ -45,7 +46,7 @@ class SynapseRequest(Request):
request even after the client has disconnected.
Attributes:
logcontext(LoggingContext) : the log context for this request
logcontext: the log context for this request
"""
def __init__(self, channel, *args, **kw):
@ -53,10 +54,10 @@ class SynapseRequest(Request):
self.site = channel.site
self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
self.start_time = 0.0
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None
self.logcontext = None # type: Optional[LoggingContext]
global _next_request_seq
self.request_seq = _next_request_seq
@ -182,6 +183,7 @@ class SynapseRequest(Request):
self.finish_time = time.time()
Request.finish(self)
if not self._is_processing:
assert self.logcontext is not None
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
@ -249,6 +251,7 @@ class SynapseRequest(Request):
def _finished_processing(self):
"""Log the completion of this request and update the metrics
"""
assert self.logcontext is not None
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:

View file

@ -26,7 +26,12 @@ import six
import attr
from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricFamily
from prometheus_client.core import (
REGISTRY,
CounterMetricFamily,
GaugeMetricFamily,
HistogramMetricFamily,
)
from twisted.internet import reactor
@ -338,6 +343,78 @@ class GCCounts(object):
if not running_on_pypy:
REGISTRY.register(GCCounts())
#
# PyPy GC / memory metrics
#
class PyPyGCStats(object):
def collect(self):
# @stats is a pretty-printer object with __str__() returning a nice table,
# plus some fields that contain data from that table.
# unfortunately, fields are pretty-printed themselves (i. e. '4.5MB').
stats = gc.get_stats(memory_pressure=False) # type: ignore
# @s contains same fields as @stats, but as actual integers.
s = stats._s # type: ignore
# also note that field naming is completely braindead
# and only vaguely correlates with the pretty-printed table.
# >>>> gc.get_stats(False)
# Total memory consumed:
# GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory
# in arenas: 3.0MB # s.total_arena_memory
# rawmalloced: 1.7MB # s.total_rawmalloced_memory
# nursery: 4.0MB # s.nursery_size
# raw assembler used: 31.0kB # s.jit_backend_used
# -----------------------------
# Total: 8.8MB # stats.memory_used_sum
#
# Total memory allocated:
# GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory
# in arenas: 30.9MB # s.peak_arena_memory
# rawmalloced: 4.1MB # s.peak_rawmalloced_memory
# nursery: 4.0MB # s.nursery_size
# raw assembler allocated: 1.0MB # s.jit_backend_allocated
# -----------------------------
# Total: 39.7MB # stats.memory_allocated_sum
#
# Total time spent in GC: 0.073 # s.total_gc_time
pypy_gc_time = CounterMetricFamily(
"pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
)
pypy_gc_time.add_metric([], s.total_gc_time / 1000)
yield pypy_gc_time
pypy_mem = GaugeMetricFamily(
"pypy_memory_bytes",
"Memory tracked by PyPy allocator",
labels=["state", "class", "kind"],
)
# memory used by JIT assembler
pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used)
pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated)
# memory used by GCed objects
pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory)
pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory)
pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory)
pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory)
pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size)
pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size)
# totals
pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory)
pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory)
pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory)
pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory)
yield pypy_mem
if running_on_pypy:
REGISTRY.register(PyPyGCStats())
#
# Twisted reactor metrics
#

View file

@ -15,7 +15,6 @@
import logging
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
@ -132,8 +131,7 @@ class EmailPusher(object):
self._is_processing = False
self._start_processing()
@defer.inlineCallbacks
def _process(self):
async def _process(self):
# we should never get here if we are already processing
assert not self._is_processing
@ -142,7 +140,7 @@ class EmailPusher(object):
if self.throttle_params is None:
# this is our first loop: load up the throttle params
self.throttle_params = yield self.store.get_throttle_params_by_room(
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
@ -151,7 +149,7 @@ class EmailPusher(object):
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
await self._unsafe_process()
except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
@ -159,8 +157,7 @@ class EmailPusher(object):
finally:
self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
async def _unsafe_process(self):
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
@ -168,12 +165,12 @@ class EmailPusher(object):
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
soonest_due_at = None
if not unprocessed:
yield self.save_last_stream_ordering_and_success(self.max_stream_ordering)
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
return
for push_action in unprocessed:
@ -201,15 +198,15 @@ class EmailPusher(object):
"throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
}
yield self.send_notification(unprocessed, reason)
await self.send_notification(unprocessed, reason)
yield self.save_last_stream_ordering_and_success(
await self.save_last_stream_ordering_and_success(
max(ea["stream_ordering"] for ea in unprocessed)
)
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
yield self.sent_notif_update_throttle(ea["room_id"], ea)
await self.sent_notif_update_throttle(ea["room_id"], ea)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
@ -227,14 +224,13 @@ class EmailPusher(object):
self.seconds_until(soonest_due_at), self.on_timer
)
@defer.inlineCallbacks
def save_last_stream_ordering_and_success(self, last_stream_ordering):
async def save_last_stream_ordering_and_success(self, last_stream_ordering):
if last_stream_ordering is None:
# This happens if we haven't yet processed anything
return
self.last_stream_ordering = last_stream_ordering
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.email,
self.user_id,
@ -275,13 +271,12 @@ class EmailPusher(object):
may_send_at = last_sent_ts + throttle_ms
return may_send_at
@defer.inlineCallbacks
def sent_notif_update_throttle(self, room_id, notified_push_action):
async def sent_notif_update_throttle(self, room_id, notified_push_action):
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
notified_push_action["stream_ordering"]
)
@ -310,14 +305,13 @@ class EmailPusher(object):
"last_sent_ts": self.clock.time_msec(),
"throttle_ms": new_throttle_ms,
}
yield self.store.set_throttle_params(
await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
@defer.inlineCallbacks
def send_notification(self, push_actions, reason):
async def send_notification(self, push_actions, reason):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
await self.mailer.send_notification_mail(
self.app_id, self.user_id, self.email, push_actions, reason
)

View file

@ -26,8 +26,6 @@ from six.moves import urllib
import bleach
import jinja2
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@ -127,8 +125,7 @@ class Mailer(object):
logger.info("Created Mailer for app_name %s" % app_name)
@defer.inlineCallbacks
def send_password_reset_mail(self, email_address, token, client_secret, sid):
async def send_password_reset_mail(self, email_address, token, client_secret, sid):
"""Send an email with a password reset link to a user
Args:
@ -149,14 +146,13 @@ class Mailer(object):
template_vars = {"link": link}
yield self.send_email(
await self.send_email(
email_address,
"[%s] Password Reset" % self.hs.config.server_name,
template_vars,
)
@defer.inlineCallbacks
def send_registration_mail(self, email_address, token, client_secret, sid):
async def send_registration_mail(self, email_address, token, client_secret, sid):
"""Send an email with a registration confirmation link to a user
Args:
@ -177,14 +173,13 @@ class Mailer(object):
template_vars = {"link": link}
yield self.send_email(
await self.send_email(
email_address,
"[%s] Register your Email Address" % self.hs.config.server_name,
template_vars,
)
@defer.inlineCallbacks
def send_add_threepid_mail(self, email_address, token, client_secret, sid):
async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
"""Send an email with a validation link to a user for adding a 3pid to their account
Args:
@ -206,20 +201,19 @@ class Mailer(object):
template_vars = {"link": link}
yield self.send_email(
await self.send_email(
email_address,
"[%s] Validate Your Email" % self.hs.config.server_name,
template_vars,
)
@defer.inlineCallbacks
def send_notification_mail(
async def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason
):
"""Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
notif_events = yield self.store.get_events(
notif_events = await self.store.get_events(
[pa["event_id"] for pa in push_actions]
)
@ -232,7 +226,7 @@ class Mailer(object):
state_by_room = {}
try:
user_display_name = yield self.store.get_profile_displayname(
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@ -240,14 +234,13 @@ class Mailer(object):
except StoreError:
user_display_name = user_id
@defer.inlineCallbacks
def _fetch_room_state(room_id):
room_state = yield self.store.get_current_state_ids(room_id)
async def _fetch_room_state(room_id):
room_state = await self.store.get_current_state_ids(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
# notifs are much less realtime than sync so we can afford to wait a bit.
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
@ -255,19 +248,19 @@ class Mailer(object):
rooms = []
for r in rooms_in_order:
roomvars = yield self.get_room_vars(
roomvars = await self.get_room_vars(
r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
)
rooms.append(roomvars)
reason["room_name"] = yield calculate_room_name(
reason["room_name"] = await calculate_room_name(
self.store,
state_by_room[reason["room_id"]],
user_id,
fallback_to_members=True,
)
summary_text = yield self.make_summary_text(
summary_text = await self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason
)
@ -282,12 +275,11 @@ class Mailer(object):
"reason": reason,
}
yield self.send_email(
await self.send_email(
email_address, "[%s] %s" % (self.app_name, summary_text), template_vars
)
@defer.inlineCallbacks
def send_email(self, email_address, subject, template_vars):
async def send_email(self, email_address, subject, template_vars):
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@ -317,7 +309,7 @@ class Mailer(object):
logger.info("Sending email to %s" % email_address)
yield make_deferred_yieldable(
await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
raw_from,
@ -332,13 +324,14 @@ class Mailer(object):
)
)
@defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
async def get_room_vars(
self, room_id, user_id, notifs, notif_events, room_state_ids
):
my_member_event_id = room_state_ids[("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
my_member_event = await self.store.get_event(my_member_event_id)
is_invite = my_member_event.content["membership"] == "invite"
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
room_name = await calculate_room_name(self.store, room_state_ids, user_id)
room_vars = {
"title": room_name,
@ -350,7 +343,7 @@ class Mailer(object):
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
notifvars = await self.get_notif_vars(
n, user_id, notif_events[n["event_id"]], room_state_ids
)
@ -377,9 +370,8 @@ class Mailer(object):
return room_vars
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = yield self.store.get_events_around(
async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
before_limit=CONTEXT_BEFORE,
@ -392,25 +384,24 @@ class Mailer(object):
"messages": [],
}
the_events = yield filter_events_for_client(
the_events = await filter_events_for_client(
self.storage, user_id, results["events_before"]
)
the_events.append(notif_event)
for event in the_events:
messagevars = yield self.get_message_vars(notif, event, room_state_ids)
messagevars = await self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None:
ret["messages"].append(messagevars)
return ret
@defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
async def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message:
return
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = yield self.store.get_event(sender_state_event_id)
sender_state_event = await self.store.get_event(sender_state_event_id)
sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url")
@ -460,8 +451,7 @@ class Mailer(object):
return messagevars
@defer.inlineCallbacks
def make_summary_text(
async def make_summary_text(
self, notifs_by_room, room_state_ids, notif_events, user_id, reason
):
if len(notifs_by_room) == 1:
@ -471,17 +461,17 @@ class Mailer(object):
# If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name(
room_name = await calculate_room_name(
self.store, room_state_ids[room_id], user_id, fallback_to_members=False
)
my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
my_member_event = await self.store.get_event(my_member_event_id)
if my_member_event.content["membership"] == "invite":
inviter_member_event_id = room_state_ids[room_id][
("m.room.member", my_member_event.sender)
]
inviter_member_event = yield self.store.get_event(
inviter_member_event = await self.store.get_event(
inviter_member_event_id
)
inviter_name = name_from_member_event(inviter_member_event)
@ -506,7 +496,7 @@ class Mailer(object):
state_event_id = room_state_ids[room_id][
("m.room.member", event.sender)
]
state_event = yield self.store.get_event(state_event_id)
state_event = await self.store.get_event(state_event_id)
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
@ -535,7 +525,7 @@ class Mailer(object):
}
)
member_events = yield self.store.get_events(
member_events = await self.store.get_events(
[
room_state_ids[room_id][("m.room.member", s)]
for s in sender_ids
@ -567,7 +557,7 @@ class Mailer(object):
}
)
member_events = yield self.store.get_events(
member_events = await self.store.get_events(
[room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
)

View file

@ -19,6 +19,7 @@ from synapse.replication.http import (
federation,
login,
membership,
presence,
register,
send_event,
streams,
@ -35,10 +36,11 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
membership.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)

View file

@ -142,6 +142,7 @@ class ReplicationEndpoint(object):
"""
clock = hs.get_clock()
client = hs.get_simple_http_client()
local_instance_name = hs.get_instance_name()
master_host = hs.config.worker_replication_host
master_port = hs.config.worker_replication_http_port
@ -151,6 +152,8 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
host = master_host
port = master_port

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and
notifying.
notifying. Returns the maximum stream ID of the persisted events.
The API looks like:
@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"context": { .. serialized event context .. },
}],
"backfilled": false
}
200 OK
{
"max_stream_id": 32443,
}
"""
NAME = "fed_send_events"
@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
await self.federation_handler.persist_events_and_notify(
max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
return 200, {}
return 200, {"max_stream_id": max_stream_id}
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):

View file

@ -14,12 +14,16 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -76,11 +80,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
await self.federation_handler.do_invite_join(
event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
return 200, {}
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
@ -136,10 +141,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
event = await self.federation_handler.do_remotely_reject_invite(
event, stream_id = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content,
)
ret = event.get_pdu_json()
event_id = event.event_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@ -149,10 +154,42 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
#
logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(user_id, room_id)
ret = {}
stream_id = await self.member_handler.locally_reject_invite(
user_id, room_id
)
event_id = None
return 200, ret
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects the invite for the user and room locally.
Request format:
POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
{}
"""
NAME = "locally_reject_invite"
PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.member_handler = hs.get_room_member_handler()
@staticmethod
def _serialize_payload(room_id, user_id):
return {}
async def _handle_request(self, request, room_id, user_id):
logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
return 200, {"stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@ -208,3 +245,4 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)

View file

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
"""We've seen the user do something that indicates they're interacting
with the app.
The POST looks like:
POST /_synapse/replication/bump_presence_active_time/<user_id>
200 OK
{}
"""
NAME = "bump_presence_active_time"
PATH_ARGS = ("user_id",)
METHOD = "POST"
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._presence_handler = hs.get_presence_handler()
@staticmethod
def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id)
)
return (
200,
{},
)
class ReplicationPresenceSetState(ReplicationEndpoint):
"""Set the presence state for a user.
The POST looks like:
POST /_synapse/replication/presence_set_state/<user_id>
{
"state": { ... },
"ignore_status_msg": false,
}
200 OK
{}
"""
NAME = "presence_set_state"
PATH_ARGS = ("user_id",)
METHOD = "POST"
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._presence_handler = hs.get_presence_handler()
@staticmethod
def _serialize_payload(user_id, state, ignore_status_msg=False):
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
}
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
await self._presence_handler.set_state(
UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
)
return (
200,
{},
)
def register_servlets(hs, http_server):
ReplicationBumpPresenceActiveTime(hs).register(http_server)
ReplicationPresenceSetState(hs).register(http_server)

View file

@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
await self.event_creation_handler.persist_and_notify_client_event(
stream_id = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
return 200, {}
return 200, {"stream_id": stream_id}
def register_servlets(hs, http_server):

View file

@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
super().__init__(hs)
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication handler (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_tcp_replication().get_streams()
self.streams = hs.get_replication_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token):

View file

@ -14,19 +14,23 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import heapq
import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -35,6 +39,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# How long we allow callers to wait for replication updates before timing out.
_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
@ -92,6 +100,16 @@ class ReplicationDataHandler:
self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@ -131,8 +149,76 @@ class ReplicationDataHandler:
await self.pusher_pool.on_new_notifications(token, token)
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
waiting_list = self._streams_to_waiters.get(stream_name, [])
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
# loop below means either a) no items in list so no-op or b) all items
# in list were called and so the list should be cleared. Setting it to
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)
for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
try:
with PreserveLoggingContext():
deferred.callback(None)
except Exception:
# The deferred has been cancelled or timed out.
pass
else:
# The list is sorted by position so we don't need to continue
# checking any futher entries in the list.
index_of_first_deferred_not_called = idx
break
# Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
"""Wait until this instance has received updates up to and including
the given stream position.
"""
if instance_name == self._instance_name:
# We don't get told about updates written by this process, and
# anyway in that case we don't need to wait.
return
current_position = self._streams[stream_name].current_token(self._instance_name)
if position <= current_position:
# We're already past the position
return
# Create a new deferred that times out after N seconds, as we don't want
# to wedge here forever.
deferred = Deferred()
deferred = timeout_deferred(
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
# We insert into the list using heapq as it is more efficient than
# pushing then resorting each time.
heapq.heappush(waiting_list, (position, deferred))
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
await make_deferred_yieldable(deferred)
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)

View file

@ -38,7 +38,9 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
BackfillStream,
CachesStream,
EventsStream,
FederationStream,
Stream,
)
@ -87,6 +89,14 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if hs.config.worker.writers.events == hs.get_instance_name():
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue

View file

@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet):
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
self._replication = hs.get_replication_data_handler()
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
info = await self._room_creation_handler.create_room(
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@ -94,6 +95,15 @@ class ShutdownRoomRestServlet(RestServlet):
# desirable in case the first attempt at blocking the room failed below.
await self.store.block_room(room_id, requester_user_id)
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
)
users = await self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
@ -105,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet):
try:
target_requester = create_requester(user_id)
await self.room_member_handler.update_membership(
_, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@ -115,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False,
)
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
)
await self.room_member_handler.forget(target_requester.user, room_id)
await self.room_member_handler.update_membership(

View file

@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
info = await self._room_creation_handler.create_room(
info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
event = await self.room_member_handler.update_membership(
event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
event = await self.event_creation_handler.create_and_send_nonmember_event(
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
ret = {} # type: dict
if event:
set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
if event_id:
set_tag("event_id", event_id)
ret = {"event_id": event_id}
return 200, ret
@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event = await self.event_creation_handler.create_and_send_nonmember_event(
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
event = await self.event_creation_handler.create_and_send_nonmember_event(
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,

View file

@ -177,7 +177,10 @@ class AuthRestServlet(RestServlet):
)
elif self._saml_enabled:
client_redirect_url = b""
# Some SAML identity providers (e.g. Google) require a
# RelayState parameter on requests. It is not necessary here, so
# pass in a dummy redirect URL (which will never get used).
client_redirect_url = b"unused"
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)

View file

@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
event = await self.event_creation_handler.create_and_send_nonmember_event(
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)

View file

@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@ -210,6 +211,7 @@ class HomeServer(object):
"storage",
"replication_streamer",
"replication_data_handler",
"replication_streams",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@ -583,6 +585,9 @@ class HomeServer(object):
def build_replication_data_handler(self):
return ReplicationDataHandler(self)
def build_replication_streams(self):
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -1,3 +1,5 @@
from typing import Dict
import twisted.internet
import synapse.api.auth
@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
from synapse.replication.tcp.streams import Stream
class HomeServer(object):
@property
@ -136,3 +139,5 @@ class HomeServer(object):
pass
def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool:
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass

View file

@ -83,10 +83,10 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
res = await self._event_creation_handler.create_and_send_nonmember_event(
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
return res
return event
@cached()
async def get_or_create_notice_room_for_user(self, user_id):
@ -143,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
info = await self._room_creation_handler.create_room(
info, _ = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,

View file

@ -66,9 +66,9 @@ class DataStores(object):
self.main = main_store_class(database, db_conn, hs)
# If we're on a process that can persist events (currently
# master), also instantiate a `PersistEventsStore`
if hs.config.worker.worker_app is None:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
self.persist_events = PersistEventsStore(
hs, database, self.main
)

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Tuple
from typing import List, Optional, Set, Tuple
from six import iteritems
@ -649,21 +649,31 @@ class DeviceWorkerStore(SQLBaseStore):
return results
@defer.inlineCallbacks
def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]):
def get_user_ids_requiring_device_list_resync(
self, user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
should resync the device lists for.
should resync the device lists for. If None is given instead of a list,
return every user that we should resync the device lists for.
Returns:
Deferred[Set[str]]
The IDs of users whose device lists need resync.
"""
rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync",
)
if user_ids:
rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync_with_iterable",
)
else:
rows = yield self.db.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync",
)
return {row["user_id"] for row in rows}
@ -679,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
self.db.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
)
self._invalidate_cache_and_stream(
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
return self.db.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
@ -959,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
@defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
)
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(
self, user_id, device_id, content, stream_id
):

View file

@ -138,10 +138,10 @@ class PersistEventsStore:
self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
# This should only exist on master for now
# This should only exist on instances that are configured to write
assert (
hs.config.worker.worker_app is None
), "Can only instantiate PersistEventsStore on master"
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
@_retry_on_integrity_error
@defer.inlineCallbacks
@ -1590,3 +1590,31 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
return stream_ordering

View file

@ -76,7 +76,7 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker_app is None:
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore):
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
to_1, so_1 = await self._get_event_ordering(event_id1)
to_2, so_2 = await self._get_event_ordering(event_id2)
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
def get_event_ordering(self, event_id):
res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],

View file

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -146,6 +146,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
threepids (list[dict]): List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
#
# - shouldn't there already be an entry for each reserved user (at least
# if they have been active recently)?
#
# - if it's important that the timestamp is kept up to date, why do we only
# run this at startup?
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@ -178,75 +187,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
question_marks = ",".join("?" * len(reserved_users))
in_clause, in_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", reserved_users
)
query_args.extend(reserved_users)
sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
else:
sql = base_sql
txn.execute(sql, query_args)
txn.execute(
"DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
% (in_clause,),
[thirty_days_ago] + in_clause_args,
)
if self._limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
# sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
# While Postgres does not require 'LIMIT', but also does not support
# Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be presents,
# while Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
if len(reserved_users) == 0:
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
)
"""
txn.execute(sql, ((self._max_mau_value),))
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
else:
# Must be >= 0 for postgres
num_of_non_reserved_users_to_remove = max(
self._max_mau_value - len(reserved_users), 0
# Limit must be >= 0 for postgres
num_of_non_reserved_users_to_remove = max(
self._max_mau_value - len(reserved_users), 0
)
# It is important to filter reserved users twice to guard
# against the case where the reserved user is present in the
# SELECT, meaning that a legitimate mau is deleted.
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
WHERE NOT %s
ORDER BY timestamp DESC
LIMIT ?
)
AND NOT %s
""" % (
in_clause,
in_clause,
)
# It is important to filter reserved users twice to guard
# against the case where the reserved user is present in the
# SELECT, meaning that a legitmate mau is deleted.
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
WHERE user_id NOT IN ({})
ORDER BY timestamp DESC
LIMIT ?
)
AND user_id NOT IN ({})
""".format(
question_marks, question_marks
)
query_args = (
in_clause_args
+ [num_of_non_reserved_users_to_remove]
+ in_clause_args
)
txn.execute(sql, query_args)
query_args = [
*reserved_users,
num_of_non_reserved_users_to_remove,
*reserved_users,
]
txn.execute(sql, query_args)
# It seems poor to invalidate the whole cache, Postgres supports
# It seems poor to invalidate the whole cache. Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking

View file

@ -1046,29 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""

View file

@ -786,3 +786,9 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
return await self.persist_events_store.locally_reject_invite(user_id, room_id)

View file

@ -186,10 +186,15 @@ def _check_yield_points(f: Callable, changes: List[str]):
)
raise Exception(err)
# the wrapped function yielded a Deferred: yield it back up to the parent
# inlineCallbacks().
try:
result = yield d
except Exception as e:
result = Failure(e)
except Exception:
# this will fish an earlier Failure out of the stack where possible, and
# thus is preferable to passing in an exeception to the Failure
# constructor, since it results in less stack-mangling.
result = Failure()
if current_context() != expected_context:

View file

@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
d = handler._remote_join(
None,
@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(

View file

@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
reactor.pump((1000,))
hs = self.setup_test_homeserver(
notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
hs.datastores = datastores

View file

@ -46,7 +46,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
return
@ -358,7 +358,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
config["email"] = {

View file

@ -27,20 +27,33 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
from tests.unittest import override_config
from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
hs_config["server_notices"] = {
"system_mxid_localpart": "server",
"system_mxid_display_name": "test display name",
"system_mxid_avatar_url": None,
"room_name": "Server Notices",
}
def default_config(self):
config = default_config("test")
hs = self.setup_test_homeserver(config=hs_config)
return hs
config.update(
{
"admin_contact": "mailto:user@test.com",
"limit_usage_by_mau": True,
"server_notices": {
"system_mxid_localpart": "server",
"system_mxid_display_name": "test display name",
"system_mxid_avatar_url": None,
"room_name": "Server Notices",
},
}
)
# apply any additional config which was specified via the override_config
# decorator.
if self._extra_config is not None:
config.update(self._extra_config)
return config
def prepare(self, reactor, clock, hs):
self.server_notices_sender = self.hs.get_server_notices_sender()
@ -60,7 +73,6 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
@ -68,21 +80,17 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
self.hs.config.admin_contact = "mailto:user@test.com"
def test_maybe_send_server_notice_to_user_flag_off(self):
"""Tests cases where the flags indicate nothing to do"""
# test hs disabled case
self.hs.config.hs_disabled = True
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
"""If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
# Test when mau limiting disabled
self.hs.config.hs_disabled = False
self.hs.config.limit_usage_by_mau = False
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@override_config({"limit_usage_by_mau": False})
def test_maybe_send_server_notice_to_user_flag_off(self):
"""If mau limiting is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
@ -153,13 +161,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
@override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
@ -170,12 +177,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
@override_config({"mau_limit_alerting": False})
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
@ -187,12 +193,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
@override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(

View file

@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {}))
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
def run_background_update(self):
@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {}))
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.user_consent_version = self.CONSENT_VERSION

View file

@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
@ -137,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@ -149,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
lots_of_users = 100
user_id = "@user:server"
@ -166,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@ -184,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))

View file

@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
info = self.get_success(room_creator.create_room(requester, {}))
info, _ = self.get_success(room_creator.create_room(requester, {}))
room_id = info["room_id"]
last_event = None

View file

@ -19,94 +19,106 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
def gen_3pids(count):
"""Generate `count` threepids as a list."""
return [
{"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
]
class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def default_config(self):
config = default_config("test")
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
hs.config.limit_usage_by_mau = True
hs.config.max_mau_value = 50
config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
# apply any additional config which was specified via the override_config
# decorator.
if self._extra_config is not None:
config.update(self._extra_config)
return config
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
return hs
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
self.hs.config.max_mau_value = 5
threepids = self.hs.config.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third
# which is a support user.
user1 = "@user1:server"
user1_email = "user1@matrix.org"
user1_email = threepids[0]["address"]
user2 = "@user2:server"
user2_email = "user2@matrix.org"
user2_email = threepids[1]["address"]
user3 = "@user3:server"
user3_email = "user3@matrix.org"
threepids = [
{"medium": "email", "address": user1_email},
{"medium": "email", "address": user2_email},
{"medium": "email", "address": user3_email},
]
self.hs.config.mau_limits_reserved_threepids = threepids
# -1 because user3 is a support user and does not count
user_num = len(threepids) - 1
self.store.register_user(user_id=user1, password_hash=None)
self.store.register_user(user_id=user2, password_hash=None)
self.store.register_user(
user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT
)
self.store.register_user(user_id=user1)
self.store.register_user(user_id=user2)
self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
self.pump()
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
# XXX why are we doing this here? this function is only run at startup
# so it is odd to re-run it here.
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.pump()
active_count = self.store.get_monthly_active_count()
# the number of users we expect will be counted against the mau limit
# -1 because user3 is a support user and does not count
user_num = len(threepids) - 1
# Test total counts, ensure user3 (support user) is not counted
self.assertEquals(self.get_success(active_count), user_num)
# Check the number of active users. Ensure user3 (support user) is not counted
active_count = self.get_success(self.store.get_monthly_active_count())
self.assertEquals(active_count, user_num)
# Test user is marked as active
# Test each of the registered users is marked as active
timestamp = self.store.user_last_seen_monthly_active(user1)
self.assertTrue(self.get_success(timestamp))
timestamp = self.store.user_last_seen_monthly_active(user2)
self.assertTrue(self.get_success(timestamp))
# Test that users are never removed from the db.
# Test that users with reserved 3pids are not removed from the MAU table
# XXX some of this is redundant. poking things into the config shouldn't
# work, and in any case it's not obvious what we expect to happen when
# we advance the reactor.
self.hs.config.max_mau_value = 0
self.reactor.advance(FORTY_DAYS)
self.hs.config.max_mau_value = 5
self.store.reap_monthly_active_users()
self.pump()
active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num)
# Test that regular users are removed from the db
# Add some more users and check they are counted as active
ru_count = 2
self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server")
self.pump()
active_count = self.store.get_monthly_active_count()
self.assertEqual(self.get_success(active_count), user_num + ru_count)
self.hs.config.max_mau_value = user_num
# now run the reaper and check that the number of active users is reduced
# to max_mau_value
self.store.reap_monthly_active_users()
self.pump()
active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num)
self.assertEquals(self.get_success(active_count), 3)
def test_can_insert_and_count_mau(self):
count = self.store.get_monthly_active_count()
@ -136,8 +148,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
result = self.store.user_last_seen_monthly_active(user_id3)
self.assertNotEqual(self.get_success(result), 0)
@override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
self.store.upsert_monthly_active_user("@user%d:server" % i)
@ -158,19 +170,19 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0)
# Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self):
""" Tests that reaping correctly handles reaping where reserved users are
present"""
self.hs.config.max_mau_value = 5
initial_users = 5
threepids = self.hs.config.mau_limits_reserved_threepids
initial_users = len(threepids)
reserved_user_number = initial_users - 1
threepids = []
for i in range(initial_users):
user = "@user%d:server" % i
email = "user%d@example.com" % i
email = "user%d@matrix.org" % i
self.get_success(self.store.upsert_monthly_active_user(user))
threepids.append({"medium": "email", "address": email})
# Need to ensure that the most recent entries in the
# monthly_active_users table are reserved
now = int(self.hs.get_clock().time_msec())
@ -182,7 +194,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
self.hs.config.mau_limits_reserved_threepids = threepids
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
@ -279,11 +290,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(self.get_success(count), 0)
# Note that the max_mau_value setting should not matter.
@override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
)
def test_track_monthly_users_without_cap(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.mau_stats_only = True
self.hs.config.max_mau_value = 1 # should not matter
count = self.store.get_monthly_active_count()
self.assertEqual(0, self.get_success(count))
@ -294,9 +305,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.store.get_monthly_active_count()
self.assertEqual(2, self.get_success(count))
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.mau_stats_only = False
self.store.upsert_monthly_active_user = Mock()
self.store.populate_monthly_active_users("@user:sever")

View file

@ -6,12 +6,13 @@ from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
class MessageAcceptTests(unittest.TestCase):
class MessageAcceptTests(unittest.HomeserverTestCase):
def setUp(self):
self.http_client = Mock()
@ -27,13 +28,13 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room = ensureDeferred(
room_deferred = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]
self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
self.store = self.homeserver.get_datastore()
@ -145,3 +146,63 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure the invalid event isn't there
extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
def test_retry_device_list_resync(self):
"""Tests that device lists are marked as stale if they couldn't be synced, and
that stale device lists are retried periodically.
"""
remote_user_id = "@john:test_remote"
remote_origin = "test_remote"
# Track the number of attempts to resync the user's device list.
self.resync_attempts = 0
# When this function is called, increment the number of resync attempts (only if
# we're querying devices for the right user ID), then raise a
# NotRetryingDestination error to fail the resync gracefully.
def query_user_devices(destination, user_id):
if user_id == remote_user_id:
self.resync_attempts += 1
raise NotRetryingDestination(0, 0, destination)
# Register the mock on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(side_effect=query_user_devices)
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
device_list_updater = self.homeserver.get_device_handler().device_list_updater
self.get_success(
device_list_updater.incoming_device_list_update(
origin=remote_origin,
edu_content={
"deleted": False,
"device_display_name": "Mobile",
"device_id": "QBUAZIFURK",
"prev_id": [5],
"stream_id": 6,
"user_id": remote_user_id,
},
)
)
# Check that there was one resync attempt.
self.assertEqual(self.resync_attempts, 1)
# Check that the resync attempt failed and caused the user's device list to be
# marked as stale.
need_resync = self.get_success(
store.get_user_ids_requiring_device_list_resync()
)
self.assertIn(remote_user_id, need_resync)
# Check that waiting for 30 seconds caused Synapse to retry resyncing the device
# list.
self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2)

View file

@ -17,47 +17,44 @@
import json
from mock import Mock
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
from tests.unittest import override_config
from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
def make_homeserver(self, reactor, clock):
def default_config(self):
config = default_config("test")
self.hs = self.setup_test_homeserver(
"red", http_client=None, federation_client=Mock()
config.update(
{
"registrations_require_3pid": [],
"limit_usage_by_mau": True,
"max_mau_value": 2,
"mau_trial_days": 0,
"server_notices": {
"system_mxid_localpart": "server",
"room_name": "Test Server Notice Room",
},
}
)
self.store = self.hs.get_datastore()
# apply any additional config which was specified via the override_config
# decorator.
if self._extra_config is not None:
config.update(self._extra_config)
self.hs.config.registrations_require_3pid = []
self.hs.config.enable_registration_captcha = False
self.hs.config.recaptcha_public_key = []
return config
self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2
self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room"
self.hs.config.mau_trial_days = 0
# AuthBlocking reads config options during hs creation. Recreate the
# hs' copy of AuthBlocking after we've updated config values above
self.auth_blocking = AuthBlocking(self.hs)
self.hs.get_auth()._auth_blocking = self.auth_blocking
return self.hs
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated
@ -66,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
# check we're testing what we think we are: there should be two active users
self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
@ -93,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1})
def test_trial_delay(self):
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@ -127,8 +126,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self):
self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
@ -176,11 +175,11 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config(
# max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
)
def test_tracked_but_not_limited(self):
self.auth_blocking._max_mau_value = 1 # should not matter
self.auth_blocking._limit_usage_by_mau = False
self.hs.config.mau_stats_only = True
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")

View file

@ -27,6 +27,7 @@ from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http.server import (
DirectServeResource,
JsonResource,
OptionsResource,
wrap_html_request_handler,
)
from synapse.http.site import SynapseSite, logger
@ -168,6 +169,86 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
class OptionsResourceTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
class DummyResource(Resource):
isLeaf = True
def render(self, request):
return request.path
# Setup a resource with some children.
self.resource = OptionsResource()
self.resource.putChild(b"res", DummyResource())
def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response."""
request, channel = make_request(self.reactor, method, path, shorthand=False)
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource.
site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
request.site = site
resource = site.getResourceFor(request)
# Finally, render the resource and return the channel.
render(request, resource, self.reactor)
return channel
def test_unknown_options_request(self):
"""An OPTIONS requests to an unknown URL still returns 200 OK."""
channel = self._make_request(b"OPTIONS", b"/foo/")
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.result["body"], b"{}")
# Ensure the correct CORS headers have been added
self.assertTrue(
channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
"has CORS Origin header",
)
self.assertTrue(
channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
"has CORS Methods header",
)
self.assertTrue(
channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
"has CORS Headers header",
)
def test_known_options_request(self):
"""An OPTIONS requests to an known URL still returns 200 OK."""
channel = self._make_request(b"OPTIONS", b"/res/")
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.result["body"], b"{}")
# Ensure the correct CORS headers have been added
self.assertTrue(
channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
"has CORS Origin header",
)
self.assertTrue(
channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
"has CORS Methods header",
)
self.assertTrue(
channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
"has CORS Headers header",
)
def test_unknown_request(self):
"""A non-OPTIONS request to an unknown URL should 404."""
channel = self._make_request(b"GET", b"/foo/")
self.assertEqual(channel.result["code"], b"404")
def test_known_request(self):
"""A non-OPTIONS request to an known URL should query the proper resource."""
channel = self._make_request(b"GET", b"/res/")
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.result["body"], b"/res/")
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource):
callback = None

View file

@ -193,6 +193,7 @@ commands = mypy \
synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
synapse/http/site.py \
synapse/logging/ \
synapse/metrics \
synapse/module_api \