Compare commits
7 commits
develop
...
anoa/e2e_a
Author | SHA1 | Date | |
---|---|---|---|
|
a7bf3f6f95 | ||
|
42f0ab1d84 | ||
|
6bcd8dee42 | ||
|
43bdfbbc85 | ||
|
59b5eeefe1 | ||
|
52c43d91f5 | ||
|
0772fc855b |
1
changelog.d/11215.feature
Normal file
1
changelog.d/11215.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default.
|
|
@ -1 +0,0 @@
|
|||
Add missing type hints to `synapse.app`.
|
|
@ -1 +0,0 @@
|
|||
Fix a long-standing bug where uploading extremely thin images (e.g. 1000x1) would fail. Contributed by @Neeeflix.
|
|
@ -1 +0,0 @@
|
|||
Add type hints to `synapse._scripts`.
|
|
@ -1 +0,0 @@
|
|||
Add Single Sign-On, SAML and CAS pages to the documentation.
|
|
@ -23,10 +23,10 @@
|
|||
- [Structured Logging](structured_logging.md)
|
||||
- [Templates](templates.md)
|
||||
- [User Authentication](usage/configuration/user_authentication/README.md)
|
||||
- [Single-Sign On](usage/configuration/user_authentication/single_sign_on/README.md)
|
||||
- [Single-Sign On]()
|
||||
- [OpenID Connect](openid.md)
|
||||
- [SAML](usage/configuration/user_authentication/single_sign_on/saml.md)
|
||||
- [CAS](usage/configuration/user_authentication/single_sign_on/cas.md)
|
||||
- [SAML]()
|
||||
- [CAS]()
|
||||
- [SSO Mapping Providers](sso_mapping_providers.md)
|
||||
- [Password Auth Providers](password_auth_providers.md)
|
||||
- [JSON Web Tokens](jwt.md)
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Single Sign-On
|
||||
|
||||
Synapse supports single sign-on through the SAML, Open ID Connect or CAS protocols.
|
||||
LDAP and other login methods are supported through first and third-party password
|
||||
auth provider modules.
|
|
@ -1,8 +0,0 @@
|
|||
# CAS
|
||||
|
||||
Synapse supports authenticating users via the [Central Authentication
|
||||
Service protocol](https://en.wikipedia.org/wiki/Central_Authentication_Service)
|
||||
(CAS) natively.
|
||||
|
||||
Please see the `cas_config` and `sso` sections of the [Synapse configuration
|
||||
file](../../../configuration/homeserver_sample_config.md) for more details.
|
|
@ -1,8 +0,0 @@
|
|||
# SAML
|
||||
|
||||
Synapse supports authenticating users via the [Security Assertion
|
||||
Markup Language](https://en.wikipedia.org/wiki/Security_Assertion_Markup_Language)
|
||||
(SAML) protocol natively.
|
||||
|
||||
Please see the `saml2_config` and `sso` sections of the [Synapse configuration
|
||||
file](../../../configuration/homeserver_sample_config.md) for more details.
|
21
mypy.ini
21
mypy.ini
|
@ -23,6 +23,24 @@ files =
|
|||
# https://docs.python.org/3/library/re.html#re.X
|
||||
exclude = (?x)
|
||||
^(
|
||||
|synapse/_scripts/register_new_matrix_user.py
|
||||
|synapse/_scripts/review_recent_signups.py
|
||||
|synapse/app/__init__.py
|
||||
|synapse/app/_base.py
|
||||
|synapse/app/admin_cmd.py
|
||||
|synapse/app/appservice.py
|
||||
|synapse/app/client_reader.py
|
||||
|synapse/app/event_creator.py
|
||||
|synapse/app/federation_reader.py
|
||||
|synapse/app/federation_sender.py
|
||||
|synapse/app/frontend_proxy.py
|
||||
|synapse/app/generic_worker.py
|
||||
|synapse/app/homeserver.py
|
||||
|synapse/app/media_repository.py
|
||||
|synapse/app/phone_stats_home.py
|
||||
|synapse/app/pusher.py
|
||||
|synapse/app/synchrotron.py
|
||||
|synapse/app/user_dir.py
|
||||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/__init__.py
|
||||
|synapse/storage/databases/main/account_data.py
|
||||
|
@ -163,9 +181,6 @@ exclude = (?x)
|
|||
[mypy-synapse.api.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.app.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.crypto.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -110,7 +110,6 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
|
|||
"types-Pillow>=8.3.4",
|
||||
"types-pyOpenSSL>=20.0.7",
|
||||
"types-PyYAML>=5.4.10",
|
||||
"types-requests>=2.26.0",
|
||||
"types-setuptools>=57.4.0",
|
||||
]
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -20,23 +19,22 @@ import hashlib
|
|||
import hmac
|
||||
import logging
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import requests as _requests
|
||||
import yaml
|
||||
|
||||
|
||||
def request_registration(
|
||||
user: str,
|
||||
password: str,
|
||||
server_location: str,
|
||||
shared_secret: str,
|
||||
admin: bool = False,
|
||||
user_type: Optional[str] = None,
|
||||
user,
|
||||
password,
|
||||
server_location,
|
||||
shared_secret,
|
||||
admin=False,
|
||||
user_type=None,
|
||||
requests=_requests,
|
||||
_print: Callable[[str], None] = print,
|
||||
exit: Callable[[int], None] = sys.exit,
|
||||
) -> None:
|
||||
_print=print,
|
||||
exit=sys.exit,
|
||||
):
|
||||
|
||||
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
|
||||
|
||||
|
@ -67,13 +65,13 @@ def request_registration(
|
|||
mac.update(b"\x00")
|
||||
mac.update(user_type.encode("utf8"))
|
||||
|
||||
hex_mac = mac.hexdigest()
|
||||
mac = mac.hexdigest()
|
||||
|
||||
data = {
|
||||
"nonce": nonce,
|
||||
"username": user,
|
||||
"password": password,
|
||||
"mac": hex_mac,
|
||||
"mac": mac,
|
||||
"admin": admin,
|
||||
"user_type": user_type,
|
||||
}
|
||||
|
@ -93,17 +91,10 @@ def request_registration(
|
|||
_print("Success!")
|
||||
|
||||
|
||||
def register_new_user(
|
||||
user: str,
|
||||
password: str,
|
||||
server_location: str,
|
||||
shared_secret: str,
|
||||
admin: Optional[bool],
|
||||
user_type: Optional[str],
|
||||
) -> None:
|
||||
def register_new_user(user, password, server_location, shared_secret, admin, user_type):
|
||||
if not user:
|
||||
try:
|
||||
default_user: Optional[str] = getpass.getuser()
|
||||
default_user = getpass.getuser()
|
||||
except Exception:
|
||||
default_user = None
|
||||
|
||||
|
@ -132,8 +123,8 @@ def register_new_user(
|
|||
sys.exit(1)
|
||||
|
||||
if admin is None:
|
||||
admin_inp = input("Make admin [no]: ")
|
||||
if admin_inp in ("y", "yes", "true"):
|
||||
admin = input("Make admin [no]: ")
|
||||
if admin in ("y", "yes", "true"):
|
||||
admin = True
|
||||
else:
|
||||
admin = False
|
||||
|
@ -143,7 +134,7 @@ def register_new_user(
|
|||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def main():
|
||||
|
||||
logging.captureWarnings(True)
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]:
|
|||
return user_infos
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
|
@ -142,8 +142,7 @@ def main() -> None:
|
|||
engine = create_engine(database_config.config)
|
||||
|
||||
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
|
||||
# This generates a type of Cursor, not LoggingTransaction.
|
||||
user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type]
|
||||
user_infos = get_recent_users(db_conn.cursor(), since_ms)
|
||||
|
||||
for user_info in user_infos:
|
||||
if exclude_users_with_email and user_info.emails:
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
from typing import Container
|
||||
|
||||
from synapse import python_dependencies # noqa: E402
|
||||
|
||||
|
@ -28,9 +27,7 @@ except python_dependencies.DependencyException as e:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def check_bind_error(
|
||||
e: Exception, address: str, bind_addresses: Container[str]
|
||||
) -> None:
|
||||
def check_bind_error(e, address, bind_addresses):
|
||||
"""
|
||||
This method checks an exception occurred while binding on 0.0.0.0.
|
||||
If :: is specified in the bind addresses a warning is shown.
|
||||
|
@ -41,9 +38,9 @@ def check_bind_error(
|
|||
When binding on 0.0.0.0 after :: this can safely be ignored.
|
||||
|
||||
Args:
|
||||
e: Exception that was caught.
|
||||
address: Address on which binding was attempted.
|
||||
bind_addresses: Addresses on which the service listens.
|
||||
e (Exception): Exception that was caught.
|
||||
address (str): Address on which binding was attempted.
|
||||
bind_addresses (list): Addresses on which the service listens.
|
||||
"""
|
||||
if address == "0.0.0.0" and "::" in bind_addresses:
|
||||
logger.warning(
|
||||
|
|
|
@ -22,27 +22,13 @@ import socket
|
|||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
|
||||
|
||||
from cryptography.utils import CryptographyDeprecationWarning
|
||||
from typing_extensions import NoReturn
|
||||
|
||||
import twisted
|
||||
from twisted.internet import defer, error, reactor as _reactor
|
||||
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
from twisted.internet.tcp import Port
|
||||
from twisted.internet import defer, error, reactor
|
||||
from twisted.logger import LoggingFile, LogLevel
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
@ -62,7 +48,6 @@ from synapse.logging.context import PreserveLoggingContext
|
|||
from synapse.metrics import register_threadpool
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
||||
from synapse.util.daemonize import daemonize_process
|
||||
from synapse.util.gai_resolver import GAIResolver
|
||||
|
@ -72,44 +57,33 @@ from synapse.util.versionstring import get_version_string
|
|||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
# Twisted injects the global reactor to make it easier to import, this confuses
|
||||
# mypy which thinks it is a module. Tell it that it a more proper type.
|
||||
reactor = cast(ISynapseReactor, _reactor)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# list of tuples of function, args list, kwargs dict
|
||||
_sighup_callbacks: List[
|
||||
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
|
||||
] = []
|
||||
_sighup_callbacks = []
|
||||
|
||||
|
||||
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
|
||||
def register_sighup(func, *args, **kwargs):
|
||||
"""
|
||||
Register a function to be called when a SIGHUP occurs.
|
||||
|
||||
Args:
|
||||
func: Function to be called when sent a SIGHUP signal.
|
||||
func (function): Function to be called when sent a SIGHUP signal.
|
||||
*args, **kwargs: args and kwargs to be passed to the target function.
|
||||
"""
|
||||
_sighup_callbacks.append((func, args, kwargs))
|
||||
|
||||
|
||||
def start_worker_reactor(
|
||||
appname: str,
|
||||
config: HomeServerConfig,
|
||||
run_command: Callable[[], None] = reactor.run,
|
||||
) -> None:
|
||||
def start_worker_reactor(appname, config, run_command=reactor.run):
|
||||
"""Run the reactor in the main process
|
||||
|
||||
Daemonizes if necessary, and then configures some resources, before starting
|
||||
the reactor. Pulls configuration from the 'worker' settings in 'config'.
|
||||
|
||||
Args:
|
||||
appname: application name which will be sent to syslog
|
||||
config: config object
|
||||
run_command: callable that actually runs the reactor
|
||||
appname (str): application name which will be sent to syslog
|
||||
config (synapse.config.Config): config object
|
||||
run_command (Callable[]): callable that actually runs the reactor
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(config.worker.worker_app)
|
||||
|
@ -127,32 +101,32 @@ def start_worker_reactor(
|
|||
|
||||
|
||||
def start_reactor(
|
||||
appname: str,
|
||||
soft_file_limit: int,
|
||||
gc_thresholds: Tuple[int, int, int],
|
||||
pid_file: str,
|
||||
daemonize: bool,
|
||||
print_pidfile: bool,
|
||||
logger: logging.Logger,
|
||||
run_command: Callable[[], None] = reactor.run,
|
||||
) -> None:
|
||||
appname,
|
||||
soft_file_limit,
|
||||
gc_thresholds,
|
||||
pid_file,
|
||||
daemonize,
|
||||
print_pidfile,
|
||||
logger,
|
||||
run_command=reactor.run,
|
||||
):
|
||||
"""Run the reactor in the main process
|
||||
|
||||
Daemonizes if necessary, and then configures some resources, before starting
|
||||
the reactor
|
||||
|
||||
Args:
|
||||
appname: application name which will be sent to syslog
|
||||
soft_file_limit:
|
||||
appname (str): application name which will be sent to syslog
|
||||
soft_file_limit (int):
|
||||
gc_thresholds:
|
||||
pid_file: name of pid file to write to if daemonize is True
|
||||
daemonize: true to run the reactor in a background process
|
||||
print_pidfile: whether to print the pid file, if daemonize is True
|
||||
logger: logger instance to pass to Daemonize
|
||||
run_command: callable that actually runs the reactor
|
||||
pid_file (str): name of pid file to write to if daemonize is True
|
||||
daemonize (bool): true to run the reactor in a background process
|
||||
print_pidfile (bool): whether to print the pid file, if daemonize is True
|
||||
logger (logging.Logger): logger instance to pass to Daemonize
|
||||
run_command (Callable[]): callable that actually runs the reactor
|
||||
"""
|
||||
|
||||
def run() -> None:
|
||||
def run():
|
||||
logger.info("Running")
|
||||
setup_jemalloc_stats()
|
||||
change_resource_limit(soft_file_limit)
|
||||
|
@ -211,7 +185,7 @@ def redirect_stdio_to_logs() -> None:
|
|||
print("Redirected stdout/stderr to logs")
|
||||
|
||||
|
||||
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
|
||||
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
|
||||
"""Register a callback with the reactor, to be called once it is running
|
||||
|
||||
This can be used to initialise parts of the system which require an asynchronous
|
||||
|
@ -221,7 +195,7 @@ def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> N
|
|||
will exit.
|
||||
"""
|
||||
|
||||
async def wrapper() -> None:
|
||||
async def wrapper():
|
||||
try:
|
||||
await cb(*args, **kwargs)
|
||||
except Exception:
|
||||
|
@ -250,7 +224,7 @@ def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> N
|
|||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||
|
||||
|
||||
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
|
||||
def listen_metrics(bind_addresses, port):
|
||||
"""
|
||||
Start Prometheus metrics server.
|
||||
"""
|
||||
|
@ -262,11 +236,11 @@ def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
|
|||
|
||||
|
||||
def listen_manhole(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: Iterable[str],
|
||||
port: int,
|
||||
manhole_settings: ManholeConfig,
|
||||
manhole_globals: dict,
|
||||
) -> None:
|
||||
):
|
||||
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
|
||||
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
|
||||
# suppress the warning for now.
|
||||
|
@ -285,18 +259,12 @@ def listen_manhole(
|
|||
)
|
||||
|
||||
|
||||
def listen_tcp(
|
||||
bind_addresses: Collection[str],
|
||||
port: int,
|
||||
factory: ServerFactory,
|
||||
reactor: IReactorTCP = reactor,
|
||||
backlog: int = 50,
|
||||
) -> List[Port]:
|
||||
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
|
||||
"""
|
||||
Create a TCP socket for a port and several addresses
|
||||
|
||||
Returns:
|
||||
list of twisted.internet.tcp.Port listening for TCP connections
|
||||
list[twisted.internet.tcp.Port]: listening for TCP connections
|
||||
"""
|
||||
r = []
|
||||
for address in bind_addresses:
|
||||
|
@ -305,19 +273,12 @@ def listen_tcp(
|
|||
except error.CannotListenError as e:
|
||||
check_bind_error(e, address, bind_addresses)
|
||||
|
||||
# IReactorTCP returns an object implementing IListeningPort from listenTCP,
|
||||
# but we know it will be a Port instance.
|
||||
return r # type: ignore[return-value]
|
||||
return r
|
||||
|
||||
|
||||
def listen_ssl(
|
||||
bind_addresses: Collection[str],
|
||||
port: int,
|
||||
factory: ServerFactory,
|
||||
context_factory: IOpenSSLContextFactory,
|
||||
reactor: IReactorSSL = reactor,
|
||||
backlog: int = 50,
|
||||
) -> List[Port]:
|
||||
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
|
||||
):
|
||||
"""
|
||||
Create an TLS-over-TCP socket for a port and several addresses
|
||||
|
||||
|
@ -333,13 +294,10 @@ def listen_ssl(
|
|||
except error.CannotListenError as e:
|
||||
check_bind_error(e, address, bind_addresses)
|
||||
|
||||
# IReactorSSL incorrectly declares that an int is returned from listenSSL,
|
||||
# it actually returns an object implementing IListeningPort, but we know it
|
||||
# will be a Port instance.
|
||||
return r # type: ignore[return-value]
|
||||
return r
|
||||
|
||||
|
||||
def refresh_certificate(hs: "HomeServer") -> None:
|
||||
def refresh_certificate(hs: "HomeServer"):
|
||||
"""
|
||||
Refresh the TLS certificates that Synapse is using by re-reading them from
|
||||
disk and updating the TLS context factories to use them.
|
||||
|
@ -371,7 +329,7 @@ def refresh_certificate(hs: "HomeServer") -> None:
|
|||
logger.info("Context factories updated.")
|
||||
|
||||
|
||||
async def start(hs: "HomeServer") -> None:
|
||||
async def start(hs: "HomeServer"):
|
||||
"""
|
||||
Start a Synapse server or worker.
|
||||
|
||||
|
@ -402,7 +360,7 @@ async def start(hs: "HomeServer") -> None:
|
|||
if hasattr(signal, "SIGHUP"):
|
||||
|
||||
@wrap_as_background_process("sighup")
|
||||
def handle_sighup(*args: Any, **kwargs: Any) -> None:
|
||||
def handle_sighup(*args, **kwargs):
|
||||
# Tell systemd our state, if we're using it. This will silently fail if
|
||||
# we're not using systemd.
|
||||
sdnotify(b"RELOADING=1")
|
||||
|
@ -415,7 +373,7 @@ async def start(hs: "HomeServer") -> None:
|
|||
# We defer running the sighup handlers until next reactor tick. This
|
||||
# is so that we're in a sane state, e.g. flushing the logs may fail
|
||||
# if the sighup happens in the middle of writing a log entry.
|
||||
def run_sighup(*args: Any, **kwargs: Any) -> None:
|
||||
def run_sighup(*args, **kwargs):
|
||||
# `callFromThread` should be "signal safe" as well as thread
|
||||
# safe.
|
||||
reactor.callFromThread(handle_sighup, *args, **kwargs)
|
||||
|
@ -478,8 +436,12 @@ async def start(hs: "HomeServer") -> None:
|
|||
atexit.register(gc.freeze)
|
||||
|
||||
|
||||
def setup_sentry(hs: "HomeServer") -> None:
|
||||
"""Enable sentry integration, if enabled in configuration"""
|
||||
def setup_sentry(hs: "HomeServer"):
|
||||
"""Enable sentry integration, if enabled in configuration
|
||||
|
||||
Args:
|
||||
hs
|
||||
"""
|
||||
|
||||
if not hs.config.metrics.sentry_enabled:
|
||||
return
|
||||
|
@ -504,7 +466,7 @@ def setup_sentry(hs: "HomeServer") -> None:
|
|||
scope.set_tag("worker_name", name)
|
||||
|
||||
|
||||
def setup_sdnotify(hs: "HomeServer") -> None:
|
||||
def setup_sdnotify(hs: "HomeServer"):
|
||||
"""Adds process state hooks to tell systemd what we are up to."""
|
||||
|
||||
# Tell systemd our state, if we're using it. This will silently fail if
|
||||
|
@ -519,7 +481,7 @@ def setup_sdnotify(hs: "HomeServer") -> None:
|
|||
sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
|
||||
|
||||
|
||||
def sdnotify(state: bytes) -> None:
|
||||
def sdnotify(state):
|
||||
"""
|
||||
Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
|
||||
|
||||
|
@ -528,7 +490,7 @@ def sdnotify(state: bytes) -> None:
|
|||
package which many OSes don't include as a matter of principle.
|
||||
|
||||
Args:
|
||||
state: notification to send
|
||||
state (bytes): notification to send
|
||||
"""
|
||||
if not isinstance(state, bytes):
|
||||
raise TypeError("sdnotify should be called with a bytes")
|
||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
from twisted.internet import defer, task
|
||||
|
||||
|
@ -26,7 +25,6 @@ from synapse.app import _base
|
|||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.admin import ExfiltrationWriter
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
|
@ -42,7 +40,6 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
|||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||
from synapse.types import StateMap
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
|
@ -68,11 +65,16 @@ class AdminCmdSlavedStore(
|
|||
|
||||
|
||||
class AdminCmdServer(HomeServer):
|
||||
DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore
|
||||
DATASTORE_CLASS = AdminCmdSlavedStore
|
||||
|
||||
|
||||
async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None:
|
||||
"""Export data for a user."""
|
||||
async def export_data_command(hs: HomeServer, args):
|
||||
"""Export data for a user.
|
||||
|
||||
Args:
|
||||
hs
|
||||
args (argparse.Namespace)
|
||||
"""
|
||||
|
||||
user_id = args.user_id
|
||||
directory = args.output_directory
|
||||
|
@ -90,12 +92,12 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
|||
Note: This writes to disk on the main reactor thread.
|
||||
|
||||
Args:
|
||||
user_id: The user whose data is being exfiltrated.
|
||||
directory: The directory to write the data to, if None then will write
|
||||
to a temporary directory.
|
||||
user_id (str): The user whose data is being exfiltrated.
|
||||
directory (str|None): The directory to write the data to, if None then
|
||||
will write to a temporary directory.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str, directory: Optional[str] = None):
|
||||
def __init__(self, user_id, directory=None):
|
||||
self.user_id = user_id
|
||||
|
||||
if directory:
|
||||
|
@ -109,7 +111,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
|||
if list(os.listdir(self.base_directory)):
|
||||
raise Exception("Directory must be empty")
|
||||
|
||||
def write_events(self, room_id: str, events: List[EventBase]) -> None:
|
||||
def write_events(self, room_id, events):
|
||||
room_directory = os.path.join(self.base_directory, "rooms", room_id)
|
||||
os.makedirs(room_directory, exist_ok=True)
|
||||
events_file = os.path.join(room_directory, "events")
|
||||
|
@ -118,9 +120,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
|||
for event in events:
|
||||
print(json.dumps(event.get_pdu_json()), file=f)
|
||||
|
||||
def write_state(
|
||||
self, room_id: str, event_id: str, state: StateMap[EventBase]
|
||||
) -> None:
|
||||
def write_state(self, room_id, event_id, state):
|
||||
room_directory = os.path.join(self.base_directory, "rooms", room_id)
|
||||
state_directory = os.path.join(room_directory, "state")
|
||||
os.makedirs(state_directory, exist_ok=True)
|
||||
|
@ -131,9 +131,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
|||
for event in state.values():
|
||||
print(json.dumps(event.get_pdu_json()), file=f)
|
||||
|
||||
def write_invite(
|
||||
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||
) -> None:
|
||||
def write_invite(self, room_id, event, state):
|
||||
self.write_events(room_id, [event])
|
||||
|
||||
# We write the invite state somewhere else as they aren't full events
|
||||
|
@ -147,9 +145,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
|||
for event in state.values():
|
||||
print(json.dumps(event), file=f)
|
||||
|
||||
def write_knock(
|
||||
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||
) -> None:
|
||||
def write_knock(self, room_id, event, state):
|
||||
self.write_events(room_id, [event])
|
||||
|
||||
# We write the knock state somewhere else as they aren't full events
|
||||
|
@ -163,11 +159,11 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
|||
for event in state.values():
|
||||
print(json.dumps(event), file=f)
|
||||
|
||||
def finished(self) -> str:
|
||||
def finished(self):
|
||||
return self.base_directory
|
||||
|
||||
|
||||
def start(config_options: List[str]) -> None:
|
||||
def start(config_options):
|
||||
parser = argparse.ArgumentParser(description="Synapse Admin Command")
|
||||
HomeServerConfig.add_arguments_to_parser(parser)
|
||||
|
||||
|
@ -235,7 +231,7 @@ def start(config_options: List[str]) -> None:
|
|||
# We also make sure that `_base.start` gets run before we actually run the
|
||||
# command.
|
||||
|
||||
async def run() -> None:
|
||||
async def run():
|
||||
with LoggingContext("command"):
|
||||
await _base.start(ss)
|
||||
await args.func(ss, args)
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Optional
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.resource import IResource
|
||||
from twisted.web.server import Request
|
||||
|
||||
import synapse
|
||||
import synapse.events
|
||||
|
@ -43,7 +44,7 @@ from synapse.config.server import ListenerConfig
|
|||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.http.server import JsonResource, OptionsResource
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
||||
|
@ -118,7 +119,6 @@ from synapse.storage.databases.main.stats import StatsStore
|
|||
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
||||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.databases.main.user_directory import UserDirectoryStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
|
@ -143,9 +143,7 @@ class KeyUploadServlet(RestServlet):
|
|||
self.http_client = hs.get_simple_http_client()
|
||||
self.main_uri = hs.config.worker.worker_main_http_uri
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, device_id: Optional[str]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: Request, device_id: Optional[str]):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -189,8 +187,9 @@ class KeyUploadServlet(RestServlet):
|
|||
# If the header exists, add to the comma-separated list of the first
|
||||
# instance of the header. Otherwise, generate a new header.
|
||||
if x_forwarded_for:
|
||||
x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host]
|
||||
x_forwarded_for.extend(x_forwarded_for[1:])
|
||||
x_forwarded_for = [
|
||||
x_forwarded_for[0] + b", " + previous_host
|
||||
] + x_forwarded_for[1:]
|
||||
else:
|
||||
x_forwarded_for = [previous_host]
|
||||
headers[b"X-Forwarded-For"] = x_forwarded_for
|
||||
|
@ -254,16 +253,13 @@ class GenericWorkerSlavedStore(
|
|||
SessionStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
# Properties that multiple storage classes define. Tell mypy what the
|
||||
# expected type is.
|
||||
server_name: str
|
||||
config: HomeServerConfig
|
||||
pass
|
||||
|
||||
|
||||
class GenericWorkerServer(HomeServer):
|
||||
DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore
|
||||
DATASTORE_CLASS = GenericWorkerSlavedStore
|
||||
|
||||
def _listen_http(self, listener_config: ListenerConfig) -> None:
|
||||
def _listen_http(self, listener_config: ListenerConfig):
|
||||
port = listener_config.port
|
||||
bind_addresses = listener_config.bind_addresses
|
||||
|
||||
|
@ -271,10 +267,10 @@ class GenericWorkerServer(HomeServer):
|
|||
|
||||
site_tag = listener_config.http_options.tag
|
||||
if site_tag is None:
|
||||
site_tag = str(port)
|
||||
site_tag = port
|
||||
|
||||
# We always include a health resource.
|
||||
resources: Dict[str, Resource] = {"/health": HealthResource()}
|
||||
resources: Dict[str, IResource] = {"/health": HealthResource()}
|
||||
|
||||
for res in listener_config.http_options.resources:
|
||||
for name in res.names:
|
||||
|
@ -390,7 +386,7 @@ class GenericWorkerServer(HomeServer):
|
|||
|
||||
logger.info("Synapse worker now listening on port %d", port)
|
||||
|
||||
def start_listening(self) -> None:
|
||||
def start_listening(self):
|
||||
for listener in self.config.worker.worker_listeners:
|
||||
if listener.type == "http":
|
||||
self._listen_http(listener)
|
||||
|
@ -415,7 +411,7 @@ class GenericWorkerServer(HomeServer):
|
|||
self.get_tcp_replication().start_replication(self)
|
||||
|
||||
|
||||
def start(config_options: List[str]) -> None:
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config("Synapse worker", config_options)
|
||||
except ConfigError as e:
|
||||
|
|
|
@ -16,10 +16,10 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Iterable, Iterator, List
|
||||
from typing import Iterator
|
||||
|
||||
from twisted.internet.tcp import Port
|
||||
from twisted.web.resource import EncodingResourceWrapper, Resource
|
||||
from twisted.internet import reactor
|
||||
from twisted.web.resource import EncodingResourceWrapper, IResource
|
||||
from twisted.web.server import GzipEncoderFactory
|
||||
from twisted.web.static import File
|
||||
|
||||
|
@ -76,27 +76,23 @@ from synapse.util.versionstring import get_version_string
|
|||
logger = logging.getLogger("synapse.app.homeserver")
|
||||
|
||||
|
||||
def gz_wrap(r: Resource) -> Resource:
|
||||
def gz_wrap(r):
|
||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||
|
||||
|
||||
class SynapseHomeServer(HomeServer):
|
||||
DATASTORE_CLASS = DataStore # type: ignore
|
||||
DATASTORE_CLASS = DataStore
|
||||
|
||||
def _listener_http(
|
||||
self, config: HomeServerConfig, listener_config: ListenerConfig
|
||||
) -> Iterable[Port]:
|
||||
def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
|
||||
port = listener_config.port
|
||||
bind_addresses = listener_config.bind_addresses
|
||||
tls = listener_config.tls
|
||||
# Must exist since this is an HTTP listener.
|
||||
assert listener_config.http_options is not None
|
||||
site_tag = listener_config.http_options.tag
|
||||
if site_tag is None:
|
||||
site_tag = str(port)
|
||||
|
||||
# We always include a health resource.
|
||||
resources: Dict[str, Resource] = {"/health": HealthResource()}
|
||||
resources = {"/health": HealthResource()}
|
||||
|
||||
for res in listener_config.http_options.resources:
|
||||
for name in res.names:
|
||||
|
@ -115,7 +111,7 @@ class SynapseHomeServer(HomeServer):
|
|||
("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
|
||||
)
|
||||
handler = handler_cls(config, module_api)
|
||||
if isinstance(handler, Resource):
|
||||
if IResource.providedBy(handler):
|
||||
resource = handler
|
||||
elif hasattr(handler, "handle_request"):
|
||||
resource = AdditionalResource(self, handler.handle_request)
|
||||
|
@ -132,7 +128,7 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
# try to find something useful to redirect '/' to
|
||||
if WEB_CLIENT_PREFIX in resources:
|
||||
root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
|
||||
root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
|
||||
elif STATIC_PREFIX in resources:
|
||||
root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
|
||||
else:
|
||||
|
@ -149,8 +145,6 @@ class SynapseHomeServer(HomeServer):
|
|||
)
|
||||
|
||||
if tls:
|
||||
# refresh_certificate should have been called before this.
|
||||
assert self.tls_server_context_factory is not None
|
||||
ports = listen_ssl(
|
||||
bind_addresses,
|
||||
port,
|
||||
|
@ -171,21 +165,20 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
return ports
|
||||
|
||||
def _configure_named_resource(
|
||||
self, name: str, compress: bool = False
|
||||
) -> Dict[str, Resource]:
|
||||
def _configure_named_resource(self, name, compress=False):
|
||||
"""Build a resource map for a named resource
|
||||
|
||||
Args:
|
||||
name: named resource: one of "client", "federation", etc
|
||||
compress: whether to enable gzip compression for this resource
|
||||
name (str): named resource: one of "client", "federation", etc
|
||||
compress (bool): whether to enable gzip compression for this
|
||||
resource
|
||||
|
||||
Returns:
|
||||
map from path to HTTP resource
|
||||
dict[str, Resource]: map from path to HTTP resource
|
||||
"""
|
||||
resources: Dict[str, Resource] = {}
|
||||
resources = {}
|
||||
if name == "client":
|
||||
client_resource: Resource = ClientRestResource(self)
|
||||
client_resource = ClientRestResource(self)
|
||||
if compress:
|
||||
client_resource = gz_wrap(client_resource)
|
||||
|
||||
|
@ -214,7 +207,7 @@ class SynapseHomeServer(HomeServer):
|
|||
if name == "consent":
|
||||
from synapse.rest.consent.consent_resource import ConsentResource
|
||||
|
||||
consent_resource: Resource = ConsentResource(self)
|
||||
consent_resource = ConsentResource(self)
|
||||
if compress:
|
||||
consent_resource = gz_wrap(consent_resource)
|
||||
resources.update({"/_matrix/consent": consent_resource})
|
||||
|
@ -284,7 +277,7 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
return resources
|
||||
|
||||
def start_listening(self) -> None:
|
||||
def start_listening(self):
|
||||
if self.config.redis.redis_enabled:
|
||||
# If redis is enabled we connect via the replication command handler
|
||||
# in the same way as the workers (since we're effectively a client
|
||||
|
@ -310,9 +303,7 @@ class SynapseHomeServer(HomeServer):
|
|||
ReplicationStreamProtocolFactory(self),
|
||||
)
|
||||
for s in services:
|
||||
self.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", s.stopListening
|
||||
)
|
||||
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
|
||||
elif listener.type == "metrics":
|
||||
if not self.config.metrics.enable_metrics:
|
||||
logger.warning(
|
||||
|
@ -327,13 +318,14 @@ class SynapseHomeServer(HomeServer):
|
|||
logger.warning("Unrecognized listener type: %s", listener.type)
|
||||
|
||||
|
||||
def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||
def setup(config_options):
|
||||
"""
|
||||
Args:
|
||||
config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`.
|
||||
config_options_options: The options passed to Synapse. Usually
|
||||
`sys.argv[1:]`.
|
||||
|
||||
Returns:
|
||||
A homeserver instance.
|
||||
HomeServer
|
||||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
|
@ -372,7 +364,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
|
|||
except Exception as e:
|
||||
handle_startup_exception(e)
|
||||
|
||||
async def start() -> None:
|
||||
async def start():
|
||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
|
@ -412,15 +404,39 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
|
|||
|
||||
yield ":\n %s" % (e.msg,)
|
||||
|
||||
parent_e = e.__cause__
|
||||
e = e.__cause__
|
||||
indent = 1
|
||||
while parent_e:
|
||||
while e:
|
||||
indent += 1
|
||||
yield ":\n%s%s" % (" " * indent, str(parent_e))
|
||||
parent_e = parent_e.__cause__
|
||||
yield ":\n%s%s" % (" " * indent, str(e))
|
||||
e = e.__cause__
|
||||
|
||||
|
||||
def run(hs: HomeServer) -> None:
|
||||
def run(hs: HomeServer):
|
||||
PROFILE_SYNAPSE = False
|
||||
if PROFILE_SYNAPSE:
|
||||
|
||||
def profile(func):
|
||||
from cProfile import Profile
|
||||
from threading import current_thread
|
||||
|
||||
def profiled(*args, **kargs):
|
||||
profile = Profile()
|
||||
profile.enable()
|
||||
func(*args, **kargs)
|
||||
profile.disable()
|
||||
ident = current_thread().ident
|
||||
profile.dump_stats(
|
||||
"/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
|
||||
)
|
||||
|
||||
return profiled
|
||||
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
|
||||
ThreadPool._worker = profile(ThreadPool._worker)
|
||||
reactor.run = profile(reactor.run)
|
||||
|
||||
_base.start_reactor(
|
||||
"synapse-homeserver",
|
||||
soft_file_limit=hs.config.server.soft_file_limit,
|
||||
|
@ -432,7 +448,7 @@ def run(hs: HomeServer) -> None:
|
|||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def main():
|
||||
with LoggingContext("main"):
|
||||
# check base requirements
|
||||
check_requirements()
|
||||
|
|
|
@ -15,12 +15,11 @@ import logging
|
|||
import math
|
||||
import resource
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, List, Sized, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -29,7 +28,7 @@ logger = logging.getLogger("synapse.app.homeserver")
|
|||
|
||||
# Contains the list of processes we will be monitoring
|
||||
# currently either 0 or 1
|
||||
_stats_process: List[Tuple[int, "resource.struct_rusage"]] = []
|
||||
_stats_process = []
|
||||
|
||||
# Gauges to expose monthly active user control metrics
|
||||
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
||||
|
@ -46,15 +45,9 @@ registered_reserved_users_mau_gauge = Gauge(
|
|||
|
||||
|
||||
@wrap_as_background_process("phone_stats_home")
|
||||
async def phone_stats_home(
|
||||
hs: "HomeServer",
|
||||
stats: JsonDict,
|
||||
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
|
||||
) -> None:
|
||||
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
|
||||
logger.info("Gathering stats for reporting")
|
||||
now = int(hs.get_clock().time())
|
||||
# Ensure the homeserver has started.
|
||||
assert hs.start_time is not None
|
||||
uptime = int(now - hs.start_time)
|
||||
if uptime < 0:
|
||||
uptime = 0
|
||||
|
@ -153,15 +146,15 @@ async def phone_stats_home(
|
|||
logger.warning("Error reporting stats: %s", e)
|
||||
|
||||
|
||||
def start_phone_stats_home(hs: "HomeServer") -> None:
|
||||
def start_phone_stats_home(hs: "HomeServer"):
|
||||
"""
|
||||
Start the background tasks which report phone home stats.
|
||||
"""
|
||||
clock = hs.get_clock()
|
||||
|
||||
stats: JsonDict = {}
|
||||
stats = {}
|
||||
|
||||
def performance_stats_init() -> None:
|
||||
def performance_stats_init():
|
||||
_stats_process.clear()
|
||||
_stats_process.append(
|
||||
(int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
|
||||
|
@ -177,10 +170,10 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
|
|||
hs.get_datastore().reap_monthly_active_users()
|
||||
|
||||
@wrap_as_background_process("generate_monthly_active_users")
|
||||
async def generate_monthly_active_users() -> None:
|
||||
async def generate_monthly_active_users():
|
||||
current_mau_count = 0
|
||||
current_mau_count_by_service = {}
|
||||
reserved_users: Sized = ()
|
||||
reserved_users = ()
|
||||
store = hs.get_datastore()
|
||||
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
|
||||
current_mau_count = await store.get_monthly_active_count()
|
||||
|
|
|
@ -46,3 +46,11 @@ class ExperimentalConfig(Config):
|
|||
|
||||
# MSC3266 (room summary api)
|
||||
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
|
||||
|
||||
# MSC2409 (this setting only relates to optionally sending to-device messages).
|
||||
# Presence, typing and read receipt EDUs are already sent to application services that
|
||||
# have opted in to receive them. This setting, if enabled, adds to-device messages
|
||||
# to that list.
|
||||
self.msc2409_to_device_messages_enabled: bool = experimental.get(
|
||||
"msc2409_to_device_messages_enabled", False
|
||||
)
|
||||
|
|
|
@ -234,7 +234,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
|
|||
|
||||
@abc.abstractmethod
|
||||
def write_invite(
|
||||
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||
self, room_id: str, event: EventBase, state: StateMap[dict]
|
||||
) -> None:
|
||||
"""Write an invite for the room, with associated invite state.
|
||||
|
||||
|
@ -248,7 +248,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
|
|||
|
||||
@abc.abstractmethod
|
||||
def write_knock(
|
||||
self, room_id: str, event: EventBase, state: StateMap[EventBase]
|
||||
self, room_id: str, event: EventBase, state: StateMap[dict]
|
||||
) -> None:
|
||||
"""Write a knock for the room, with associated knock state.
|
||||
|
||||
|
|
|
@ -12,7 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
|
@ -42,6 +51,8 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TO_DEVICE_MESSAGES_PER_AS_TRANSACTION = 100
|
||||
|
||||
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
|
||||
|
||||
|
||||
|
@ -55,6 +66,9 @@ class ApplicationServicesHandler:
|
|||
self.clock = hs.get_clock()
|
||||
self.notify_appservices = hs.config.appservice.notify_appservices
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.msc2409_to_device_messages_enabled = (
|
||||
hs.config.experimental.msc2409_to_device_messages_enabled
|
||||
)
|
||||
|
||||
self.current_max = 0
|
||||
self.is_processing = False
|
||||
|
@ -199,8 +213,9 @@ class ApplicationServicesHandler:
|
|||
Args:
|
||||
stream_key: The stream the event came from.
|
||||
|
||||
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
|
||||
value for `stream_key` will cause this function to return early.
|
||||
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
|
||||
"to_device_key". Any other value for `stream_key` will cause this function
|
||||
to return early.
|
||||
|
||||
Ephemeral events will only be pushed to appservices that have opted into
|
||||
receiving them by setting `push_ephemeral` to true in their registration
|
||||
|
@ -216,10 +231,6 @@ class ApplicationServicesHandler:
|
|||
if not self.notify_appservices:
|
||||
return
|
||||
|
||||
# Ignore any unsupported streams
|
||||
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
|
||||
return
|
||||
|
||||
# Assert that new_token is an integer (and not a RoomStreamToken).
|
||||
# All of the supported streams that this function handles use an
|
||||
# integer to track progress (rather than a RoomStreamToken - a
|
||||
|
@ -233,6 +244,13 @@ class ApplicationServicesHandler:
|
|||
# Additional context: https://github.com/matrix-org/synapse/pull/11137
|
||||
assert isinstance(new_token, int)
|
||||
|
||||
# Ignore to-device messages if the feature flag is not enabled
|
||||
if (
|
||||
stream_key == "to_device_key"
|
||||
and not self.msc2409_to_device_messages_enabled
|
||||
):
|
||||
return
|
||||
|
||||
# Check whether there are any appservices which have registered to receive
|
||||
# ephemeral events.
|
||||
#
|
||||
|
@ -307,6 +325,30 @@ class ApplicationServicesHandler:
|
|||
service, "presence", new_token
|
||||
)
|
||||
|
||||
elif (
|
||||
stream_key == "to_device_key"
|
||||
and new_token is not None
|
||||
and self.msc2409_to_device_messages_enabled
|
||||
):
|
||||
# Retrieve an iterable of to-device message events, as well as the
|
||||
# maximum stream token of the messages we were able to retrieve.
|
||||
events, max_stream_token = await self._handle_to_device(
|
||||
service, new_token, users
|
||||
)
|
||||
if events:
|
||||
self.scheduler.submit_ephemeral_events_for_as(
|
||||
service, events
|
||||
)
|
||||
|
||||
# TODO: If max_stream_token != new_token, schedule another transaction immediately,
|
||||
# instead of waiting for another to-device to be sent?
|
||||
# https://github.com/matrix-org/synapse/issues/11150#issuecomment-960726449
|
||||
|
||||
# Persist the latest handled stream token for this appservice
|
||||
await self.store.set_type_stream_id_for_appservice(
|
||||
service, "to_device", max_stream_token
|
||||
)
|
||||
|
||||
async def _handle_typing(
|
||||
self, service: ApplicationService, new_token: int
|
||||
) -> List[JsonDict]:
|
||||
|
@ -440,6 +482,88 @@ class ApplicationServicesHandler:
|
|||
|
||||
return events
|
||||
|
||||
async def _handle_to_device(
|
||||
self,
|
||||
service: ApplicationService,
|
||||
new_token: int,
|
||||
users: Collection[Union[str, UserID]],
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""
|
||||
Given an application service, determine which events it should receive
|
||||
from those between the last-recorded typing event stream token for this
|
||||
appservice and the given stream token.
|
||||
|
||||
Args:
|
||||
service: The application service to check for which events it should receive.
|
||||
new_token: The latest to-device event stream token.
|
||||
users: The users that should receive new to-device messages.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the following:
|
||||
* A list of JSON dictionaries containing data derived from the typing events
|
||||
that should be sent to the given application service.
|
||||
* The last-processed to-device message's stream id. If this does not equal
|
||||
`new_token` then there are likely more events to send.
|
||||
"""
|
||||
# Get the stream token that this application service has processed up until
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
service, "to_device"
|
||||
)
|
||||
|
||||
# Filter out users that this appservice is not interested in
|
||||
users_appservice_is_interested_in: List[str] = []
|
||||
for user in users:
|
||||
if isinstance(user, UserID):
|
||||
user = user.to_string()
|
||||
|
||||
if service.is_interested_in_user(user):
|
||||
users_appservice_is_interested_in.append(user)
|
||||
|
||||
if not users_appservice_is_interested_in:
|
||||
# Return early if the AS was not interested in any of these users
|
||||
return [], new_token
|
||||
|
||||
# Retrieve the to-device messages for each user
|
||||
(
|
||||
recipient_user_id_device_id_to_messages,
|
||||
# Record the maximum stream token of the retrieved messages that
|
||||
# this function was able to pull before hitting the max to-device
|
||||
# message count limit.
|
||||
# We return this value later, to ensure we only record the stream
|
||||
# token we managed to get up to, so that the rest can be sent later.
|
||||
max_stream_token,
|
||||
) = await self.store.get_new_messages(
|
||||
users_appservice_is_interested_in,
|
||||
from_key,
|
||||
new_token,
|
||||
limit=MAX_TO_DEVICE_MESSAGES_PER_AS_TRANSACTION,
|
||||
)
|
||||
|
||||
# According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
|
||||
# to the event JSON so that the application service will know which user/device
|
||||
# combination this messages was intended for.
|
||||
#
|
||||
# So we mangle this dict into a flat list of to-device messages with the relevant
|
||||
# user ID and device ID embedded inside each message dict.
|
||||
message_payload: List[JsonDict] = []
|
||||
for (
|
||||
user_id,
|
||||
device_id,
|
||||
), messages in recipient_user_id_device_id_to_messages.items():
|
||||
for message_json in messages:
|
||||
# Remove 'message_id' from the to-device message, as it's an internal ID
|
||||
message_json.pop("message_id", None)
|
||||
|
||||
message_payload.append(
|
||||
{
|
||||
"to_user_id": user_id,
|
||||
"to_device_id": device_id,
|
||||
**message_json,
|
||||
}
|
||||
)
|
||||
|
||||
return message_payload, max_stream_token
|
||||
|
||||
async def query_user_exists(self, user_id: str) -> bool:
|
||||
"""Check if any application service knows this user_id exists.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
from logging import Handler, LogRecord
|
||||
from logging.handlers import MemoryHandler
|
||||
from threading import Thread
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
|
||||
|
@ -56,7 +56,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
|
|||
if reactor is None:
|
||||
from twisted.internet import reactor as global_reactor
|
||||
|
||||
reactor_to_use = cast(IReactorCore, global_reactor)
|
||||
reactor_to_use = global_reactor # type: ignore[assignment]
|
||||
else:
|
||||
reactor_to_use = reactor
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ import attr
|
|||
import jinja2
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.resource import IResource
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
|
@ -196,7 +196,7 @@ class ModuleApi:
|
|||
"""
|
||||
return self._password_auth_provider.register_password_auth_provider_callbacks
|
||||
|
||||
def register_web_resource(self, path: str, resource: Resource):
|
||||
def register_web_resource(self, path: str, resource: IResource):
|
||||
"""Registers a web resource to be served at the given path.
|
||||
|
||||
This function should be called during initialisation of the module.
|
||||
|
|
|
@ -444,15 +444,23 @@ class Notifier:
|
|||
|
||||
self.notify_replication()
|
||||
|
||||
# Notify appservices.
|
||||
try:
|
||||
self.appservice_handler.notify_interested_services_ephemeral(
|
||||
stream_key,
|
||||
new_token,
|
||||
users,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error notifying application services of event")
|
||||
# Notify appservices of updates in ephemeral event streams.
|
||||
# Only the following streams are currently supported.
|
||||
if stream_key in (
|
||||
"typing_key",
|
||||
"receipt_key",
|
||||
"presence_key",
|
||||
"to_device_key",
|
||||
):
|
||||
# Notify appservices.
|
||||
try:
|
||||
self.appservice_handler.notify_interested_services_ephemeral(
|
||||
stream_key,
|
||||
new_token,
|
||||
users,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error notifying application services of event")
|
||||
|
||||
def on_new_replication_data(self) -> None:
|
||||
"""Used to inform replication listeners that something has happened
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
from twisted.internet.protocol import Factory
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import PositionCommand
|
||||
|
@ -38,7 +38,7 @@ stream_updates_counter = Counter(
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationStreamProtocolFactory(ServerFactory):
|
||||
class ReplicationStreamProtocolFactory(Factory):
|
||||
"""Factory for new replication connections."""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
|
|
@ -101,8 +101,8 @@ class Thumbnailer:
|
|||
fits within the given rectangle::
|
||||
|
||||
(w_in / h_in) = (w_out / h_out)
|
||||
w_out = max(min(w_max, h_max * (w_in / h_in)), 1)
|
||||
h_out = max(min(h_max, w_max * (h_in / w_in)), 1)
|
||||
w_out = min(w_max, h_max * (w_in / h_in))
|
||||
h_out = min(h_max, w_max * (h_in / w_in))
|
||||
|
||||
Args:
|
||||
max_width: The largest possible width.
|
||||
|
@ -110,9 +110,9 @@ class Thumbnailer:
|
|||
"""
|
||||
|
||||
if max_width * self.height < max_height * self.width:
|
||||
return max_width, max((max_width * self.height) // self.width, 1)
|
||||
return max_width, (max_width * self.height) // self.width
|
||||
else:
|
||||
return max((max_height * self.width) // self.height, 1), max_height
|
||||
return (max_height * self.width) // self.height, max_height
|
||||
|
||||
def _resize(self, width: int, height: int) -> Image.Image:
|
||||
# 1-bit or 8-bit color palette images need converting to RGB
|
||||
|
|
|
@ -33,10 +33,9 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from twisted.internet.interfaces import IOpenSSLContextFactory
|
||||
from twisted.internet.tcp import Port
|
||||
import twisted.internet.tcp
|
||||
from twisted.web.iweb import IPolicyForHTTPS
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.resource import IResource
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.filtering import Filtering
|
||||
|
@ -207,7 +206,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
|
||||
Attributes:
|
||||
config (synapse.config.homeserver.HomeserverConfig):
|
||||
_listening_services (list[Port]): TCP ports that
|
||||
_listening_services (list[twisted.internet.tcp.Port]): TCP ports that
|
||||
we are listening on to provide HTTP services.
|
||||
"""
|
||||
|
||||
|
@ -226,8 +225,6 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
# instantiated during setup() for future return by get_datastore()
|
||||
DATASTORE_CLASS = abc.abstractproperty()
|
||||
|
||||
tls_server_context_factory: Optional[IOpenSSLContextFactory]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str,
|
||||
|
@ -250,7 +247,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
# the key we use to sign events and requests
|
||||
self.signing_key = config.key.signing_key[0]
|
||||
self.config = config
|
||||
self._listening_services: List[Port] = []
|
||||
self._listening_services: List[twisted.internet.tcp.Port] = []
|
||||
self.start_time: Optional[int] = None
|
||||
|
||||
self._instance_id = random_string(5)
|
||||
|
@ -260,10 +257,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
|
||||
self.datastores: Optional[Databases] = None
|
||||
|
||||
self._module_web_resources: Dict[str, Resource] = {}
|
||||
self._module_web_resources: Dict[str, IResource] = {}
|
||||
self._module_web_resources_consumed = False
|
||||
|
||||
def register_module_web_resource(self, path: str, resource: Resource):
|
||||
def register_module_web_resource(self, path: str, resource: IResource):
|
||||
"""Allows a module to register a web resource to be served at the given path.
|
||||
|
||||
If multiple modules register a resource for the same path, the module that
|
||||
|
|
|
@ -387,7 +387,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
async def get_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str
|
||||
) -> int:
|
||||
if type not in ("read_receipt", "presence"):
|
||||
if type not in ("read_receipt", "presence", "to_device"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (type,)
|
||||
|
@ -414,7 +414,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
async def set_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, stream_type: str, pos: Optional[int]
|
||||
) -> None:
|
||||
if stream_type not in ("read_receipt", "presence"):
|
||||
if stream_type not in ("read_receipt", "presence", "to_device"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (stream_type,)
|
||||
|
|
|
@ -13,13 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.logging import issue9533_logger
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.replication.tcp.streams import ToDeviceStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
|
@ -116,6 +120,97 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
def get_to_device_stream_token(self):
|
||||
return self._device_inbox_id_gen.get_current_token()
|
||||
|
||||
async def get_new_messages(
|
||||
self,
|
||||
user_ids: Collection[str],
|
||||
from_stream_id: int,
|
||||
to_stream_id: int,
|
||||
limit: int = 100,
|
||||
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
|
||||
"""
|
||||
Retrieve to-device messages for a given set of user IDs.
|
||||
|
||||
Only to-device messages with stream ids between the given boundaries
|
||||
(from < X <= to) are returned.
|
||||
|
||||
Note that multiple messages can have the same stream id. Stream ids are
|
||||
unique to *messages*, but there might be multiple recipients of a message, and
|
||||
thus multiple entries in the device_inbox table with the same stream id.
|
||||
|
||||
Args:
|
||||
user_ids: The users to retrieve to-device messages for.
|
||||
from_stream_id: The lower boundary of stream id to filter with (exclusive).
|
||||
to_stream_id: The upper boundary of stream id to filter with (inclusive).
|
||||
limit: The maximum number of to-device messages to return.
|
||||
|
||||
Returns:
|
||||
A tuple containing the following:
|
||||
* A list of to-device messages
|
||||
* The stream id that to-device messages were processed up to. The calling
|
||||
function can check this value against `to_stream_id` to see if we hit
|
||||
the given limit when fetching messages.
|
||||
"""
|
||||
# Bail out if none of these users have any messages
|
||||
for user_id in user_ids:
|
||||
if self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, from_stream_id
|
||||
):
|
||||
break
|
||||
else:
|
||||
return {}, to_stream_id
|
||||
|
||||
def get_new_messages_txn(txn: LoggingTransaction):
|
||||
# Build a query to select messages from any of the given users that are between
|
||||
# the given stream id bounds
|
||||
|
||||
# Scope to only the given users. We need to use this method as doing so is
|
||||
# different across database engines.
|
||||
many_clause_sql, many_clause_args = make_in_list_sql_clause(
|
||||
self.database_engine, "user_id", user_ids
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
|
||||
WHERE {many_clause_sql}
|
||||
AND ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id, limit))
|
||||
|
||||
# Create a dictionary of (user ID, device ID) -> list of messages that
|
||||
# that device is meant to receive.
|
||||
recipient_user_id_device_id_to_messages = {}
|
||||
|
||||
stream_pos = to_stream_id
|
||||
total_messages_processed = 0
|
||||
for row in txn:
|
||||
# Record the last-processed stream position, to return later.
|
||||
# Note that we process messages here in order of ascending stream id.
|
||||
stream_pos = row[0]
|
||||
recipient_user_id = row[1]
|
||||
recipient_device_id = row[2]
|
||||
message_dict = db_to_json(row[3])
|
||||
|
||||
recipient_user_id_device_id_to_messages.setdefault(
|
||||
(recipient_user_id, recipient_device_id), []
|
||||
).append(message_dict)
|
||||
total_messages_processed += 1
|
||||
|
||||
# If we don't end up hitting the limit, we still want to return the equivalent
|
||||
# value of `to_stream_id` to the calling function. This is needed as we'll
|
||||
# be processing up to `to_stream_id`, without necessarily fetching a message
|
||||
# with a stream id of `to_stream_id`.
|
||||
if total_messages_processed < limit:
|
||||
stream_pos = to_stream_id
|
||||
|
||||
return recipient_user_id_device_id_to_messages, stream_pos
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_new_messages", get_new_messages_txn
|
||||
)
|
||||
|
||||
async def get_new_messages_for_device(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Add a column to track what to_device stream id that this application
|
||||
-- service has been caught up to.
|
||||
ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;
|
|
@ -38,7 +38,6 @@ from zope.interface import Interface
|
|||
from twisted.internet.interfaces import (
|
||||
IReactorCore,
|
||||
IReactorPluggableNameResolver,
|
||||
IReactorSSL,
|
||||
IReactorTCP,
|
||||
IReactorThreads,
|
||||
IReactorTime,
|
||||
|
@ -67,7 +66,6 @@ JsonDict = Dict[str, Any]
|
|||
# for mypy-zope to realize it is an interface.
|
||||
class ISynapseReactor(
|
||||
IReactorTCP,
|
||||
IReactorSSL,
|
||||
IReactorPluggableNameResolver,
|
||||
IReactorTime,
|
||||
IReactorCore,
|
||||
|
|
|
@ -31,13 +31,13 @@ from typing import (
|
|||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.base import ReactorBase
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.python import failure
|
||||
|
@ -271,7 +271,8 @@ class Linearizer:
|
|||
if not clock:
|
||||
from twisted.internet import reactor
|
||||
|
||||
clock = Clock(cast(IReactorTime, reactor))
|
||||
assert isinstance(reactor, ReactorBase)
|
||||
clock = Clock(reactor)
|
||||
self._clock = clock
|
||||
self.max_count = max_count
|
||||
|
||||
|
|
|
@ -92,9 +92,9 @@ def _resource_id(resource: Resource, path_seg: bytes) -> str:
|
|||
the mapping should looks like _resource_id(A,C) = B.
|
||||
|
||||
Args:
|
||||
resource: The *parent* Resourceb
|
||||
path_seg: The name of the child Resource to be attached.
|
||||
resource (Resource): The *parent* Resourceb
|
||||
path_seg (str): The name of the child Resource to be attached.
|
||||
Returns:
|
||||
A unique string which can be a key to the child Resource.
|
||||
str: A unique string which can be a key to the child Resource.
|
||||
"""
|
||||
return "%s-%r" % (resource, path_seg)
|
||||
|
|
|
@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
|
|||
from twisted.conch.ssh.keys import Key
|
||||
from twisted.cred import checkers, portal
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
from twisted.internet.protocol import Factory
|
||||
|
||||
from synapse.config.server import ManholeConfig
|
||||
|
||||
|
@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
|
|||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
|
||||
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
|
||||
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
|
||||
"""Starts a ssh listener with password authentication using
|
||||
the given username and password. Clients connecting to the ssh
|
||||
listener will find themselves in a colored python shell with
|
||||
|
@ -105,8 +105,7 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
|
|||
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
|
||||
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
|
||||
|
||||
# ConchFactory is a Factory, not a ServerFactory, but they are identical.
|
||||
return factory # type: ignore[return-value]
|
||||
return factory
|
||||
|
||||
|
||||
class SynapseManhole(ColoredManhole):
|
||||
|
|
|
@ -253,20 +253,59 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_notify_interested_services_ephemeral(self):
|
||||
def test_notify_interested_services_ephemeral_read_receipt(self):
|
||||
"""
|
||||
Test sending ephemeral events to the appservice handler are scheduled
|
||||
Test sending read receipts to the appservice handler are scheduled
|
||||
to be pushed out to interested appservices, and that the stream ID is
|
||||
updated accordingly.
|
||||
"""
|
||||
# Create an application service that is guaranteed to be interested in
|
||||
# any new events
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
services = [interested_service]
|
||||
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
|
||||
# State that this application service has received up until stream ID 579
|
||||
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
||||
579
|
||||
)
|
||||
|
||||
# Set up a dummy event that should be sent to the application service
|
||||
event = Mock(event_id="event_1")
|
||||
self.event_source.sources.receipt.get_new_events_as.return_value = (
|
||||
make_awaitable(([event], None))
|
||||
)
|
||||
|
||||
self.handler.notify_interested_services_ephemeral(
|
||||
"receipt_key", 580, ["@fakerecipient:example.com"]
|
||||
)
|
||||
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
|
||||
interested_service, [event]
|
||||
)
|
||||
self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
|
||||
interested_service,
|
||||
"read_receipt",
|
||||
580,
|
||||
)
|
||||
|
||||
def test_notify_interested_services_ephemeral_to_device(self):
|
||||
"""
|
||||
Test sending read receipts to the appservice handler are scheduled
|
||||
to be pushed out to interested appservices, and that the stream ID is
|
||||
updated accordingly.
|
||||
"""
|
||||
# Create an application service that is guaranteed to be interested in
|
||||
# any new events
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
services = [interested_service]
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
|
||||
# State that this application service has received up until stream ID 579
|
||||
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
||||
579
|
||||
)
|
||||
|
||||
# Set up a dummy event that should be sent to the application service
|
||||
event = Mock(event_id="event_1")
|
||||
self.event_source.sources.receipt.get_new_events_as.return_value = (
|
||||
make_awaitable(([event], None))
|
||||
|
|
258
tests/handlers/test_appservice_ephemeral.py
Normal file
258
tests/handlers/test_appservice_ephemeral.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Iterable, Optional, Tuple, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
import synapse.rest.admin
|
||||
import synapse.storage
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client import login, receipts, room
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class ApplicationServiceEphemeralEventsTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
receipts.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# Mock the application service scheduler so that we can track any outgoing transactions
|
||||
self.mock_scheduler = Mock()
|
||||
self.mock_scheduler.submit_ephemeral_events_for_as = Mock()
|
||||
|
||||
hs.get_application_service_handler().scheduler = self.mock_scheduler
|
||||
|
||||
self.user1 = self.register_user("user1", "password")
|
||||
self.token1 = self.login("user1", "password")
|
||||
|
||||
self.user2 = self.register_user("user2", "password")
|
||||
self.token2 = self.login("user2", "password")
|
||||
|
||||
def test_application_services_receive_read_receipts(self):
|
||||
"""
|
||||
Test that when a user sends a read receipt in a room with another
|
||||
user, and that is in an application service's user namespace, that
|
||||
application service will receive that read receipt.
|
||||
"""
|
||||
(
|
||||
interested_service,
|
||||
_,
|
||||
) = self._register_interested_and_uninterested_application_services()
|
||||
|
||||
# Create a room with both user1 and user2
|
||||
room_id = self.helper.create_room_as(
|
||||
self.user1, tok=self.token1, is_public=True
|
||||
)
|
||||
self.helper.join(room_id, self.user2, tok=self.token2)
|
||||
|
||||
# Have user2 send a message into the room
|
||||
response_dict = self.helper.send(room_id, body="read me", tok=self.token2)
|
||||
|
||||
# Have user1 send a read receipt for the message with an empty body
|
||||
self._send_read_receipt(room_id, response_dict["event_id"], self.token1)
|
||||
|
||||
# user2 should have been the recipient of that read receipt.
|
||||
# Check if our application service - that is interested in user2 - received
|
||||
# the read receipt as part of an AS transaction.
|
||||
#
|
||||
# The uninterested application service should not have been notified.
|
||||
last_call = self.mock_scheduler.submit_ephemeral_events_for_as.call_args_list[
|
||||
0
|
||||
]
|
||||
service, events = last_call[0]
|
||||
self.assertEqual(service, interested_service)
|
||||
self.assertEqual(len(events), 1)
|
||||
self.assertEqual(events[0]["type"], "m.receipt")
|
||||
self.assertEqual(events[0]["room_id"], room_id)
|
||||
|
||||
# Assert that this was a read receipt from user1
|
||||
read_receipts = list(events[0]["content"].values())
|
||||
self.assertIn(self.user1, read_receipts[0]["m.read"])
|
||||
|
||||
# Clear mock stats
|
||||
self.mock_scheduler.submit_ephemeral_events_for_as.reset_mock()
|
||||
|
||||
# Send 2 pairs of messages + read receipts
|
||||
response_dict_1 = self.helper.send(room_id, body="read me1", tok=self.token2)
|
||||
response_dict_2 = self.helper.send(room_id, body="read me2", tok=self.token2)
|
||||
self._send_read_receipt(room_id, response_dict_1["event_id"], self.token1)
|
||||
self._send_read_receipt(room_id, response_dict_2["event_id"], self.token1)
|
||||
|
||||
# Assert each transaction that was sent to the application service is as expected
|
||||
self.assertEqual(2, self.mock_scheduler.submit_ephemeral_events_for_as.call_count)
|
||||
|
||||
first_call, second_call = self.mock_scheduler.submit_ephemeral_events_for_as.call_args_list
|
||||
service, events = first_call[0]
|
||||
self.assertEqual(service, interested_service)
|
||||
self.assertEqual(len(events), 1)
|
||||
self.assertEqual(events[0]["type"], "m.receipt")
|
||||
self.assertEqual(
|
||||
self._event_id_from_read_receipt(events[0]), response_dict_1["event_id"]
|
||||
)
|
||||
|
||||
service, events = second_call[0]
|
||||
self.assertEqual(service, interested_service)
|
||||
self.assertEqual(len(events), 1)
|
||||
self.assertEqual(events[0]["type"], "m.receipt")
|
||||
self.assertEqual(
|
||||
self._event_id_from_read_receipt(events[0]), response_dict_2["event_id"]
|
||||
)
|
||||
|
||||
def test_application_services_receive_to_device(self):
|
||||
"""
|
||||
Test that when a user sends a to-device message to another user, and
|
||||
that is in an application service's user namespace, that application
|
||||
service will receive it.
|
||||
"""
|
||||
(
|
||||
interested_service,
|
||||
_,
|
||||
) = self._register_interested_and_uninterested_application_services()
|
||||
|
||||
# Create a room with both user1 and user2
|
||||
room_id = self.helper.create_room_as(
|
||||
self.user1, tok=self.token1, is_public=True
|
||||
)
|
||||
self.helper.join(room_id, self.user2, tok=self.token2)
|
||||
|
||||
# Have user2 send a typing notification into the room
|
||||
response_dict = self.helper.send(room_id, body="read me", tok=self.token2)
|
||||
|
||||
# Have user1 send a read receipt for the message with an empty body
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/receipt/m.read/%s" % (room_id, response_dict["event_id"]),
|
||||
access_token=self.token1,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# user2 should have been the recipient of that read receipt.
|
||||
# Check if our application service - that is interested in user2 - received
|
||||
# the read receipt as part of an AS transaction.
|
||||
#
|
||||
# The uninterested application service should not have been notified.
|
||||
service, events = self.mock_scheduler.submit_ephemeral_events_for_as.call_args[
|
||||
0
|
||||
]
|
||||
self.assertEqual(service, interested_service)
|
||||
self.assertEqual(events[0]["type"], "m.receipt")
|
||||
self.assertEqual(events[0]["room_id"], room_id)
|
||||
|
||||
# Assert that this was a read receipt from user1
|
||||
read_receipts = list(events[0]["content"].values())
|
||||
self.assertIn(self.user1, read_receipts[0]["m.read"])
|
||||
|
||||
def _register_interested_and_uninterested_application_services(
|
||||
self,
|
||||
) -> Tuple[ApplicationService, ApplicationService]:
|
||||
# Create an application service with exclusive interest in user2
|
||||
interested_service = self._make_application_service(
|
||||
namespaces={
|
||||
ApplicationService.NS_USERS: [
|
||||
{
|
||||
"regex": "@user2:.+",
|
||||
"exclusive": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
uninterested_service = self._make_application_service()
|
||||
|
||||
# Register this application service, along with another, uninterested one
|
||||
services = [
|
||||
uninterested_service,
|
||||
interested_service,
|
||||
]
|
||||
self.hs.get_datastore().get_app_services = Mock(return_value=services)
|
||||
|
||||
return interested_service, uninterested_service
|
||||
|
||||
def _make_application_service(
|
||||
self,
|
||||
namespaces: Optional[
|
||||
Dict[
|
||||
Union[
|
||||
ApplicationService.NS_USERS,
|
||||
ApplicationService.NS_ALIASES,
|
||||
ApplicationService.NS_ROOMS,
|
||||
],
|
||||
Iterable[Dict],
|
||||
]
|
||||
] = None,
|
||||
supports_ephemeral: Optional[bool] = True,
|
||||
) -> ApplicationService:
|
||||
return ApplicationService(
|
||||
token=None,
|
||||
hostname="example.com",
|
||||
id=random_string(10),
|
||||
sender="@as:example.com",
|
||||
rate_limited=False,
|
||||
namespaces=namespaces,
|
||||
supports_ephemeral=supports_ephemeral,
|
||||
)
|
||||
|
||||
def _send_read_receipt(self, room_id: str, event_id_to_read: str, tok: str) -> None:
|
||||
"""
|
||||
Send a read receipt of an event into a room.
|
||||
|
||||
Args:
|
||||
room_id: The room to event is part of.
|
||||
event_id_to_read: The ID of the event being read.
|
||||
tok: The access token of the sender.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/receipt/m.read/%s" % (room_id, event_id_to_read),
|
||||
access_token=tok,
|
||||
content="{}",
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
def _event_id_from_read_receipt(self, read_receipt_dict: JsonDict):
|
||||
"""
|
||||
Extracts the first event ID from a read receipt. Read receipt dictionaries
|
||||
are in the form:
|
||||
|
||||
{
|
||||
'type': 'm.receipt',
|
||||
'room_id': '!PEzCqHyycBVxqMKIjI:test',
|
||||
'content': {
|
||||
'$DETIeTEH651c1N7sP_j-YZiaQqCaayHhYwmhZDVWDY8': { # We want this
|
||||
'm.read': {
|
||||
'@user1:test': {
|
||||
'ts': 1300,
|
||||
'hidden': False
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
read_receipt_dict: The dictionary returned from a POST read receipt call.
|
||||
|
||||
Returns:
|
||||
The (first) event ID the read receipt refers to.
|
||||
"""
|
||||
return list(read_receipt_dict["content"].keys())[0]
|
||||
|
||||
# TODO: Test that ephemeral messages aren't sent to application services that have
|
||||
# ephemeral: false
|
|
@ -230,7 +230,7 @@ class RestHelper:
|
|||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
):
|
||||
) -> JsonDict:
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
||||
|
@ -257,7 +257,7 @@ class RestHelper:
|
|||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
):
|
||||
) -> JsonDict:
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
||||
|
|
|
@ -320,7 +320,7 @@ class HomeserverTestCase(TestCase):
|
|||
def wait_for_background_updates(self) -> None:
|
||||
"""Block until all background database updates have completed."""
|
||||
while not self.get_success(
|
||||
self.store.db_pool.updates.has_completed_background_updates()
|
||||
self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
|
||||
):
|
||||
self.get_success(
|
||||
self.store.db_pool.updates.do_next_background_update(100), by=0.1
|
||||
|
|
|
@ -236,7 +236,7 @@ def setup_test_homeserver(
|
|||
else:
|
||||
database_config = {
|
||||
"name": "sqlite3",
|
||||
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
|
||||
"args": {"database": "test.db", "cp_min": 1, "cp_max": 1},
|
||||
}
|
||||
|
||||
if "db_txn_limit" in kwargs:
|
||||
|
|
Loading…
Reference in a new issue