Compare commits

..

7 commits

Author SHA1 Message Date
Andrew Morgan a7bf3f6f95 wip tests 2021-11-11 07:21:28 +00:00
Andrew Morgan 42f0ab1d84 Changelog 2021-11-10 16:11:01 +00:00
Andrew Morgan 6bcd8dee42 Add a to_device_stream_id column to the application_services_state table
This is for tracking the stream id that each application service has
been sent up to. In other words, there shouldn't be any need to process
stream ids below the recorded one here as the AS should have already
received them.

Note that there is no reliability built-in here. Reliability of delivery
if intended for a separate PR.
2021-11-10 16:11:01 +00:00
Andrew Morgan 43bdfbbc85 Add database method to fetch to-device messages by user_ids from db
This method is quite similar to the one below, except that it doesn't
support device ids, and supports querying with more than one user id,
both of which are relevant to application services.

The results are also formatted in a different data structure, so I'm
not sure how much we could really share here between the two methods.
2021-11-10 16:10:57 +00:00
Andrew Morgan 59b5eeefe1 Allow setting/getting stream id per appservice for to-device messages 2021-11-10 16:10:57 +00:00
Andrew Morgan 52c43d91f5 Add a new ephemeral AS handler for to_device message edus
Here we add the ability for the application service ephemeral message
processor to handle new events on the "to_device" stream.

We keep track of a stream id (token) per application service, and every
time a new to-device message comes in, for each appservice we pull the
messages between the last-recorded and current stream id and check
whether any of the messages are for a user in that appservice's user
namespace.

get_new_messages is implemented in the next commit.
2021-11-10 16:10:41 +00:00
Andrew Morgan 0772fc855b Add experimental config option to send to-device messages to AS's 2021-11-10 16:06:58 +00:00
40 changed files with 795 additions and 310 deletions

View file

@ -0,0 +1 @@
Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default.

View file

@ -1 +0,0 @@
Add missing type hints to `synapse.app`.

View file

@ -1 +0,0 @@
Fix a long-standing bug where uploading extremely thin images (e.g. 1000x1) would fail. Contributed by @Neeeflix.

View file

@ -1 +0,0 @@
Add type hints to `synapse._scripts`.

View file

@ -1 +0,0 @@
Add Single Sign-On, SAML and CAS pages to the documentation.

View file

@ -23,10 +23,10 @@
- [Structured Logging](structured_logging.md)
- [Templates](templates.md)
- [User Authentication](usage/configuration/user_authentication/README.md)
- [Single-Sign On](usage/configuration/user_authentication/single_sign_on/README.md)
- [Single-Sign On]()
- [OpenID Connect](openid.md)
- [SAML](usage/configuration/user_authentication/single_sign_on/saml.md)
- [CAS](usage/configuration/user_authentication/single_sign_on/cas.md)
- [SAML]()
- [CAS]()
- [SSO Mapping Providers](sso_mapping_providers.md)
- [Password Auth Providers](password_auth_providers.md)
- [JSON Web Tokens](jwt.md)

View file

@ -1,5 +0,0 @@
# Single Sign-On
Synapse supports single sign-on through the SAML, Open ID Connect or CAS protocols.
LDAP and other login methods are supported through first and third-party password
auth provider modules.

View file

@ -1,8 +0,0 @@
# CAS
Synapse supports authenticating users via the [Central Authentication
Service protocol](https://en.wikipedia.org/wiki/Central_Authentication_Service)
(CAS) natively.
Please see the `cas_config` and `sso` sections of the [Synapse configuration
file](../../../configuration/homeserver_sample_config.md) for more details.

View file

@ -1,8 +0,0 @@
# SAML
Synapse supports authenticating users via the [Security Assertion
Markup Language](https://en.wikipedia.org/wiki/Security_Assertion_Markup_Language)
(SAML) protocol natively.
Please see the `saml2_config` and `sso` sections of the [Synapse configuration
file](../../../configuration/homeserver_sample_config.md) for more details.

View file

@ -23,6 +23,24 @@ files =
# https://docs.python.org/3/library/re.html#re.X
exclude = (?x)
^(
|synapse/_scripts/register_new_matrix_user.py
|synapse/_scripts/review_recent_signups.py
|synapse/app/__init__.py
|synapse/app/_base.py
|synapse/app/admin_cmd.py
|synapse/app/appservice.py
|synapse/app/client_reader.py
|synapse/app/event_creator.py
|synapse/app/federation_reader.py
|synapse/app/federation_sender.py
|synapse/app/frontend_proxy.py
|synapse/app/generic_worker.py
|synapse/app/homeserver.py
|synapse/app/media_repository.py
|synapse/app/phone_stats_home.py
|synapse/app/pusher.py
|synapse/app/synchrotron.py
|synapse/app/user_dir.py
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py
@ -163,9 +181,6 @@ exclude = (?x)
[mypy-synapse.api.*]
disallow_untyped_defs = True
[mypy-synapse.app.*]
disallow_untyped_defs = True
[mypy-synapse.crypto.*]
disallow_untyped_defs = True

View file

@ -110,7 +110,6 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
"types-Pillow>=8.3.4",
"types-pyOpenSSL>=20.0.7",
"types-PyYAML>=5.4.10",
"types-requests>=2.26.0",
"types-setuptools>=57.4.0",
]

View file

@ -1,6 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,23 +19,22 @@ import hashlib
import hmac
import logging
import sys
from typing import Callable, Optional
import requests as _requests
import yaml
def request_registration(
user: str,
password: str,
server_location: str,
shared_secret: str,
admin: bool = False,
user_type: Optional[str] = None,
user,
password,
server_location,
shared_secret,
admin=False,
user_type=None,
requests=_requests,
_print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit,
) -> None:
_print=print,
exit=sys.exit,
):
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
@ -67,13 +65,13 @@ def request_registration(
mac.update(b"\x00")
mac.update(user_type.encode("utf8"))
hex_mac = mac.hexdigest()
mac = mac.hexdigest()
data = {
"nonce": nonce,
"username": user,
"password": password,
"mac": hex_mac,
"mac": mac,
"admin": admin,
"user_type": user_type,
}
@ -93,17 +91,10 @@ def request_registration(
_print("Success!")
def register_new_user(
user: str,
password: str,
server_location: str,
shared_secret: str,
admin: Optional[bool],
user_type: Optional[str],
) -> None:
def register_new_user(user, password, server_location, shared_secret, admin, user_type):
if not user:
try:
default_user: Optional[str] = getpass.getuser()
default_user = getpass.getuser()
except Exception:
default_user = None
@ -132,8 +123,8 @@ def register_new_user(
sys.exit(1)
if admin is None:
admin_inp = input("Make admin [no]: ")
if admin_inp in ("y", "yes", "true"):
admin = input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
@ -143,7 +134,7 @@ def register_new_user(
)
def main() -> None:
def main():
logging.captureWarnings(True)

View file

@ -92,7 +92,7 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
return user_infos
def main() -> None:
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
@ -142,8 +142,7 @@ def main() -> None:
engine = create_engine(database_config.config)
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
# This generates a type of Cursor, not LoggingTransaction.
user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type]
user_infos = get_recent_users(db_conn.cursor(), since_ms)
for user_info in user_infos:
if exclude_users_with_email and user_info.emails:

View file

@ -13,7 +13,6 @@
# limitations under the License.
import logging
import sys
from typing import Container
from synapse import python_dependencies # noqa: E402
@ -28,9 +27,7 @@ except python_dependencies.DependencyException as e:
sys.exit(1)
def check_bind_error(
e: Exception, address: str, bind_addresses: Container[str]
) -> None:
def check_bind_error(e, address, bind_addresses):
"""
This method checks an exception occurred while binding on 0.0.0.0.
If :: is specified in the bind addresses a warning is shown.
@ -41,9 +38,9 @@ def check_bind_error(
When binding on 0.0.0.0 after :: this can safely be ignored.
Args:
e: Exception that was caught.
address: Address on which binding was attempted.
bind_addresses: Addresses on which the service listens.
e (Exception): Exception that was caught.
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
if address == "0.0.0.0" and "::" in bind_addresses:
logger.warning(

View file

@ -22,27 +22,13 @@ import socket
import sys
import traceback
import warnings
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
NoReturn,
Tuple,
cast,
)
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
import twisted
from twisted.internet import defer, error, reactor as _reactor
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP
from twisted.internet.protocol import ServerFactory
from twisted.internet.tcp import Port
from twisted.internet import defer, error, reactor
from twisted.logger import LoggingFile, LogLevel
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.threadpool import ThreadPool
@ -62,7 +48,6 @@ from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.types import ISynapseReactor
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
from synapse.util.gai_resolver import GAIResolver
@ -72,44 +57,33 @@ from synapse.util.versionstring import get_version_string
if TYPE_CHECKING:
from synapse.server import HomeServer
# Twisted injects the global reactor to make it easier to import, this confuses
# mypy which thinks it is a module. Tell it that it a more proper type.
reactor = cast(ISynapseReactor, _reactor)
logger = logging.getLogger(__name__)
# list of tuples of function, args list, kwargs dict
_sighup_callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
] = []
_sighup_callbacks = []
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
def register_sighup(func, *args, **kwargs):
"""
Register a function to be called when a SIGHUP occurs.
Args:
func: Function to be called when sent a SIGHUP signal.
func (function): Function to be called when sent a SIGHUP signal.
*args, **kwargs: args and kwargs to be passed to the target function.
"""
_sighup_callbacks.append((func, args, kwargs))
def start_worker_reactor(
appname: str,
config: HomeServerConfig,
run_command: Callable[[], None] = reactor.run,
) -> None:
def start_worker_reactor(appname, config, run_command=reactor.run):
"""Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor. Pulls configuration from the 'worker' settings in 'config'.
Args:
appname: application name which will be sent to syslog
config: config object
run_command: callable that actually runs the reactor
appname (str): application name which will be sent to syslog
config (synapse.config.Config): config object
run_command (Callable[]): callable that actually runs the reactor
"""
logger = logging.getLogger(config.worker.worker_app)
@ -127,32 +101,32 @@ def start_worker_reactor(
def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Tuple[int, int, int],
pid_file: str,
daemonize: bool,
print_pidfile: bool,
logger: logging.Logger,
run_command: Callable[[], None] = reactor.run,
) -> None:
appname,
soft_file_limit,
gc_thresholds,
pid_file,
daemonize,
print_pidfile,
logger,
run_command=reactor.run,
):
"""Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor
Args:
appname: application name which will be sent to syslog
soft_file_limit:
appname (str): application name which will be sent to syslog
soft_file_limit (int):
gc_thresholds:
pid_file: name of pid file to write to if daemonize is True
daemonize: true to run the reactor in a background process
print_pidfile: whether to print the pid file, if daemonize is True
logger: logger instance to pass to Daemonize
run_command: callable that actually runs the reactor
pid_file (str): name of pid file to write to if daemonize is True
daemonize (bool): true to run the reactor in a background process
print_pidfile (bool): whether to print the pid file, if daemonize is True
logger (logging.Logger): logger instance to pass to Daemonize
run_command (Callable[]): callable that actually runs the reactor
"""
def run() -> None:
def run():
logger.info("Running")
setup_jemalloc_stats()
change_resource_limit(soft_file_limit)
@ -211,7 +185,7 @@ def redirect_stdio_to_logs() -> None:
print("Redirected stdout/stderr to logs")
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
"""Register a callback with the reactor, to be called once it is running
This can be used to initialise parts of the system which require an asynchronous
@ -221,7 +195,7 @@ def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> N
will exit.
"""
async def wrapper() -> None:
async def wrapper():
try:
await cb(*args, **kwargs)
except Exception:
@ -250,7 +224,7 @@ def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> N
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
def listen_metrics(bind_addresses, port):
"""
Start Prometheus metrics server.
"""
@ -262,11 +236,11 @@ def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
def listen_manhole(
bind_addresses: Collection[str],
bind_addresses: Iterable[str],
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
) -> None:
):
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
# suppress the warning for now.
@ -285,18 +259,12 @@ def listen_manhole(
)
def listen_tcp(
bind_addresses: Collection[str],
port: int,
factory: ServerFactory,
reactor: IReactorTCP = reactor,
backlog: int = 50,
) -> List[Port]:
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
Returns:
list of twisted.internet.tcp.Port listening for TCP connections
list[twisted.internet.tcp.Port]: listening for TCP connections
"""
r = []
for address in bind_addresses:
@ -305,19 +273,12 @@ def listen_tcp(
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
# IReactorTCP returns an object implementing IListeningPort from listenTCP,
# but we know it will be a Port instance.
return r # type: ignore[return-value]
return r
def listen_ssl(
bind_addresses: Collection[str],
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,
reactor: IReactorSSL = reactor,
backlog: int = 50,
) -> List[Port]:
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
"""
Create an TLS-over-TCP socket for a port and several addresses
@ -333,13 +294,10 @@ def listen_ssl(
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
# IReactorSSL incorrectly declares that an int is returned from listenSSL,
# it actually returns an object implementing IListeningPort, but we know it
# will be a Port instance.
return r # type: ignore[return-value]
return r
def refresh_certificate(hs: "HomeServer") -> None:
def refresh_certificate(hs: "HomeServer"):
"""
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
@ -371,7 +329,7 @@ def refresh_certificate(hs: "HomeServer") -> None:
logger.info("Context factories updated.")
async def start(hs: "HomeServer") -> None:
async def start(hs: "HomeServer"):
"""
Start a Synapse server or worker.
@ -402,7 +360,7 @@ async def start(hs: "HomeServer") -> None:
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
def handle_sighup(*args: Any, **kwargs: Any) -> None:
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
@ -415,7 +373,7 @@ async def start(hs: "HomeServer") -> None:
# We defer running the sighup handlers until next reactor tick. This
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args: Any, **kwargs: Any) -> None:
def run_sighup(*args, **kwargs):
# `callFromThread` should be "signal safe" as well as thread
# safe.
reactor.callFromThread(handle_sighup, *args, **kwargs)
@ -478,8 +436,12 @@ async def start(hs: "HomeServer") -> None:
atexit.register(gc.freeze)
def setup_sentry(hs: "HomeServer") -> None:
"""Enable sentry integration, if enabled in configuration"""
def setup_sentry(hs: "HomeServer"):
"""Enable sentry integration, if enabled in configuration
Args:
hs
"""
if not hs.config.metrics.sentry_enabled:
return
@ -504,7 +466,7 @@ def setup_sentry(hs: "HomeServer") -> None:
scope.set_tag("worker_name", name)
def setup_sdnotify(hs: "HomeServer") -> None:
def setup_sdnotify(hs: "HomeServer"):
"""Adds process state hooks to tell systemd what we are up to."""
# Tell systemd our state, if we're using it. This will silently fail if
@ -519,7 +481,7 @@ def setup_sdnotify(hs: "HomeServer") -> None:
sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
def sdnotify(state: bytes) -> None:
def sdnotify(state):
"""
Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
@ -528,7 +490,7 @@ def sdnotify(state: bytes) -> None:
package which many OSes don't include as a matter of principle.
Args:
state: notification to send
state (bytes): notification to send
"""
if not isinstance(state, bytes):
raise TypeError("sdnotify should be called with a bytes")

View file

@ -17,7 +17,6 @@ import logging
import os
import sys
import tempfile
from typing import List, Optional
from twisted.internet import defer, task
@ -26,7 +25,6 @@ from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.events import EventBase
from synapse.handlers.admin import ExfiltrationWriter
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
@ -42,7 +40,6 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.server import HomeServer
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.types import StateMap
from synapse.util.logcontext import LoggingContext
from synapse.util.versionstring import get_version_string
@ -68,11 +65,16 @@ class AdminCmdSlavedStore(
class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore
DATASTORE_CLASS = AdminCmdSlavedStore
async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None:
"""Export data for a user."""
async def export_data_command(hs: HomeServer, args):
"""Export data for a user.
Args:
hs
args (argparse.Namespace)
"""
user_id = args.user_id
directory = args.output_directory
@ -90,12 +92,12 @@ class FileExfiltrationWriter(ExfiltrationWriter):
Note: This writes to disk on the main reactor thread.
Args:
user_id: The user whose data is being exfiltrated.
directory: The directory to write the data to, if None then will write
to a temporary directory.
user_id (str): The user whose data is being exfiltrated.
directory (str|None): The directory to write the data to, if None then
will write to a temporary directory.
"""
def __init__(self, user_id: str, directory: Optional[str] = None):
def __init__(self, user_id, directory=None):
self.user_id = user_id
if directory:
@ -109,7 +111,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
if list(os.listdir(self.base_directory)):
raise Exception("Directory must be empty")
def write_events(self, room_id: str, events: List[EventBase]) -> None:
def write_events(self, room_id, events):
room_directory = os.path.join(self.base_directory, "rooms", room_id)
os.makedirs(room_directory, exist_ok=True)
events_file = os.path.join(room_directory, "events")
@ -118,9 +120,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for event in events:
print(json.dumps(event.get_pdu_json()), file=f)
def write_state(
self, room_id: str, event_id: str, state: StateMap[EventBase]
) -> None:
def write_state(self, room_id, event_id, state):
room_directory = os.path.join(self.base_directory, "rooms", room_id)
state_directory = os.path.join(room_directory, "state")
os.makedirs(state_directory, exist_ok=True)
@ -131,9 +131,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for event in state.values():
print(json.dumps(event.get_pdu_json()), file=f)
def write_invite(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
) -> None:
def write_invite(self, room_id, event, state):
self.write_events(room_id, [event])
# We write the invite state somewhere else as they aren't full events
@ -147,9 +145,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for event in state.values():
print(json.dumps(event), file=f)
def write_knock(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
) -> None:
def write_knock(self, room_id, event, state):
self.write_events(room_id, [event])
# We write the knock state somewhere else as they aren't full events
@ -163,11 +159,11 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for event in state.values():
print(json.dumps(event), file=f)
def finished(self) -> str:
def finished(self):
return self.base_directory
def start(config_options: List[str]) -> None:
def start(config_options):
parser = argparse.ArgumentParser(description="Synapse Admin Command")
HomeServerConfig.add_arguments_to_parser(parser)
@ -235,7 +231,7 @@ def start(config_options: List[str]) -> None:
# We also make sure that `_base.start` gets run before we actually run the
# command.
async def run() -> None:
async def run():
with LoggingContext("command"):
await _base.start(ss)
await args.func(ss, args)

View file

@ -14,10 +14,11 @@
# limitations under the License.
import logging
import sys
from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional
from twisted.internet import address
from twisted.web.resource import Resource
from twisted.web.resource import IResource
from twisted.web.server import Request
import synapse
import synapse.events
@ -43,7 +44,7 @@ from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource, OptionsResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
@ -118,7 +119,6 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import JsonDict
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.versionstring import get_version_string
@ -143,9 +143,7 @@ class KeyUploadServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
self.main_uri = hs.config.worker.worker_main_http_uri
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@ -189,8 +187,9 @@ class KeyUploadServlet(RestServlet):
# If the header exists, add to the comma-separated list of the first
# instance of the header. Otherwise, generate a new header.
if x_forwarded_for:
x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host]
x_forwarded_for.extend(x_forwarded_for[1:])
x_forwarded_for = [
x_forwarded_for[0] + b", " + previous_host
] + x_forwarded_for[1:]
else:
x_forwarded_for = [previous_host]
headers[b"X-Forwarded-For"] = x_forwarded_for
@ -254,16 +253,13 @@ class GenericWorkerSlavedStore(
SessionStore,
BaseSlavedStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
server_name: str
config: HomeServerConfig
pass
class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore
DATASTORE_CLASS = GenericWorkerSlavedStore
def _listen_http(self, listener_config: ListenerConfig) -> None:
def _listen_http(self, listener_config: ListenerConfig):
port = listener_config.port
bind_addresses = listener_config.bind_addresses
@ -271,10 +267,10 @@ class GenericWorkerServer(HomeServer):
site_tag = listener_config.http_options.tag
if site_tag is None:
site_tag = str(port)
site_tag = port
# We always include a health resource.
resources: Dict[str, Resource] = {"/health": HealthResource()}
resources: Dict[str, IResource] = {"/health": HealthResource()}
for res in listener_config.http_options.resources:
for name in res.names:
@ -390,7 +386,7 @@ class GenericWorkerServer(HomeServer):
logger.info("Synapse worker now listening on port %d", port)
def start_listening(self) -> None:
def start_listening(self):
for listener in self.config.worker.worker_listeners:
if listener.type == "http":
self._listen_http(listener)
@ -415,7 +411,7 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self)
def start(config_options: List[str]) -> None:
def start(config_options):
try:
config = HomeServerConfig.load_config("Synapse worker", config_options)
except ConfigError as e:

View file

@ -16,10 +16,10 @@
import logging
import os
import sys
from typing import Dict, Iterable, Iterator, List
from typing import Iterator
from twisted.internet.tcp import Port
from twisted.web.resource import EncodingResourceWrapper, Resource
from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -76,27 +76,23 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.homeserver")
def gz_wrap(r: Resource) -> Resource:
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore
DATASTORE_CLASS = DataStore
def _listener_http(
self, config: HomeServerConfig, listener_config: ListenerConfig
) -> Iterable[Port]:
def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
port = listener_config.port
bind_addresses = listener_config.bind_addresses
tls = listener_config.tls
# Must exist since this is an HTTP listener.
assert listener_config.http_options is not None
site_tag = listener_config.http_options.tag
if site_tag is None:
site_tag = str(port)
# We always include a health resource.
resources: Dict[str, Resource] = {"/health": HealthResource()}
resources = {"/health": HealthResource()}
for res in listener_config.http_options.resources:
for name in res.names:
@ -115,7 +111,7 @@ class SynapseHomeServer(HomeServer):
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
)
handler = handler_cls(config, module_api)
if isinstance(handler, Resource):
if IResource.providedBy(handler):
resource = handler
elif hasattr(handler, "handle_request"):
resource = AdditionalResource(self, handler.handle_request)
@ -132,7 +128,7 @@ class SynapseHomeServer(HomeServer):
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
elif STATIC_PREFIX in resources:
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
else:
@ -149,8 +145,6 @@ class SynapseHomeServer(HomeServer):
)
if tls:
# refresh_certificate should have been called before this.
assert self.tls_server_context_factory is not None
ports = listen_ssl(
bind_addresses,
port,
@ -171,21 +165,20 @@ class SynapseHomeServer(HomeServer):
return ports
def _configure_named_resource(
self, name: str, compress: bool = False
) -> Dict[str, Resource]:
def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource
Args:
name: named resource: one of "client", "federation", etc
compress: whether to enable gzip compression for this resource
name (str): named resource: one of "client", "federation", etc
compress (bool): whether to enable gzip compression for this
resource
Returns:
map from path to HTTP resource
dict[str, Resource]: map from path to HTTP resource
"""
resources: Dict[str, Resource] = {}
resources = {}
if name == "client":
client_resource: Resource = ClientRestResource(self)
client_resource = ClientRestResource(self)
if compress:
client_resource = gz_wrap(client_resource)
@ -214,7 +207,7 @@ class SynapseHomeServer(HomeServer):
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
consent_resource: Resource = ConsentResource(self)
consent_resource = ConsentResource(self)
if compress:
consent_resource = gz_wrap(consent_resource)
resources.update({"/_matrix/consent": consent_resource})
@ -284,7 +277,7 @@ class SynapseHomeServer(HomeServer):
return resources
def start_listening(self) -> None:
def start_listening(self):
if self.config.redis.redis_enabled:
# If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client
@ -310,9 +303,7 @@ class SynapseHomeServer(HomeServer):
ReplicationStreamProtocolFactory(self),
)
for s in services:
self.get_reactor().addSystemEventTrigger(
"before", "shutdown", s.stopListening
)
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener.type == "metrics":
if not self.config.metrics.enable_metrics:
logger.warning(
@ -327,13 +318,14 @@ class SynapseHomeServer(HomeServer):
logger.warning("Unrecognized listener type: %s", listener.type)
def setup(config_options: List[str]) -> SynapseHomeServer:
def setup(config_options):
"""
Args:
config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`.
config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`.
Returns:
A homeserver instance.
HomeServer
"""
try:
config = HomeServerConfig.load_or_generate_config(
@ -372,7 +364,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
except Exception as e:
handle_startup_exception(e)
async def start() -> None:
async def start():
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
@ -412,15 +404,39 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
yield ":\n %s" % (e.msg,)
parent_e = e.__cause__
e = e.__cause__
indent = 1
while parent_e:
while e:
indent += 1
yield ":\n%s%s" % (" " * indent, str(parent_e))
parent_e = parent_e.__cause__
yield ":\n%s%s" % (" " * indent, str(e))
e = e.__cause__
def run(hs: HomeServer) -> None:
def run(hs: HomeServer):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
from cProfile import Profile
from threading import current_thread
def profiled(*args, **kargs):
profile = Profile()
profile.enable()
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
profile.dump_stats(
"/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
)
return profiled
from twisted.python.threadpool import ThreadPool
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
_base.start_reactor(
"synapse-homeserver",
soft_file_limit=hs.config.server.soft_file_limit,
@ -432,7 +448,7 @@ def run(hs: HomeServer) -> None:
)
def main() -> None:
def main():
with LoggingContext("main"):
# check base requirements
check_requirements()

View file

@ -15,12 +15,11 @@ import logging
import math
import resource
import sys
from typing import TYPE_CHECKING, List, Sized, Tuple
from typing import TYPE_CHECKING
from prometheus_client import Gauge
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -29,7 +28,7 @@ logger = logging.getLogger("synapse.app.homeserver")
# Contains the list of processes we will be monitoring
# currently either 0 or 1
_stats_process: List[Tuple[int, "resource.struct_rusage"]] = []
_stats_process = []
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
@ -46,15 +45,9 @@ registered_reserved_users_mau_gauge = Gauge(
@wrap_as_background_process("phone_stats_home")
async def phone_stats_home(
hs: "HomeServer",
stats: JsonDict,
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
) -> None:
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
# Ensure the homeserver has started.
assert hs.start_time is not None
uptime = int(now - hs.start_time)
if uptime < 0:
uptime = 0
@ -153,15 +146,15 @@ async def phone_stats_home(
logger.warning("Error reporting stats: %s", e)
def start_phone_stats_home(hs: "HomeServer") -> None:
def start_phone_stats_home(hs: "HomeServer"):
"""
Start the background tasks which report phone home stats.
"""
clock = hs.get_clock()
stats: JsonDict = {}
stats = {}
def performance_stats_init() -> None:
def performance_stats_init():
_stats_process.clear()
_stats_process.append(
(int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
@ -177,10 +170,10 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
hs.get_datastore().reap_monthly_active_users()
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
async def generate_monthly_active_users():
current_mau_count = 0
current_mau_count_by_service = {}
reserved_users: Sized = ()
reserved_users = ()
store = hs.get_datastore()
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
current_mau_count = await store.get_monthly_active_count()

View file

@ -46,3 +46,11 @@ class ExperimentalConfig(Config):
# MSC3266 (room summary api)
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
# MSC2409 (this setting only relates to optionally sending to-device messages).
# Presence, typing and read receipt EDUs are already sent to application services that
# have opted in to receive them. This setting, if enabled, adds to-device messages
# to that list.
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

View file

@ -234,7 +234,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
@abc.abstractmethod
def write_invite(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
self, room_id: str, event: EventBase, state: StateMap[dict]
) -> None:
"""Write an invite for the room, with associated invite state.
@ -248,7 +248,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
@abc.abstractmethod
def write_knock(
self, room_id: str, event: EventBase, state: StateMap[EventBase]
self, room_id: str, event: EventBase, state: StateMap[dict]
) -> None:
"""Write a knock for the room, with associated knock state.

View file

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from prometheus_client import Counter
@ -42,6 +51,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_TO_DEVICE_MESSAGES_PER_AS_TRANSACTION = 100
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
@ -55,6 +66,9 @@ class ApplicationServicesHandler:
self.clock = hs.get_clock()
self.notify_appservices = hs.config.appservice.notify_appservices
self.event_sources = hs.get_event_sources()
self.msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self.current_max = 0
self.is_processing = False
@ -199,8 +213,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
value for `stream_key` will cause this function to return early.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
"to_device_key". Any other value for `stream_key` will cause this function
to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@ -216,10 +231,6 @@ class ApplicationServicesHandler:
if not self.notify_appservices:
return
# Ignore any unsupported streams
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
return
# Assert that new_token is an integer (and not a RoomStreamToken).
# All of the supported streams that this function handles use an
# integer to track progress (rather than a RoomStreamToken - a
@ -233,6 +244,13 @@ class ApplicationServicesHandler:
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == "to_device_key"
and not self.msc2409_to_device_messages_enabled
):
return
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
@ -307,6 +325,30 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
elif (
stream_key == "to_device_key"
and new_token is not None
and self.msc2409_to_device_messages_enabled
):
# Retrieve an iterable of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
events, max_stream_token = await self._handle_to_device(
service, new_token, users
)
if events:
self.scheduler.submit_ephemeral_events_for_as(
service, events
)
# TODO: If max_stream_token != new_token, schedule another transaction immediately,
# instead of waiting for another to-device to be sent?
# https://github.com/matrix-org/synapse/issues/11150#issuecomment-960726449
# Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice(
service, "to_device", max_stream_token
)
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@ -440,6 +482,88 @@ class ApplicationServicesHandler:
return events
async def _handle_to_device(
self,
service: ApplicationService,
new_token: int,
users: Collection[Union[str, UserID]],
) -> Tuple[List[JsonDict], int]:
"""
Given an application service, determine which events it should receive
from those between the last-recorded typing event stream token for this
appservice and the given stream token.
Args:
service: The application service to check for which events it should receive.
new_token: The latest to-device event stream token.
users: The users that should receive new to-device messages.
Returns:
A two-tuple containing the following:
* A list of JSON dictionaries containing data derived from the typing events
that should be sent to the given application service.
* The last-processed to-device message's stream id. If this does not equal
`new_token` then there are likely more events to send.
"""
# Get the stream token that this application service has processed up until
from_key = await self.store.get_type_stream_id_for_appservice(
service, "to_device"
)
# Filter out users that this appservice is not interested in
users_appservice_is_interested_in: List[str] = []
for user in users:
if isinstance(user, UserID):
user = user.to_string()
if service.is_interested_in_user(user):
users_appservice_is_interested_in.append(user)
if not users_appservice_is_interested_in:
# Return early if the AS was not interested in any of these users
return [], new_token
# Retrieve the to-device messages for each user
(
recipient_user_id_device_id_to_messages,
# Record the maximum stream token of the retrieved messages that
# this function was able to pull before hitting the max to-device
# message count limit.
# We return this value later, to ensure we only record the stream
# token we managed to get up to, so that the rest can be sent later.
max_stream_token,
) = await self.store.get_new_messages(
users_appservice_is_interested_in,
from_key,
new_token,
limit=MAX_TO_DEVICE_MESSAGES_PER_AS_TRANSACTION,
)
# According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
# to the event JSON so that the application service will know which user/device
# combination this messages was intended for.
#
# So we mangle this dict into a flat list of to-device messages with the relevant
# user ID and device ID embedded inside each message dict.
message_payload: List[JsonDict] = []
for (
user_id,
device_id,
), messages in recipient_user_id_device_id_to_messages.items():
for message_json in messages:
# Remove 'message_id' from the to-device message, as it's an internal ID
message_json.pop("message_id", None)
message_payload.append(
{
"to_user_id": user_id,
"to_device_id": device_id,
**message_json,
}
)
return message_payload, max_stream_token
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.

View file

@ -3,7 +3,7 @@ import time
from logging import Handler, LogRecord
from logging.handlers import MemoryHandler
from threading import Thread
from typing import Optional, cast
from typing import Optional
from twisted.internet.interfaces import IReactorCore
@ -56,7 +56,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
if reactor is None:
from twisted.internet import reactor as global_reactor
reactor_to_use = cast(IReactorCore, global_reactor)
reactor_to_use = global_reactor # type: ignore[assignment]
else:
reactor_to_use = reactor

View file

@ -31,7 +31,7 @@ import attr
import jinja2
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.resource import IResource
from synapse.api.errors import SynapseError
from synapse.events import EventBase
@ -196,7 +196,7 @@ class ModuleApi:
"""
return self._password_auth_provider.register_password_auth_provider_callbacks
def register_web_resource(self, path: str, resource: Resource):
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.
This function should be called during initialisation of the module.

View file

@ -444,15 +444,23 @@ class Notifier:
self.notify_replication()
# Notify appservices.
try:
self.appservice_handler.notify_interested_services_ephemeral(
stream_key,
new_token,
users,
)
except Exception:
logger.exception("Error notifying application services of event")
# Notify appservices of updates in ephemeral event streams.
# Only the following streams are currently supported.
if stream_key in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
):
# Notify appservices.
try:
self.appservice_handler.notify_interested_services_ephemeral(
stream_key,
new_token,
users,
)
except Exception:
logger.exception("Error notifying application services of event")
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened

View file

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
from prometheus_client import Counter
from twisted.internet.protocol import ServerFactory
from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import PositionCommand
@ -38,7 +38,7 @@ stream_updates_counter = Counter(
logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(ServerFactory):
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections."""
def __init__(self, hs: "HomeServer"):

View file

@ -101,8 +101,8 @@ class Thumbnailer:
fits within the given rectangle::
(w_in / h_in) = (w_out / h_out)
w_out = max(min(w_max, h_max * (w_in / h_in)), 1)
h_out = max(min(h_max, w_max * (h_in / w_in)), 1)
w_out = min(w_max, h_max * (w_in / h_in))
h_out = min(h_max, w_max * (h_in / w_in))
Args:
max_width: The largest possible width.
@ -110,9 +110,9 @@ class Thumbnailer:
"""
if max_width * self.height < max_height * self.width:
return max_width, max((max_width * self.height) // self.width, 1)
return max_width, (max_width * self.height) // self.width
else:
return max((max_height * self.width) // self.height, 1), max_height
return (max_height * self.width) // self.height, max_height
def _resize(self, width: int, height: int) -> Image.Image:
# 1-bit or 8-bit color palette images need converting to RGB

View file

@ -33,10 +33,9 @@ from typing import (
cast,
)
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.tcp import Port
import twisted.internet.tcp
from twisted.web.iweb import IPolicyForHTTPS
from twisted.web.resource import Resource
from twisted.web.resource import IResource
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
@ -207,7 +206,7 @@ class HomeServer(metaclass=abc.ABCMeta):
Attributes:
config (synapse.config.homeserver.HomeserverConfig):
_listening_services (list[Port]): TCP ports that
_listening_services (list[twisted.internet.tcp.Port]): TCP ports that
we are listening on to provide HTTP services.
"""
@ -226,8 +225,6 @@ class HomeServer(metaclass=abc.ABCMeta):
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
tls_server_context_factory: Optional[IOpenSSLContextFactory]
def __init__(
self,
hostname: str,
@ -250,7 +247,7 @@ class HomeServer(metaclass=abc.ABCMeta):
# the key we use to sign events and requests
self.signing_key = config.key.signing_key[0]
self.config = config
self._listening_services: List[Port] = []
self._listening_services: List[twisted.internet.tcp.Port] = []
self.start_time: Optional[int] = None
self._instance_id = random_string(5)
@ -260,10 +257,10 @@ class HomeServer(metaclass=abc.ABCMeta):
self.datastores: Optional[Databases] = None
self._module_web_resources: Dict[str, Resource] = {}
self._module_web_resources: Dict[str, IResource] = {}
self._module_web_resources_consumed = False
def register_module_web_resource(self, path: str, resource: Resource):
def register_module_web_resource(self, path: str, resource: IResource):
"""Allows a module to register a web resource to be served at the given path.
If multiple modules register a resource for the same path, the module that

View file

@ -387,7 +387,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
if type not in ("read_receipt", "presence"):
if type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@ -414,7 +414,7 @@ class ApplicationServiceTransactionWorkerStore(
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if stream_type not in ("read_receipt", "presence"):
if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)

View file

@ -13,13 +13,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@ -116,6 +120,97 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
async def get_new_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
limit: int = 100,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve to-device messages for a given set of user IDs.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Note that multiple messages can have the same stream id. Stream ids are
unique to *messages*, but there might be multiple recipients of a message, and
thus multiple entries in the device_inbox table with the same stream id.
Args:
user_ids: The users to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
limit: The maximum number of to-device messages to return.
Returns:
A tuple containing the following:
* A list of to-device messages
* The stream id that to-device messages were processed up to. The calling
function can check this value against `to_stream_id` to see if we hit
the given limit when fetching messages.
"""
# Bail out if none of these users have any messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
break
else:
return {}, to_stream_id
def get_new_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given users that are between
# the given stream id bounds
# Scope to only the given users. We need to use this method as doing so is
# different across database engines.
many_clause_sql, many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
sql = f"""
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
WHERE {many_clause_sql}
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id, limit))
# Create a dictionary of (user ID, device ID) -> list of messages that
# that device is meant to receive.
recipient_user_id_device_id_to_messages = {}
stream_pos = to_stream_id
total_messages_processed = 0
for row in txn:
# Record the last-processed stream position, to return later.
# Note that we process messages here in order of ascending stream id.
stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
recipient_user_id_device_id_to_messages.setdefault(
(recipient_user_id, recipient_device_id), []
).append(message_dict)
total_messages_processed += 1
# If we don't end up hitting the limit, we still want to return the equivalent
# value of `to_stream_id` to the calling function. This is needed as we'll
# be processing up to `to_stream_id`, without necessarily fetching a message
# with a stream id of `to_stream_id`.
if total_messages_processed < limit:
stream_pos = to_stream_id
return recipient_user_id_device_id_to_messages, stream_pos
return await self.db_pool.runInteraction(
"get_new_messages", get_new_messages_txn
)
async def get_new_messages_for_device(
self,
user_id: str,

View file

@ -0,0 +1,18 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what to_device stream id that this application
-- service has been caught up to.
ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;

View file

@ -38,7 +38,6 @@ from zope.interface import Interface
from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
IReactorSSL,
IReactorTCP,
IReactorThreads,
IReactorTime,
@ -67,7 +66,6 @@ JsonDict = Dict[str, Any]
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
IReactorTCP,
IReactorSSL,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,

View file

@ -31,13 +31,13 @@ from typing import (
Set,
TypeVar,
Union,
cast,
)
import attr
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.base import ReactorBase
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
@ -271,7 +271,8 @@ class Linearizer:
if not clock:
from twisted.internet import reactor
clock = Clock(cast(IReactorTime, reactor))
assert isinstance(reactor, ReactorBase)
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count

View file

@ -92,9 +92,9 @@ def _resource_id(resource: Resource, path_seg: bytes) -> str:
the mapping should looks like _resource_id(A,C) = B.
Args:
resource: The *parent* Resourceb
path_seg: The name of the child Resource to be attached.
resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
A unique string which can be a key to the child Resource.
str: A unique string which can be a key to the child Resource.
"""
return "%s-%r" % (resource, path_seg)

View file

@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
from twisted.conch.ssh.keys import Key
from twisted.cred import checkers, portal
from twisted.internet import defer
from twisted.internet.protocol import ServerFactory
from twisted.internet.protocol import Factory
from synapse.config.server import ManholeConfig
@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----"""
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
"""Starts a ssh listener with password authentication using
the given username and password. Clients connecting to the ssh
listener will find themselves in a colored python shell with
@ -105,8 +105,7 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
# ConchFactory is a Factory, not a ServerFactory, but they are identical.
return factory # type: ignore[return-value]
return factory
class SynapseManhole(ColoredManhole):

View file

@ -253,20 +253,59 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)
def test_notify_interested_services_ephemeral(self):
def test_notify_interested_services_ephemeral_read_receipt(self):
"""
Test sending ephemeral events to the appservice handler are scheduled
Test sending read receipts to the appservice handler are scheduled
to be pushed out to interested appservices, and that the stream ID is
updated accordingly.
"""
# Create an application service that is guaranteed to be interested in
# any new events
interested_service = self._mkservice(is_interested=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
# State that this application service has received up until stream ID 579
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579
)
# Set up a dummy event that should be sent to the application service
event = Mock(event_id="event_1")
self.event_source.sources.receipt.get_new_events_as.return_value = (
make_awaitable(([event], None))
)
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
interested_service, [event]
)
self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
interested_service,
"read_receipt",
580,
)
def test_notify_interested_services_ephemeral_to_device(self):
"""
Test sending read receipts to the appservice handler are scheduled
to be pushed out to interested appservices, and that the stream ID is
updated accordingly.
"""
# Create an application service that is guaranteed to be interested in
# any new events
interested_service = self._mkservice(is_interested=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
# State that this application service has received up until stream ID 579
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579
)
# Set up a dummy event that should be sent to the application service
event = Mock(event_id="event_1")
self.event_source.sources.receipt.get_new_events_as.return_value = (
make_awaitable(([event], None))

View file

@ -0,0 +1,258 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, Optional, Tuple, Union
from unittest.mock import Mock
import synapse.rest.admin
import synapse.storage
from synapse.appservice import ApplicationService
from synapse.rest.client import login, receipts, room
from synapse.util.stringutils import random_string
from synapse.types import JsonDict
from tests import unittest
class ApplicationServiceEphemeralEventsTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
receipts.register_servlets,
]
def prepare(self, reactor, clock, hs):
# Mock the application service scheduler so that we can track any outgoing transactions
self.mock_scheduler = Mock()
self.mock_scheduler.submit_ephemeral_events_for_as = Mock()
hs.get_application_service_handler().scheduler = self.mock_scheduler
self.user1 = self.register_user("user1", "password")
self.token1 = self.login("user1", "password")
self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password")
def test_application_services_receive_read_receipts(self):
"""
Test that when a user sends a read receipt in a room with another
user, and that is in an application service's user namespace, that
application service will receive that read receipt.
"""
(
interested_service,
_,
) = self._register_interested_and_uninterested_application_services()
# Create a room with both user1 and user2
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
)
self.helper.join(room_id, self.user2, tok=self.token2)
# Have user2 send a message into the room
response_dict = self.helper.send(room_id, body="read me", tok=self.token2)
# Have user1 send a read receipt for the message with an empty body
self._send_read_receipt(room_id, response_dict["event_id"], self.token1)
# user2 should have been the recipient of that read receipt.
# Check if our application service - that is interested in user2 - received
# the read receipt as part of an AS transaction.
#
# The uninterested application service should not have been notified.
last_call = self.mock_scheduler.submit_ephemeral_events_for_as.call_args_list[
0
]
service, events = last_call[0]
self.assertEqual(service, interested_service)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]["type"], "m.receipt")
self.assertEqual(events[0]["room_id"], room_id)
# Assert that this was a read receipt from user1
read_receipts = list(events[0]["content"].values())
self.assertIn(self.user1, read_receipts[0]["m.read"])
# Clear mock stats
self.mock_scheduler.submit_ephemeral_events_for_as.reset_mock()
# Send 2 pairs of messages + read receipts
response_dict_1 = self.helper.send(room_id, body="read me1", tok=self.token2)
response_dict_2 = self.helper.send(room_id, body="read me2", tok=self.token2)
self._send_read_receipt(room_id, response_dict_1["event_id"], self.token1)
self._send_read_receipt(room_id, response_dict_2["event_id"], self.token1)
# Assert each transaction that was sent to the application service is as expected
self.assertEqual(2, self.mock_scheduler.submit_ephemeral_events_for_as.call_count)
first_call, second_call = self.mock_scheduler.submit_ephemeral_events_for_as.call_args_list
service, events = first_call[0]
self.assertEqual(service, interested_service)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]["type"], "m.receipt")
self.assertEqual(
self._event_id_from_read_receipt(events[0]), response_dict_1["event_id"]
)
service, events = second_call[0]
self.assertEqual(service, interested_service)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]["type"], "m.receipt")
self.assertEqual(
self._event_id_from_read_receipt(events[0]), response_dict_2["event_id"]
)
def test_application_services_receive_to_device(self):
"""
Test that when a user sends a to-device message to another user, and
that is in an application service's user namespace, that application
service will receive it.
"""
(
interested_service,
_,
) = self._register_interested_and_uninterested_application_services()
# Create a room with both user1 and user2
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
)
self.helper.join(room_id, self.user2, tok=self.token2)
# Have user2 send a typing notification into the room
response_dict = self.helper.send(room_id, body="read me", tok=self.token2)
# Have user1 send a read receipt for the message with an empty body
channel = self.make_request(
"POST",
"/rooms/%s/receipt/m.read/%s" % (room_id, response_dict["event_id"]),
access_token=self.token1,
)
self.assertEqual(channel.code, 200)
# user2 should have been the recipient of that read receipt.
# Check if our application service - that is interested in user2 - received
# the read receipt as part of an AS transaction.
#
# The uninterested application service should not have been notified.
service, events = self.mock_scheduler.submit_ephemeral_events_for_as.call_args[
0
]
self.assertEqual(service, interested_service)
self.assertEqual(events[0]["type"], "m.receipt")
self.assertEqual(events[0]["room_id"], room_id)
# Assert that this was a read receipt from user1
read_receipts = list(events[0]["content"].values())
self.assertIn(self.user1, read_receipts[0]["m.read"])
def _register_interested_and_uninterested_application_services(
self,
) -> Tuple[ApplicationService, ApplicationService]:
# Create an application service with exclusive interest in user2
interested_service = self._make_application_service(
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@user2:.+",
"exclusive": True,
}
],
},
)
uninterested_service = self._make_application_service()
# Register this application service, along with another, uninterested one
services = [
uninterested_service,
interested_service,
]
self.hs.get_datastore().get_app_services = Mock(return_value=services)
return interested_service, uninterested_service
def _make_application_service(
self,
namespaces: Optional[
Dict[
Union[
ApplicationService.NS_USERS,
ApplicationService.NS_ALIASES,
ApplicationService.NS_ROOMS,
],
Iterable[Dict],
]
] = None,
supports_ephemeral: Optional[bool] = True,
) -> ApplicationService:
return ApplicationService(
token=None,
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
namespaces=namespaces,
supports_ephemeral=supports_ephemeral,
)
def _send_read_receipt(self, room_id: str, event_id_to_read: str, tok: str) -> None:
"""
Send a read receipt of an event into a room.
Args:
room_id: The room to event is part of.
event_id_to_read: The ID of the event being read.
tok: The access token of the sender.
"""
channel = self.make_request(
"POST",
"/rooms/%s/receipt/m.read/%s" % (room_id, event_id_to_read),
access_token=tok,
content="{}",
)
self.assertEqual(channel.code, 200, channel.json_body)
def _event_id_from_read_receipt(self, read_receipt_dict: JsonDict):
"""
Extracts the first event ID from a read receipt. Read receipt dictionaries
are in the form:
{
'type': 'm.receipt',
'room_id': '!PEzCqHyycBVxqMKIjI:test',
'content': {
'$DETIeTEH651c1N7sP_j-YZiaQqCaayHhYwmhZDVWDY8': { # We want this
'm.read': {
'@user1:test': {
'ts': 1300,
'hidden': False
}
}
}
}
}
Args:
read_receipt_dict: The dictionary returned from a POST read receipt call.
Returns:
The (first) event ID the read receipt refers to.
"""
return list(read_receipt_dict["content"].keys())[0]
# TODO: Test that ephemeral messages aren't sent to application services that have
# ephemeral: false

View file

@ -230,7 +230,7 @@ class RestHelper:
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
) -> JsonDict:
if body is None:
body = "body_text_here"
@ -257,7 +257,7 @@ class RestHelper:
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
) -> JsonDict:
if txn_id is None:
txn_id = "m%s" % (str(time.time()))

View file

@ -320,7 +320,7 @@ class HomeserverTestCase(TestCase):
def wait_for_background_updates(self) -> None:
"""Block until all background database updates have completed."""
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
):
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1

View file

@ -236,7 +236,7 @@ def setup_test_homeserver(
else:
database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
"args": {"database": "test.db", "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs: