diff --git a/.ci/worker-blacklist b/.ci/worker-blacklist index 5975cb98cf..cb8eae5d2a 100644 --- a/.ci/worker-blacklist +++ b/.ci/worker-blacklist @@ -1,10 +1,2 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse when run under worker mode. For more details, see sytest-blacklist. - -Can re-join room if re-invited - -# new failures as of https://github.com/matrix-org/sytest/pull/732 -Device list doesn't change if remote server is down - -# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 -Server correctly handles incoming m.device_list_update diff --git a/changelog.d/10601.misc b/changelog.d/10601.misc new file mode 100644 index 0000000000..1227113ff3 --- /dev/null +++ b/changelog.d/10601.misc @@ -0,0 +1 @@ +Add type annotations to the synapse.util package. diff --git a/changelog.d/10774.bugfix b/changelog.d/10774.bugfix new file mode 100644 index 0000000000..5c2f6f8ade --- /dev/null +++ b/changelog.d/10774.bugfix @@ -0,0 +1 @@ +Properly handle room upgrades of spaces. diff --git a/changelog.d/10788.misc b/changelog.d/10788.misc new file mode 100644 index 0000000000..568a85ac52 --- /dev/null +++ b/changelog.d/10788.misc @@ -0,0 +1 @@ +Remove fixed and flakey tests from the sytest-blacklist. diff --git a/changelog.d/10789.misc b/changelog.d/10789.misc new file mode 100644 index 0000000000..8a0b54e32a --- /dev/null +++ b/changelog.d/10789.misc @@ -0,0 +1 @@ +Improve internal details of the user directory code. \ No newline at end of file diff --git a/changelog.d/10795.doc b/changelog.d/10795.doc new file mode 100644 index 0000000000..3a0b622825 --- /dev/null +++ b/changelog.d/10795.doc @@ -0,0 +1 @@ +Correct 2 typographical errors in the *Log Contexts* documentation. diff --git a/changelog.d/10798.misc b/changelog.d/10798.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10798.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/changelog.d/10799.misc b/changelog.d/10799.misc new file mode 100644 index 0000000000..91f7ede096 --- /dev/null +++ b/changelog.d/10799.misc @@ -0,0 +1 @@ +Add a max version for the `jaeger-client` dependency for an incompatibility with the rust reporter. diff --git a/changelog.d/10804.doc b/changelog.d/10804.doc new file mode 100644 index 0000000000..5d57af3b5f --- /dev/null +++ b/changelog.d/10804.doc @@ -0,0 +1 @@ +Fixed a wording mistake in the sample configuration. Contributed by @bramvdnheuvel:nltrix.net. diff --git a/docs/log_contexts.md b/docs/log_contexts.md index d49dce8830..cb15dbe158 100644 --- a/docs/log_contexts.md +++ b/docs/log_contexts.md @@ -10,7 +10,7 @@ Logcontexts are also used for CPU and database accounting, so that we can track which requests were responsible for high CPU use or database activity. -The `synapse.logging.context` module provides a facilities for managing +The `synapse.logging.context` module provides facilities for managing the current log context (as well as providing the `LoggingContextFilter` class). @@ -351,7 +351,7 @@ and the awaitable chain is now orphaned, and will be garbage-collected at some point. Note that `await_something_interesting` is a coroutine, which Python implements as a generator function. When Python garbage-collects generator functions, it gives them a chance to -clean up by making the `async` (or `yield`) raise a `GeneratorExit` +clean up by making the `await` (or `yield`) raise a `GeneratorExit` exception. In our case, that means that the `__exit__` handler of `PreserveLoggingContext` will carefully restore the request context, but there is now nothing waiting for its return, so the request context is diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index e15a832220..95cca16552 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2086,7 +2086,7 @@ password_config: # #require_lowercase: true - # Whether a password must contain at least one lowercase letter. + # Whether a password must contain at least one uppercase letter. # Defaults to 'false'. # #require_uppercase: true diff --git a/docs/user_directory.md b/docs/user_directory.md index d4f38d2cf1..07fe954891 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -10,3 +10,40 @@ DB corruption) get stale or out of sync. If this happens, for now the solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql) and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. + +Data model +---------- + +There are five relevant tables that collectively form the "user directory". +Three of them track a master list of all the users we could search for. +The last two (collectively called the "search tables") track who can +see who. + +From all of these tables we exclude three types of local user: + - support users + - appservice users + - deactivated users + +* `user_directory`. This contains the user_id, display name and avatar we'll + return when you search the directory. + - Because there's only one directory entry per user, it's important that we only + ever put publicly visible names here. Otherwise we might leak a private + nickname or avatar used in a private room. + - Indexed on rooms. Indexed on users. + +* `user_directory_search`. To be joined to `user_directory`. It contains an extra + column that enables full text search based on user ids and display names. + Different schemas for SQLite and Postgres with different code paths to match. + - Indexed on the full text search data. Indexed on users. + +* `user_directory_stream_pos`. When the initial background update to populate + the directory is complete, we record a stream position here. This indicates + that synapse should now listen for room changes and incrementally update + the directory where necessary. + +* `users_in_public_rooms`. Contains associations between users and the public rooms they're in. + Used to determine which users are in public rooms and should be publicly visible in the directory. + +* `users_who_share_private_rooms`. Rows are triples `(L, M, room id)` where `L` + is a local user and `M` is a local or remote user. `L` and `M` should be + different, but this isn't enforced by a constraint. diff --git a/mypy.ini b/mypy.ini index 4096f72241..09ffdda1b9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -74,17 +74,7 @@ files = synapse/storage/util, synapse/streams, synapse/types.py, - synapse/util/async_helpers.py, - synapse/util/caches, - synapse/util/daemonize.py, - synapse/util/hash.py, - synapse/util/iterutils.py, - synapse/util/linked_list.py, - synapse/util/metrics.py, - synapse/util/macaroons.py, - synapse/util/module_loader.py, - synapse/util/msisdn.py, - synapse/util/stringutils.py, + synapse/util, synapse/visibility.py, tests/replication, tests/test_event_auth.py, @@ -102,6 +92,69 @@ files = [mypy-synapse.rest.client.*] disallow_untyped_defs = True +[mypy-synapse.util.batching_queue] +disallow_untyped_defs = True + +[mypy-synapse.util.caches.dictionary_cache] +disallow_untyped_defs = True + +[mypy-synapse.util.file_consumer] +disallow_untyped_defs = True + +[mypy-synapse.util.frozenutils] +disallow_untyped_defs = True + +[mypy-synapse.util.hash] +disallow_untyped_defs = True + +[mypy-synapse.util.httpresourcetree] +disallow_untyped_defs = True + +[mypy-synapse.util.iterutils] +disallow_untyped_defs = True + +[mypy-synapse.util.linked_list] +disallow_untyped_defs = True + +[mypy-synapse.util.logcontext] +disallow_untyped_defs = True + +[mypy-synapse.util.logformatter] +disallow_untyped_defs = True + +[mypy-synapse.util.macaroons] +disallow_untyped_defs = True + +[mypy-synapse.util.manhole] +disallow_untyped_defs = True + +[mypy-synapse.util.module_loader] +disallow_untyped_defs = True + +[mypy-synapse.util.msisdn] +disallow_untyped_defs = True + +[mypy-synapse.util.ratelimitutils] +disallow_untyped_defs = True + +[mypy-synapse.util.retryutils] +disallow_untyped_defs = True + +[mypy-synapse.util.rlimit] +disallow_untyped_defs = True + +[mypy-synapse.util.stringutils] +disallow_untyped_defs = True + +[mypy-synapse.util.templates] +disallow_untyped_defs = True + +[mypy-synapse.util.threepids] +disallow_untyped_defs = True + +[mypy-synapse.util.wheel_timer] +disallow_untyped_defs = True + [mypy-pymacaroons.*] ignore_missing_imports = True diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index c1a06ae022..4ff3c6de5f 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory): def buildProtocol(self, addr) -> RedisProtocol: ... class SubscriberFactory(RedisFactory): - def __init__(self): ... + def __init__(self) -> None: ... diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 3e3d09bbd2..cbdd74025b 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -46,7 +46,7 @@ class Ratelimiter: # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() + self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() async def can_do_action( self, @@ -56,7 +56,7 @@ class Ratelimiter: burst_count: Optional[int] = None, update: bool = True, n_actions: int = 1, - _time_now_s: Optional[int] = None, + _time_now_s: Optional[float] = None, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -160,7 +160,7 @@ class Ratelimiter: return allowed, time_allowed - def _prune_message_counts(self, time_now_s: int): + def _prune_message_counts(self, time_now_s: float): """Remove message count entries that have not exceeded their defined rate_hz limit @@ -188,7 +188,7 @@ class Ratelimiter: burst_count: Optional[int] = None, update: bool = True, n_actions: int = 1, - _time_now_s: Optional[int] = None, + _time_now_s: Optional[float] = None, ): """Checks if an action can be performed. If not, raises a LimitExceededError diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 4b1f213c75..d3270cd6d2 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -41,11 +41,11 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.public_baseurl is None: + if hs_config.server.public_baseurl is None: raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") - self._public_baseurl = hs_config.public_baseurl + self._public_baseurl = hs_config.server.public_baseurl def build_user_consent_uri(self, user_id): """Build a URI which we can give to the user to do their privacy diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 89bda00090..d1aa2e7fb5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -82,7 +82,7 @@ def start_worker_reactor(appname, config, run_command=reactor.run): run_command (Callable[]): callable that actually runs the reactor """ - logger = logging.getLogger(config.worker_app) + logger = logging.getLogger(config.worker.worker_app) start_reactor( appname, @@ -398,7 +398,7 @@ async def start(hs: "HomeServer"): # If background tasks are running on the main process, start collecting the # phone home stats. - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: start_phone_stats_home(hs) # We now freeze all allocated objects in the hopes that (almost) @@ -433,9 +433,13 @@ def setup_sentry(hs): # We set some default tags that give some context to this instance with sentry_sdk.configure_scope() as scope: - scope.set_tag("matrix_server_name", hs.config.server_name) + scope.set_tag("matrix_server_name", hs.config.server.server_name) - app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" + app = ( + hs.config.worker.worker_app + if hs.config.worker.worker_app + else "synapse.app.homeserver" + ) name = hs.get_instance_name() scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 7396db93c6..5e956b1e27 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -178,12 +178,12 @@ def start(config_options): sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) - if config.worker_app is not None: - assert config.worker_app == "synapse.app.admin_cmd" + if config.worker.worker_app is not None: + assert config.worker.worker_app == "synapse.app.admin_cmd" # Update the config with some basic overrides so that don't have to specify # a full worker config. - config.worker_app = "synapse.app.admin_cmd" + config.worker.worker_app = "synapse.app.admin_cmd" if ( not config.worker_daemonize @@ -196,7 +196,7 @@ def start(config_options): # Explicitly disable background processes config.update_user_directory = False - config.run_background_tasks = False + config.worker.run_background_tasks = False config.start_pushers = False config.pusher_shard_config.instances = [] config.send_federation = False @@ -205,7 +205,7 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts ss = AdminCmdServer( - config.server_name, + config.server.server_name, config=config, version_string="Synapse/" + get_version_string(synapse), ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 7d2cd6a904..33afd59c72 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -416,7 +416,7 @@ def start(config_options): sys.exit(1) # For backwards compatibility let any of the old app names. - assert config.worker_app in ( + assert config.worker.worker_app in ( "synapse.app.appservice", "synapse.app.client_reader", "synapse.app.event_creator", @@ -430,7 +430,7 @@ def start(config_options): "synapse.app.user_dir", ) - if config.worker_app == "synapse.app.appservice": + if config.worker.worker_app == "synapse.app.appservice": if config.appservice.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" @@ -446,7 +446,7 @@ def start(config_options): # For other worker types we force this to off. config.appservice.notify_appservices = False - if config.worker_app == "synapse.app.user_dir": + if config.worker.worker_app == "synapse.app.user_dir": if config.server.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" @@ -469,7 +469,7 @@ def start(config_options): synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds hs = GenericWorkerServer( - config.server_name, + config.server.server_name, config=config, version_string="Synapse/" + get_version_string(synapse), ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 708db86f5d..b909f8db8d 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -350,7 +350,7 @@ def setup(config_options): synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds hs = SynapseHomeServer( - config.server_name, + config.server.server_name, config=config, version_string="Synapse/" + get_version_string(synapse), ) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 86ad7337a9..4a95da90f9 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -73,7 +73,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): store = hs.get_datastore() - stats["homeserver"] = hs.config.server_name + stats["homeserver"] = hs.config.server.server_name stats["server_context"] = hs.config.server_context stats["timestamp"] = now stats["uptime_seconds"] = uptime diff --git a/synapse/config/auth.py b/synapse/config/auth.py index 53809cee2e..ba8bf9cbe7 100644 --- a/synapse/config/auth.py +++ b/synapse/config/auth.py @@ -88,7 +88,7 @@ class AuthConfig(Config): # #require_lowercase: true - # Whether a password must contain at least one lowercase letter. + # Whether a password must contain at least one uppercase letter. # Defaults to 'false'. # #require_uppercase: true diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 4a398a7932..aca9d467e6 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -223,7 +223,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> # writes. log_context_filter = LoggingContextFilter() - log_metadata_filter = MetadataFilter({"server_name": config.server_name}) + log_metadata_filter = MetadataFilter({"server_name": config.server.server_name}) old_factory = logging.getLogRecordFactory() def factory(*args, **kwargs): @@ -335,5 +335,5 @@ def setup_logging( # Log immediately so we can grep backwards. logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) - logging.info("Server hostname: %s", config.server_name) + logging.info("Server hostname: %s", config.server.server_name) logging.info("Instance name: %s", hs.get_instance_name()) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index f856327bd8..36636ab07e 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -14,6 +14,8 @@ from typing import Dict, Optional +import attr + from ._base import Config @@ -29,18 +31,13 @@ class RateLimitConfig: self.burst_count = int(config.get("burst_count", defaults["burst_count"])) +@attr.s(auto_attribs=True) class FederationRateLimitConfig: - _items_and_default = { - "window_size": 1000, - "sleep_limit": 10, - "sleep_delay": 500, - "reject_limit": 50, - "concurrent": 3, - } - - def __init__(self, **kwargs): - for i in self._items_and_default.keys(): - setattr(self, i, kwargs.get(i) or self._items_and_default[i]) + window_size: int = 1000 + sleep_limit: int = 10 + sleep_delay: int = 500 + reject_limit: int = 50 + concurrent: int = 3 class RatelimitConfig(Config): @@ -69,11 +66,15 @@ class RatelimitConfig(Config): else: self.rc_federation = FederationRateLimitConfig( **{ - "window_size": config.get("federation_rc_window_size"), - "sleep_limit": config.get("federation_rc_sleep_limit"), - "sleep_delay": config.get("federation_rc_sleep_delay"), - "reject_limit": config.get("federation_rc_reject_limit"), - "concurrent": config.get("federation_rc_concurrent"), + k: v + for k, v in { + "window_size": config.get("federation_rc_window_size"), + "sleep_limit": config.get("federation_rc_sleep_limit"), + "sleep_delay": config.get("federation_rc_sleep_delay"), + "reject_limit": config.get("federation_rc_reject_limit"), + "concurrent": config.get("federation_rc_concurrent"), + }.items() + if v is not None } ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 33954b4f62..6eb6544c4c 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -88,7 +88,7 @@ class EventValidator: self._validate_retention(event) if event.type == EventTypes.ServerACL: - if not server_matches_acl_event(config.server_name, event): + if not server_matches_acl_event(config.server.server_name, event): raise SynapseError( 400, "Can't create an ACL event that denies the local server" ) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index d980e0d986..4671ac0242 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -22,6 +22,7 @@ from prometheus_client import Counter from typing_extensions import Literal from twisted.internet import defer +from twisted.internet.interfaces import IDelayedCall import synapse.metrics from synapse.api.presence import UserPresenceState @@ -280,11 +281,14 @@ class FederationSender(AbstractFederationSender): self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {} self._rr_txn_interval_per_room_ms = ( - 1000.0 / hs.config.federation_rr_transactions_per_room_per_second + 1000.0 + / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second ) # wake up destinations that have outstanding PDUs to be caught up - self._catchup_after_startup_timer = self.clock.call_later( + self._catchup_after_startup_timer: Optional[ + IDelayedCall + ] = self.clock.call_later( CATCH_UP_STARTUP_DELAY_SEC, run_as_background_process, "wake_destinations_needing_catchup", @@ -406,7 +410,7 @@ class FederationSender(AbstractFederationSender): now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) - + assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( "federation_sender" ).observe((now - ts) / 1000) @@ -435,6 +439,7 @@ class FederationSender(AbstractFederationSender): if events: now = self.clock.time_msec() ts = await self.store.get_received_ts(events[-1].event_id) + assert ts is not None synapse.metrics.event_processing_lag.labels( "federation_sender" diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index ff8372c4e9..53f99031b1 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -144,7 +144,7 @@ class GroupAttestionRenewer: self.is_mine_id = hs.is_mine_id self.attestations = hs.get_groups_attestation_signing() - if not hs.config.worker_app: + if not hs.config.worker.worker_app: self._renew_attestations_loop = self.clock.looping_call( self._start_renew_attestations, 30 * 60 * 1000 ) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 955cfa2207..c23ccd6dd9 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -45,16 +45,16 @@ class BaseHandler: self.request_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) - self._rc_message = self.hs.config.rc_message + self._rc_message = self.hs.config.ratelimiting.rc_message # Check whether ratelimiting room admin message redaction is enabled # by the presence of rate limits in the config - if self.hs.config.rc_admin_redaction: + if self.hs.config.ratelimiting.rc_admin_redaction: self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=self.hs.config.rc_admin_redaction.per_second, - burst_count=self.hs.config.rc_admin_redaction.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second, + burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count, ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 078accd634..a9c2222f46 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -78,7 +78,7 @@ class AccountValidityHandler: ) # Check the renewal emails to send and send them every 30min. - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] @@ -249,7 +249,7 @@ class AccountValidityHandler: renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( - self.hs.config.public_baseurl, + self.hs.config.server.public_baseurl, renewal_token, ) @@ -398,6 +398,7 @@ class AccountValidityHandler: """ now = self.clock.time_msec() if expiration_ts is None: + assert self._account_validity_period is not None expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 4ab4046650..a7b5a4e9c9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -131,6 +131,8 @@ class ApplicationServicesHandler: now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) + assert ts is not None + synapse.metrics.event_processing_lag_by_event.labels( "appservice_sender" ).observe((now - ts) / 1000) @@ -166,6 +168,7 @@ class ApplicationServicesHandler: if events: now = self.clock.time_msec() ts = await self.store.get_received_ts(events[-1].event_id) + assert ts is not None synapse.metrics.event_processing_lag.labels( "appservice_sender" diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 34725324a6..fbbf6fd834 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -244,8 +244,8 @@ class AuthHandler(BaseHandler): self._failed_uia_attempts_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, ) # The number of seconds to keep a UI auth session active. @@ -255,14 +255,14 @@ class AuthHandler(BaseHandler): self._failed_login_attempts_ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, ) self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, @@ -289,7 +289,7 @@ class AuthHandler(BaseHandler): hs.config.sso_account_deactivated_template ) - self._server_name = hs.config.server_name + self._server_name = hs.config.server.server_name # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) @@ -749,7 +749,7 @@ class AuthHandler(BaseHandler): "name": self.hs.config.user_consent_policy_name, "url": "%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, + self.hs.config.server.public_baseurl, self.hs.config.user_consent_version, ), }, @@ -1799,7 +1799,7 @@ class MacaroonGenerator: def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.macaroon_secret_key, ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 45d2404dde..dcd320c555 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -46,7 +46,7 @@ class DeactivateAccountHandler(BaseHandler): # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) self._account_validity_enabled = ( @@ -131,7 +131,7 @@ class DeactivateAccountHandler(BaseHandler): await self.store.add_user_pending_deactivation(user_id) # delete from user directory - await self.user_directory_handler.handle_user_deactivated(user_id) + await self.user_directory_handler.handle_local_user_deactivated(user_id) # Mark the user as erased, if they asked for that if erase_data: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 679b47f081..b6a2a34ab7 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -84,8 +84,8 @@ class DeviceMessageHandler: self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.rc_key_requests.per_second, - burst_count=hs.config.rc_key_requests.burst_count, + rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, + burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d92370859f..08a137561f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -57,7 +57,7 @@ class E2eKeysHandler: federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker_app is None + self._is_master = hs.config.worker.worker_app is None if not self._is_master: self._user_device_resync_client = ( ReplicationUserDevicesResyncRestServlet.make_client(hs) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 77df9185f6..6754c64c31 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -101,7 +101,7 @@ class FederationHandler(BaseHandler): hs ) - if hs.config.worker_app: + if hs.config.worker.worker_app: self._maybe_store_room_on_outlier_membership = ( ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) @@ -1614,7 +1614,7 @@ class FederationHandler(BaseHandler): Args: room_id """ - if self.config.worker_app: + if self.config.worker.worker_app: await self._clean_room_for_join_client(room_id) else: await self.store.clean_room_for_join(room_id) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9ec90ac8c1..946343fa25 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -149,7 +149,7 @@ class FederationEventHandler: self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) - if hs.config.worker_app: + if hs.config.worker.worker_app: self._user_device_resync = ( ReplicationUserDevicesResyncRestServlet.make_client(hs) ) @@ -1009,7 +1009,7 @@ class FederationEventHandler: await self._store.mark_remote_user_device_cache_as_stale(sender) # Immediately attempt a resync in the background - if self._config.worker_app: + if self._config.worker.worker_app: await self._user_device_resync(user_id=sender) else: await self._device_list_updater.user_device_resync(sender) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 8ffeabacf9..8b8f1f41ca 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -540,13 +540,13 @@ class IdentityHandler(BaseHandler): # It is already checked that public_baseurl is configured since this code # should only be used if account_threepid_delegate_msisdn is true. - assert self.hs.config.public_baseurl + assert self.hs.config.server.public_baseurl # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) data["submit_url"] = ( - self.hs.config.public_baseurl + self.hs.config.server.public_baseurl + "_matrix/client/unstable/add_threepid/msisdn/submit_token" ) return data diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 75d4e27723..60673cd4b8 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -84,7 +84,7 @@ class MessageHandler: # scheduled. self._scheduled_expiry: Optional[IDelayedCall] = None - if not hs.config.worker_app: + if not hs.config.worker.worker_app: run_as_background_process( "_schedule_next_expiry", self._schedule_next_expiry ) @@ -461,7 +461,7 @@ class EventCreationHandler: self._dummy_events_threshold = hs.config.dummy_events_threshold if ( - self.config.run_background_tasks + self.config.worker.run_background_tasks and self.config.cleanup_extremities_with_dummy_events ): self.clock.looping_call( diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 648fcf76f8..dfc251b2a5 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -324,7 +324,7 @@ class OidcProvider: self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() - self._server_name: str = hs.config.server_name + self._server_name: str = hs.config.server.server_name # identifier for the external_ids table self.idp_id = provider.idp_id diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1dbafd253d..7dc0ee4bef 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -91,7 +91,7 @@ class PaginationHandler: self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max - if hs.config.run_background_tasks and hs.config.retention_enabled: + if hs.config.worker.run_background_tasks and hs.config.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4418d63df7..39b39cd3e2 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -28,6 +28,7 @@ from bisect import bisect from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Any, Callable, Collection, Dict, @@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler): super().__init__(hs) self.hs = hs self.server_name = hs.hostname - self.wheel_timer = WheelTimer() + self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() self._presence_enabled = hs.config.use_presence @@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler): prev_state = await self.current_state_for_user(user_id) - new_fields = {"last_active_ts": self.clock.time_msec()} + new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 20a033d0ba..51adf8762d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -63,7 +63,7 @@ class ProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS ) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index c679a8303e..bd8160e7ed 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index fb495229a7..a49b8ee4b1 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,7 +29,7 @@ class ReceiptsHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c374a1fbc2..38c4993da0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -102,7 +102,7 @@ class RegistrationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() - if hs.config.worker_app: + if hs.config.worker.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( hs @@ -696,7 +696,7 @@ class RegistrationHandler(BaseHandler): address: the IP address used to perform the registration. shadow_banned: Whether to shadow-ban the user """ - if self.hs.config.worker_app: + if self.hs.config.worker.worker_app: await self._register_client( user_id=user_id, password_hash=password_hash, @@ -786,7 +786,7 @@ class RegistrationHandler(BaseHandler): Does the bits that need doing on the main process. Not for use outside this class and RegisterDeviceReplicationServlet. """ - assert not self.hs.config.worker_app + assert not self.hs.config.worker.worker_app valid_until_ms = None if self.session_lifetime is not None: if is_guest: @@ -843,7 +843,7 @@ class RegistrationHandler(BaseHandler): """ # TODO: 3pid registration can actually happen on the workers. Consider # refactoring it. - if self.hs.config.worker_app: + if self.hs.config.worker.worker_app: await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2932ed8a94..9345ae02e0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from synapse.api.constants import ( Membership, RoomCreationPreset, RoomEncryptionAlgorithms, + RoomTypes, ) from synapse.api.errors import ( AuthError, @@ -397,7 +398,7 @@ class RoomCreationHandler(BaseHandler): initial_state = {} # Replicate relevant room events - types_to_copy = ( + types_to_copy: List[Tuple[str, Optional[str]]] = [ (EventTypes.JoinRules, ""), (EventTypes.Name, ""), (EventTypes.Topic, ""), @@ -408,7 +409,16 @@ class RoomCreationHandler(BaseHandler): (EventTypes.ServerACL, ""), (EventTypes.RelatedGroups, ""), (EventTypes.PowerLevels, ""), - ) + ] + + # If the old room was a space, copy over the room type and the rooms in + # the space. + if ( + old_room_create_event.content.get(EventContentFields.ROOM_TYPE) + == RoomTypes.SPACE + ): + creation_content[EventContentFields.ROOM_TYPE] = RoomTypes.SPACE + types_to_copy.append((EventTypes.SpaceChild, None)) old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types(types_to_copy) @@ -419,6 +429,11 @@ class RoomCreationHandler(BaseHandler): for k, old_event_id in old_room_state_ids.items(): old_event = old_room_state_events.get(old_event_id) if old_event: + # If the event is an space child event with empty content, it was + # removed from the space and should be ignored. + if k[0] == EventTypes.SpaceChild and not old_event.content: + continue + initial_state[k] = old_event.content # deep-copy the power-levels event before we start modifying it diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index 077c7c0649..d30ba2b724 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from enum import Enum, auto from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -21,6 +22,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class MatchChange(Enum): + no_change = auto() + now_true = auto() + now_false = auto() + + class StateDeltasHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -31,18 +38,12 @@ class StateDeltasHandler: event_id: Optional[str], key_name: str, public_value: str, - ) -> Optional[bool]: + ) -> MatchChange: """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. For example, check if `history_visibility` (`key_name`) changed from `shared` to `world_readable` (`public_value`). - - Returns: - None if the field in the events either both match `public_value` - or if neither do, i.e. there has been no change. - True if it didn't match `public_value` but now does - False if it did match `public_value` but now doesn't """ prev_event = None event = None @@ -54,7 +55,7 @@ class StateDeltasHandler: if not event and not prev_event: logger.debug("Neither event exists: %r %r", prev_event_id, event_id) - return None + return MatchChange.no_change prev_value = None value = None @@ -68,8 +69,8 @@ class StateDeltasHandler: logger.debug("prev_value: %r -> value: %r", prev_value, value) if value == public_value and prev_value != public_value: - return True + return MatchChange.now_true elif value != public_value and prev_value == public_value: - return False + return MatchChange.now_false else: - return None + return MatchChange.no_change diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 47f2e2a0c1..b64ce8cab8 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -54,7 +54,7 @@ class StatsHandler: # Guard to ensure we only process deltas one at a time self._is_processing = False - if self.stats_enabled and hs.config.run_background_tasks: + if self.stats_enabled and hs.config.worker.run_background_tasks: self.notifier.add_replication_callback(self.notify_new_event) # We kick this off so that we don't have to wait for a change before diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a97c448595..9cea011e62 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -53,7 +53,7 @@ class FollowerTypingHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -73,7 +73,7 @@ class FollowerTypingHandler: self._room_typing: Dict[str, Set[str]] = {} self._member_last_federation_poke: Dict[RoomMember, int] = {} - self.wheel_timer = WheelTimer(bucket_size=5000) + self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 6edb1da50a..6faa1d84be 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import synapse.metrics from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership -from synapse.handlers.state_deltas import StateDeltasHandler +from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict @@ -30,14 +30,26 @@ logger = logging.getLogger(__name__) class UserDirectoryHandler(StateDeltasHandler): - """Handles querying of and keeping updated the user_directory. + """Handles queries and updates for the user_directory. N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY - The user directory is filled with users who this server can see are joined to a - world_readable or publicly joinable room. We keep a database table up to date - by streaming changes of the current state and recalculating whether users should - be in the directory or not when necessary. + When a local user searches the user_directory, we report two kinds of users: + + - users this server can see are joined to a world_readable or publicly + joinable room, and + - users belonging to a private room shared by that local user. + + The two cases are tracked separately in the `users_in_public_rooms` and + `users_who_share_private_rooms` tables. Both kinds of users have their + username and avatar tracked in a `user_directory` table. + + This handler has three responsibilities: + 1. Forwarding requests to `/user_directory/search` to the UserDirectoryStore. + 2. Providing hooks for the application to call when local users are added, + removed, or have their profile changed. + 3. Listening for room state changes that indicate remote users have + joined or left a room, or that their profile has changed. """ def __init__(self, hs: "HomeServer"): @@ -130,7 +142,7 @@ class UserDirectoryHandler(StateDeltasHandler): user_id, profile.display_name, profile.avatar_url ) - async def handle_user_deactivated(self, user_id: str) -> None: + async def handle_local_user_deactivated(self, user_id: str) -> None: """Called when a user ID is deactivated""" # FIXME(#3714): We should probably do this in the same worker as all # the other changes. @@ -196,7 +208,7 @@ class UserDirectoryHandler(StateDeltasHandler): public_value=Membership.JOIN, ) - if change is False: + if change is MatchChange.now_false: # Need to check if the server left the room entirely, if so # we might need to remove all the users in that room is_in_room = await self.store.is_host_joined( @@ -219,14 +231,14 @@ class UserDirectoryHandler(StateDeltasHandler): is_support = await self.store.is_support_user(state_key) if not is_support: - if change is None: + if change is MatchChange.no_change: # Handle any profile changes await self._handle_profile_change( state_key, room_id, prev_event_id, event_id ) continue - if change: # The user joined + if change is MatchChange.now_true: # The user joined event = await self.store.get_event(event_id, allow_none=True) # It isn't expected for this event to not exist, but we # don't want the entire background process to break. @@ -263,14 +275,14 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Handling change for %s: %s", typ, room_id) if typ == EventTypes.RoomHistoryVisibility: - change = await self._get_key_change( + publicness = await self._get_key_change( prev_event_id, event_id, key_name="history_visibility", public_value=HistoryVisibility.WORLD_READABLE, ) elif typ == EventTypes.JoinRules: - change = await self._get_key_change( + publicness = await self._get_key_change( prev_event_id, event_id, key_name="join_rule", @@ -278,9 +290,7 @@ class UserDirectoryHandler(StateDeltasHandler): ) else: raise Exception("Invalid event type") - # If change is None, no change. True => become world_readable/public, - # False => was world_readable/public - if change is None: + if publicness is MatchChange.no_change: logger.debug("No change") return @@ -290,13 +300,13 @@ class UserDirectoryHandler(StateDeltasHandler): room_id ) - logger.debug("Change: %r, is_public: %r", change, is_public) + logger.debug("Change: %r, publicness: %r", publicness, is_public) - if change and not is_public: + if publicness is MatchChange.now_true and not is_public: # If we became world readable but room isn't currently public then # we ignore the change return - elif not change and is_public: + elif publicness is MatchChange.now_false and is_public: # If we stopped being world readable but are still public, # ignore the change return diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ecd51f1b4a..c6c4d3bd29 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -236,8 +236,17 @@ except ImportError: try: from rust_python_jaeger_reporter import Reporter + # jaeger-client 4.7.0 requires that reporters inherit from BaseReporter, which + # didn't exist before that version. + try: + from jaeger_client.reporter import BaseReporter + except ImportError: + + class BaseReporter: # type: ignore[no-redef] + pass + @attr.s(slots=True, frozen=True) - class _WrappedRustReporter: + class _WrappedRustReporter(BaseReporter): """Wrap the reporter to ensure `report_span` never throws.""" _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) @@ -374,7 +383,7 @@ def init_tracer(hs: "HomeServer"): config = JaegerConfig( config=hs.config.jaeger_config, - service_name=f"{hs.config.server_name} {hs.get_instance_name()}", + service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), ) @@ -382,6 +391,7 @@ def init_tracer(hs: "HomeServer"): # If we have the rust jaeger reporter available let's use that. if RustReporter: logger.info("Using rust_python_jaeger_reporter library") + assert config.sampler is not None tracer = config.create_tracer(RustReporter(), config.sampler) opentracing.set_global_tracer(tracer) else: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b11fa6393b..2d403532fa 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -178,7 +178,7 @@ class ModuleApi: @property def public_baseurl(self) -> str: """The configured public base URL for this homeserver.""" - return self._hs.config.public_baseurl + return self._hs.config.server.public_baseurl @property def email_app_name(self) -> str: @@ -640,7 +640,7 @@ class ModuleApi: if desc is None: desc = f.__name__ - if self._hs.config.run_background_tasks or run_on_all_instances: + if self._hs.config.worker.run_background_tasks or run_on_all_instances: self._clock.looping_call( run_as_background_process, msec, diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b0834720ad..b89c6e6f2b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -130,7 +130,7 @@ class Mailer: """ params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( - self.hs.config.public_baseurl + self.hs.config.server.public_baseurl + "_synapse/client/password_reset/email/submit_token?%s" % urllib.parse.urlencode(params) ) @@ -140,7 +140,7 @@ class Mailer: await self.send_email( email_address, self.email_subjects.password_reset - % {"server_name": self.hs.config.server_name}, + % {"server_name": self.hs.config.server.server_name}, template_vars, ) @@ -160,7 +160,7 @@ class Mailer: """ params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( - self.hs.config.public_baseurl + self.hs.config.server.public_baseurl + "_matrix/client/unstable/registration/email/submit_token?%s" % urllib.parse.urlencode(params) ) @@ -170,7 +170,7 @@ class Mailer: await self.send_email( email_address, self.email_subjects.email_validation - % {"server_name": self.hs.config.server_name}, + % {"server_name": self.hs.config.server.server_name}, template_vars, ) @@ -191,7 +191,7 @@ class Mailer: """ params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( - self.hs.config.public_baseurl + self.hs.config.server.public_baseurl + "_matrix/client/unstable/add_threepid/email/submit_token?%s" % urllib.parse.urlencode(params) ) @@ -201,7 +201,7 @@ class Mailer: await self.send_email( email_address, self.email_subjects.email_validation - % {"server_name": self.hs.config.server_name}, + % {"server_name": self.hs.config.server.server_name}, template_vars, ) @@ -852,7 +852,7 @@ class Mailer: # XXX: make r0 once API is stable return "%s_matrix/client/unstable/pushers/remove?%s" % ( - self.hs.config.public_baseurl, + self.hs.config.server.public_baseurl, urllib.parse.urlencode(params), ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3fd2811713..37769ace48 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -73,7 +73,7 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory): ): self.client_name = client_name self.command_handler = command_handler - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index eae4515363..509ed7fb13 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -168,7 +168,7 @@ class ReplicationCommandHandler: continue # Only add any other streams if we're on master. - if hs.config.worker_app is not None: + if hs.config.worker.worker_app is not None: continue if stream.NAME == FederationStream.NAME and hs.config.send_federation: @@ -222,7 +222,7 @@ class ReplicationCommandHandler: }, ) - self._is_master = hs.config.worker_app is None + self._is_master = hs.config.worker.worker_app is None self._federation_sender = None if self._is_master and not hs.config.send_federation: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index bd47d84258..030852cb5b 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -40,7 +40,7 @@ class ReplicationStreamProtocolFactory(Factory): def __init__(self, hs): self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name # If we've created a `ReplicationStreamProtocolFactory` then we're # almost certainly registering a replication listener, so let's ensure diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index c445af9bd9..0600cdbf36 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -42,7 +42,7 @@ class FederationStream(Stream): ROW_TYPE = FederationStreamRow def __init__(self, hs: "HomeServer"): - if hs.config.worker_app is None: + if hs.config.worker.worker_app is None: # master process: get updates from the FederationRemoteSendQueue. # (if the master is configured to send federation itself, federation_sender # will be a real FederationSender, which has stubs for current_token and diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b2514d9d0d..a03774c98a 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -247,7 +247,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RegistrationTokenRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. - if hs.config.worker_app is None: + if hs.config.worker.worker_app is None: SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index df8cc4ac7a..7bb7801472 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -68,7 +68,10 @@ class AuthRestServlet(RestServlet): html = self.terms_template.render( session=session, terms_url="%s_matrix/consent?v=%s" - % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + % ( + self.hs.config.server.public_baseurl, + self.hs.config.user_consent_version, + ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), ) @@ -135,7 +138,7 @@ class AuthRestServlet(RestServlet): session=session, terms_url="%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, + self.hs.config.server.public_baseurl, self.hs.config.user_consent_version, ), myurl="%s/r0/auth/%s/fallback/web" diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index bcba106bdd..a6ede7e2f3 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -93,14 +93,14 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter = Ratelimiter( store=hs.get_datastore(), clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( store=hs.get_datastore(), clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, + rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, + burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. @@ -486,7 +486,7 @@ class SsoRedirectServlet(RestServlet): # register themselves with the main SSOHandler. _load_sso_handlers(hs) self._sso_handler = hs.get_sso_handler() - self._public_baseurl = hs.config.public_baseurl + self._public_baseurl = hs.config.server.public_baseurl async def on_GET( self, request: SynapseRequest, idp_id: Optional[str] = None diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index 4dda6dce4b..add56d6998 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -69,7 +69,7 @@ class IdTokenServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() - self.server_name = hs.config.server_name + self.server_name = hs.config.server.server_name async def on_POST( self, request: SynapseRequest, user_id: str diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index fb3211bf3a..ecebc46e8d 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -59,7 +59,7 @@ class PushRuleRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self._is_worker = hs.config.worker.worker_app is not None self._users_new_default_push_rules = hs.config.users_new_default_push_rules diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 8f3dd2a101..abe4d7e205 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet): # Artificially delay requests if rate > sleep_limit/window_size sleep_limit=1, # Amount of artificial delay to apply - sleep_msec=1000, + sleep_delay=1000, # Error with 429 if more than reject_limit requests are queued reject_limit=1, # Allow 1 request at a time - concurrent_requests=1, + concurrent=1, ), ) @@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet): Returns: dictionary for response from /register """ - result = {"user_id": user_id, "home_server": self.hs.hostname} + result: JsonDict = { + "user_id": user_id, + "home_server": self.hs.hostname, + } if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") @@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name, is_guest=True ) - result = { + result: JsonDict = { "user_id": user_id, "device_id": device_id, "access_token": access_token, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 9b0c546505..bf46dc60f2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -388,7 +388,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: + if server and server != self.hs.config.server.server_name: # Ensure the server is valid. try: parse_and_validate_server_name(server) @@ -438,7 +438,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: + if server and server != self.hs.config.server.server_name: # Ensure the server is valid. try: parse_and_validate_server_name(server) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index a5fcd15e3a..25f6eb842f 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -86,12 +86,12 @@ class LocalKey(Resource): json_object = { "valid_until_ts": self.valid_until_ts, - "server_name": self.config.server_name, + "server_name": self.config.server.server_name, "verify_keys": verify_keys, "old_verify_keys": old_verify_keys, } for key in self.config.signing_key: - json_object = sign_json(json_object, self.config.server_name, key) + json_object = sign_json(json_object, self.config.server.server_name, key) return json_object def render_GET(self, request): diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 63a40b1852..744360e5fd 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -224,7 +224,9 @@ class RemoteKey(DirectServeJsonResource): for key_json in json_results: key_json = json_decoder.decode(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: - key_json = sign_json(key_json, self.config.server_name, signing_key) + key_json = sign_json( + key_json, self.config.server.server_name, signing_key + ) signed_keys.append(key_json) diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 3869d18003..67c1ed1f5f 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource): yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir - self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) async def _async_render_GET(self, request: Request) -> None: try: diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index c15b83c387..d30b478b98 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource): yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir - self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config) async def _async_render_GET(self, request: Request) -> None: try: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 19ac3af337..6a66a88c53 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,10 +34,10 @@ class WellKnownBuilder: def get_well_known(self): # if we don't have a public_baseurl, we can't help much here. - if self._config.public_baseurl is None: + if self._config.server.public_baseurl is None: return None - result = {"m.homeserver": {"base_url": self._config.public_baseurl}} + result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} if self._config.default_identity_server: result["m.identity_server"] = { diff --git a/synapse/server.py b/synapse/server.py index 5adeeff61a..4777ef585d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -313,7 +313,7 @@ class HomeServer(metaclass=abc.ABCMeta): # Register background tasks required by this server. This must be done # somewhat manually due to the background tasks not being registered # unless handlers are instantiated. - if self.config.run_background_tasks: + if self.config.worker.run_background_tasks: self.setup_background_tasks() def start_listening(self) -> None: @@ -370,8 +370,8 @@ class HomeServer(metaclass=abc.ABCMeta): return Ratelimiter( store=self.get_datastore(), clock=self.get_clock(), - rate_hz=self.config.rc_registration.per_second, - burst_count=self.config.rc_registration.burst_count, + rate_hz=self.config.ratelimiting.rc_registration.per_second, + burst_count=self.config.ratelimiting.rc_registration.burst_count, ) @cache_in_self @@ -498,7 +498,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_device_handler(self): - if self.config.worker_app: + if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: return DeviceHandler(self) @@ -621,7 +621,7 @@ class HomeServer(metaclass=abc.ABCMeta): def get_federation_sender(self) -> AbstractFederationSender: if self.should_send_federation(): return FederationSender(self) - elif not self.config.worker_app: + elif not self.config.worker.worker_app: return FederationRemoteSendQueue(self) else: raise Exception("Workers cannot send federation traffic") @@ -650,14 +650,14 @@ class HomeServer(metaclass=abc.ABCMeta): def get_groups_local_handler( self, ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: - if self.config.worker_app: + if self.config.worker.worker_app: return GroupsLocalWorkerHandler(self) else: return GroupsLocalHandler(self) @cache_in_self def get_groups_server_handler(self): - if self.config.worker_app: + if self.config.worker.worker_app: return GroupsServerWorkerHandler(self) else: return GroupsServerHandler(self) @@ -684,7 +684,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_room_member_handler(self) -> RoomMemberHandler: - if self.config.worker_app: + if self.config.worker.worker_app: return RoomMemberWorkerHandler(self) return RoomMemberMasterHandler(self) @@ -694,13 +694,13 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_server_notices_manager(self) -> ServerNoticesManager: - if self.config.worker_app: + if self.config.worker.worker_app: raise Exception("Workers cannot send server notices") return ServerNoticesManager(self) @cache_in_self def get_server_notices_sender(self) -> WorkerServerNoticesSender: - if self.config.worker_app: + if self.config.worker.worker_app: return WorkerServerNoticesSender(self) return ServerNoticesSender(self) @@ -766,7 +766,9 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: - return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation) + return FederationRateLimiter( + self.get_clock(), config=self.config.ratelimiting.rc_federation + ) @cache_in_self def get_module_api(self) -> ModuleApi: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 00a644e8f7..1dc347f0c9 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -271,7 +271,7 @@ class DataStore( def get_users_paginate_txn(txn): filters = [] - args = [self.hs.config.server_name] + args = [self.hs.config.server.server_name] # Set ordering order_by_column = UserSortOrder(order_by).value @@ -356,13 +356,13 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig return user_domain = get_domain_from_id(rows[0][0]) - if user_domain == config.server_name: + if user_domain == config.server.server_name: return raise Exception( "Found users in database not native to %s!\n" "You cannot change a synapse server_name after it's been configured" - % (config.server_name,) + % (config.server.server_name,) ) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index f22c1f241b..6305414e3d 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -35,7 +35,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase super().__init__(database, db_conn, hs) if ( - hs.config.run_background_tasks + hs.config.worker.run_background_tasks and self.hs.config.redaction_retention_period is not None ): hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index b04867fedf..2712514145 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -355,7 +355,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): self.user_ips_max_age = hs.config.user_ips_max_age - if hs.config.run_background_tasks and self.user_ips_max_age: + if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) @wrap_as_background_process("prune_old_user_ips") diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3816a0ca53..6464520386 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -51,7 +51,7 @@ class DeviceWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index bddf5ef192..047782eb06 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -62,7 +62,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 55caa6bbe7..97b3e92d3f 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -82,7 +82,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._rotate_delay = 3 self._rotate_count = 10000 self._doing_notif_rotation = False - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._rotate_notif_loop = self._clock.looping_call( self._rotate_notifs, 30 * 60 * 1000 ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 9501f00f3b..d72e716b5c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -158,7 +158,7 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( self._cleanup_old_transaction_ids, diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index dc0bbc56ac..dac3d14da8 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -56,7 +56,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) # Used in _generate_user_daily_visits to keep track of progress diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index a6517962f6..fafadb88fc 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -132,14 +132,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): hs.config.account_validity.account_validity_startup_job_max_delta ) - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.call_later( 0.0, self._set_expiration_date_when_missing, ) # Create a background job for culling expired 3PID validity tokens - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() + assert self._account_validity_period is not None expiration_ts = now_ms + self._account_validity_period if use_delta: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6e7312266d..118b390e93 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -815,7 +815,7 @@ class RoomWorkerStore(SQLBaseStore): If it is `None` media will be removed from quarantine """ logger.info("Quarantining media: %s/%s", server_name, media_id) - is_local = server_name == self.config.server_name + is_local = server_name == self.config.server.server_name def _quarantine_media_by_id_txn(txn): local_mxcs = [media_id] if is_local else [] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c58a4b8690..9beeb96aa9 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -81,7 +81,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.close() if ( - self.hs.config.run_background_tasks + self.hs.config.worker.run_background_tasks and self.hs.config.metrics_flags.known_servers ): self._known_servers_count = 1 @@ -196,6 +196,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) -> Dict[str, ProfileInfo]: """Get a mapping from user ID to profile information for all users in a given room. + The profile information comes directly from this room's `m.room.member` + events, and so may be specific to this room rather than part of a user's + global profile. To avoid privacy leaks, the profile data should only be + revealed to users who are already in this room. + Args: room_id: The ID of the room to retrieve the users of. diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py index 172f27d109..5a97120437 100644 --- a/synapse/storage/databases/main/session.py +++ b/synapse/storage/databases/main/session.py @@ -48,7 +48,7 @@ class SessionStore(SQLBaseStore): super().__init__(database, db_conn, hs) # Create a background job for culling expired sessions. - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) async def create_session( diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 4245fa1a3c..343d6efc92 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -672,7 +672,7 @@ class StatsStore(StateDeltasStore): def get_users_media_usage_paginate_txn(txn): filters = [] - args = [self.hs.config.server_name] + args = [self.hs.config.server.server_name] if search_term: filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)") diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 7728d5f102..860146cd1b 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -60,7 +60,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if hs.config.run_background_tasks: + if hs.config.worker.run_background_tasks: self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) @wrap_as_background_process("cleanup_transactions") diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 65dde67ae9..8aebdc2817 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -196,7 +196,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) users_with_profile = await self.get_users_in_room_with_profiles(room_id) - user_ids = set(users_with_profile) # Update each user in the user directory. for user_id, profile in users_with_profile.items(): @@ -207,7 +206,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): to_insert = set() if is_public: - for user_id in user_ids: + for user_id in users_with_profile: if self.get_if_app_services_interested_in_user(user_id): continue @@ -217,14 +216,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): await self.add_users_in_public_rooms(room_id, to_insert) to_insert.clear() else: - for user_id in user_ids: + for user_id in users_with_profile: if not self.hs.is_mine_id(user_id): continue if self.get_if_app_services_interested_in_user(user_id): continue - for other_user_id in user_ids: + for other_user_id in users_with_profile: if user_id == other_user_id: continue @@ -511,7 +510,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): self._prefer_local_users_in_search = ( hs.config.user_directory_search_prefer_local_users ) - self._server_name = hs.config.server_name + self._server_name = hs.config.server.server_name async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 61392b9639..d4754c904c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -134,7 +134,7 @@ def prepare_database( # if it's a worker app, refuse to upgrade the database, to avoid multiple # workers doing it at once. if ( - config.worker_app is not None + config.worker.worker_app is not None and version_info.current_version != SCHEMA_VERSION ): raise UpgradeDatabaseException( @@ -154,7 +154,7 @@ def prepare_database( # if it's a worker app, refuse to upgrade the database, to avoid multiple # workers doing it at once. - if config and config.worker_app is not None: + if config and config.worker.worker_app is not None: raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR) _setup_new_database(cur, database_engine, databases=databases) @@ -355,7 +355,7 @@ def _upgrade_existing_database( else: assert config - is_worker = config and config.worker_app is not None + is_worker = config and config.worker.worker_app is not None if ( current_schema_state.compat_version is not None diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 7f08fabe9f..8a1f340083 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -38,7 +38,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): logger.warning("Could not get app_service_config_files from config") pass - appservices = load_appservices(config.server_name, config_files) + appservices = load_appservices(config.server.server_name, config_files) owned = {} diff --git a/synapse/storage/schema/main/delta/57/local_current_membership.py b/synapse/storage/schema/main/delta/57/local_current_membership.py index 66989222e6..d25093c19f 100644 --- a/synapse/storage/schema/main/delta/57/local_current_membership.py +++ b/synapse/storage/schema/main/delta/57/local_current_membership.py @@ -67,7 +67,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - cur.execute(sql, ("%:" + config.server_name,)) + cur.execute(sql, ("%:" + config.server.server_name,)) cur.execute( "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" diff --git a/synapse/types.py b/synapse/types.py index 80fa903c4b..d4759b2dfd 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -38,6 +38,7 @@ from twisted.internet.interfaces import ( IReactorCore, IReactorPluggableNameResolver, IReactorTCP, + IReactorThreads, IReactorTime, ) @@ -63,7 +64,12 @@ JsonDict = Dict[str, Any] # Note that this seems to require inheriting *directly* from Interface in order # for mypy-zope to realize it is an interface. class ISynapseReactor( - IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface + IReactorTCP, + IReactorPluggableNameResolver, + IReactorTime, + IReactorCore, + IReactorThreads, + Interface, ): """The interfaces necessary for Synapse to function.""" diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b69f562ca5..bd234549bd 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,27 +15,35 @@ import json import logging import re -from typing import Pattern +import typing +from typing import Any, Callable, Dict, Generator, Pattern import attr from frozendict import frozendict from twisted.internet import defer, task +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IDelayedCall, IReactorTime +from twisted.internet.task import LoopingCall +from twisted.python.failure import Failure from synapse.logging import context +if typing.TYPE_CHECKING: + pass + logger = logging.getLogger(__name__) _WILDCARD_RUN = re.compile(r"([\?\*]+)") -def _reject_invalid_json(val): +def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) -def _handle_frozendict(obj): +def _handle_frozendict(obj: Any) -> Dict[Any, Any]: """Helper for json_encoder. Makes frozendicts serializable by returning the underlying dict """ @@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder( json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) -def unwrapFirstError(failure): +def unwrapFirstError(failure: Failure) -> Failure: # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) - return failure.value.subFailure + return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations @attr.s(slots=True) @@ -75,25 +83,25 @@ class Clock: reactor: The Twisted reactor to use. """ - _reactor = attr.ib() + _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks - def sleep(self, seconds): - d = defer.Deferred() + @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations + def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": + d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) res = yield d return res - def time(self): + def time(self) -> float: """Returns the current system time in seconds since epoch.""" return self._reactor.seconds() - def time_msec(self): + def time_msec(self) -> int: """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) - def looping_call(self, f, msec, *args, **kwargs): + def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: """Call a function repeatedly. Waits `msec` initially before calling `f` for the first time. @@ -102,8 +110,8 @@ class Clock: other than trivial, you probably want to wrap it in run_as_background_process. Args: - f(function): The function to call repeatedly. - msec(float): How long to wait between calls in milliseconds. + f: The function to call repeatedly. + msec: How long to wait between calls in milliseconds. *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ @@ -113,7 +121,7 @@ class Clock: d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call - def call_later(self, delay, callback, *args, **kwargs): + def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: """Call something later Note that the function will be called with no logcontext, so if it is anything @@ -133,7 +141,7 @@ class Clock: with context.PreserveLoggingContext(): return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) - def cancel_call_later(self, timer, ignore_errs=False): + def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: try: timer.cancel() except Exception: diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a3b65aee27..82d918a05f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -37,6 +37,7 @@ import attr from typing_extensions import ContextManager from twisted.internet import defer +from twisted.internet.base import ReactorBase from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime from twisted.python import failure @@ -268,6 +269,7 @@ class Linearizer: if not clock: from twisted.internet import reactor + assert isinstance(reactor, ReactorBase) clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -411,7 +413,7 @@ class ReadWriteLock: # writers and readers have been resolved. The new writer replaces the latest # writer. - def __init__(self): + def __init__(self) -> None: # Latest readers queued self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} @@ -503,7 +505,7 @@ def timeout_deferred( timed_out = [False] - def time_it_out(): + def time_it_out() -> None: timed_out[0] = True try: @@ -550,19 +552,21 @@ def timeout_deferred( return new_d +# This class can't be generic because it uses slots with attrs. +# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True) -class DoneAwaitable: +class DoneAwaitable: # should be: Generic[R] """Simple awaitable that returns the provided value.""" - value = attr.ib() + value = attr.ib(type=Any) # should be: R def __await__(self): return self - def __iter__(self): + def __iter__(self) -> "DoneAwaitable": return self - def __next__(self): + def __next__(self) -> None: raise StopIteration(self.value) diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 274cea7eb7..2a903004a9 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]): # First we create a defer and add it and the value to the list of # pending items. - d = defer.Deferred() + d: defer.Deferred[R] = defer.Deferred() self._next_values.setdefault(key, []).append((value, d)) # If we're not currently processing the key fire off a background diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 9012034b7a..cab1bf0c15 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -64,32 +64,32 @@ class CacheMetric: evicted_size = attr.ib(default=0) memory_usage = attr.ib(default=None) - def inc_hits(self): + def inc_hits(self) -> None: self.hits += 1 - def inc_misses(self): + def inc_misses(self) -> None: self.misses += 1 - def inc_evictions(self, size=1): + def inc_evictions(self, size: int = 1) -> None: self.evicted_size += size - def inc_memory_usage(self, memory: int): + def inc_memory_usage(self, memory: int) -> None: if self.memory_usage is None: self.memory_usage = 0 self.memory_usage += memory - def dec_memory_usage(self, memory: int): + def dec_memory_usage(self, memory: int) -> None: self.memory_usage -= memory - def clear_memory_usage(self): + def clear_memory_usage(self) -> None: if self.memory_usage is not None: self.memory_usage = 0 def describe(self): return [] - def collect(self): + def collect(self) -> None: try: if self._cache_type == "response_cache": response_cache_size.labels(self._cache_name).set(len(self._cache)) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index b6456392cd..f05590da0d 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]): TreeCache, "MutableMapping[KT, CacheEntry]" ] = cache_type() - def metrics_cb(): + def metrics_cb() -> None: cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) # cache is used for completed results and maps to the result itself, rather than @@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]): def max_entries(self): return self.cache.max_size - def check_thread(self): + def check_thread(self) -> None: expected_thread = self.thread if expected_thread is None: self.thread = threading.current_thread() @@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]): self._pending_deferred_cache[key] = entry - def compare_and_pop(): + def compare_and_pop() -> bool: """Check if our entry is still the one in _pending_deferred_cache, and if so, pop it. @@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]): return False - def cb(result): + def cb(result) -> None: if compare_and_pop(): self.cache.set(key, result, entry.callbacks) else: @@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]): # not have been. Either way, let's double-check now. entry.invalidate() - def eb(_fail): + def eb(_fail) -> None: compare_and_pop() entry.invalidate() @@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]): for entry in iterate_tree_cache_entry(entry): entry.invalidate() - def invalidate_all(self): + def invalidate_all(self) -> None: self.check_thread() self.cache.clear() for entry in self._pending_deferred_cache.values(): @@ -332,7 +332,7 @@ class CacheEntry: self.callbacks = set(callbacks) self.invalidated = False - def invalidate(self): + def invalidate(self) -> None: if not self.invalidated: self.invalidated = True for callback in self.callbacks: diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 3f852edd7f..ade088aae2 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -27,10 +27,14 @@ logger = logging.getLogger(__name__) KT = TypeVar("KT") # The type of the dictionary keys. DKT = TypeVar("DKT") +# The type of the dictionary values. +DV = TypeVar("DV") +# This class can't be generic because it uses slots with attrs. +# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True) -class DictionaryEntry: +class DictionaryEntry: # should be: Generic[DKT, DV]. """Returned when getting an entry from the cache Attributes: @@ -43,10 +47,10 @@ class DictionaryEntry: """ full = attr.ib(type=bool) - known_absent = attr.ib() - value = attr.ib() + known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT] + value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV] - def __len__(self): + def __len__(self) -> int: return len(self.value) @@ -56,7 +60,7 @@ class _Sentinel(enum.Enum): sentinel = object() -class DictionaryCache(Generic[KT, DKT]): +class DictionaryCache(Generic[KT, DKT, DV]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ @@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]): Args: key - dict_key: If given a set of keys then return only those keys + dict_keys: If given a set of keys then return only those keys that exist in the cache. Returns: @@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]): self, sequence: int, key: KT, - value: Dict[DKT, Any], + value: Dict[DKT, DV], fetched_keys: Optional[Set[DKT]] = None, ) -> None: """Updates the entry in the cache @@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]): self._update_or_insert(key, value, fetched_keys) def _update_or_insert( - self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] + self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed - entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) + entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) self.cache[key] = entry - def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: + def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None: self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 5c65d187b6..39dce9dd41 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -35,6 +35,7 @@ from typing import ( from typing_extensions import Literal from twisted.internet import reactor +from twisted.internet.interfaces import IReactorTime from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]): # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. if clock is None: - real_clock = Clock(reactor) + real_clock = Clock(cast(IReactorTime, reactor)) else: real_clock = clock @@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]): lock = threading.Lock() - def evict(): + def evict() -> None: while cache_len() > self.max_size: # Get the last node in the list (i.e. the oldest node). todelete = list_root.prev_node diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 3a41a8baa6..27b1da235e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -195,7 +195,7 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] - def _evict(self): + def _evict(self) -> None: while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 4138931e7b..563845f867 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -35,17 +35,17 @@ class TreeCache: root = {key_1: {key_2: _value}} """ - def __init__(self): - self.size = 0 + def __init__(self) -> None: + self.size: int = 0 self.root = TreeCacheNode() - def __setitem__(self, key, value): - return self.set(key, value) + def __setitem__(self, key, value) -> None: + self.set(key, value) - def __contains__(self, key): + def __contains__(self, key) -> bool: return self.get(key, SENTINEL) is not SENTINEL - def set(self, key, value): + def set(self, key, value) -> None: if isinstance(value, TreeCacheNode): # this would mean we couldn't tell where our tree ended and the value # started. @@ -73,7 +73,7 @@ class TreeCache: return default return node.get(key[-1], default) - def clear(self): + def clear(self) -> None: self.size = 0 self.root = TreeCacheNode() @@ -128,7 +128,7 @@ class TreeCache: def values(self): return iterate_tree_cache_entry(self.root) - def __len__(self): + def __len__(self) -> int: return self.size diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index d8532411c2..f1a351cfd4 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - signal.signal(signal.SIGTERM, sigterm) # Cleanup pid file at exit. - def exit(): + def exit() -> None: logger.warning("Stopping daemon.") os.remove(pid_file) sys.exit(0) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 1f803aef6d..31097d6439 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Callable, Dict, List from twisted.internet import defer @@ -37,11 +38,11 @@ class Distributor: model will do for today. """ - def __init__(self): - self.signals = {} - self.pre_registration = {} + def __init__(self) -> None: + self.signals: Dict[str, Signal] = {} + self.pre_registration: Dict[str, List[Callable]] = {} - def declare(self, name): + def declare(self, name: str) -> None: if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) @@ -52,7 +53,7 @@ class Distributor: for observer in self.pre_registration[name]: signal.observe(observer) - def observe(self, name, observer): + def observe(self, name: str, observer: Callable) -> None: if name in self.signals: self.signals[name].observe(observer) else: @@ -62,7 +63,7 @@ class Distributor: self.pre_registration[name] = [] self.pre_registration[name].append(observer) - def fire(self, name, *args, **kwargs): + def fire(self, name: str, *args, **kwargs) -> None: """Dispatches the given signal to the registered observers. Runs the observers as a background process. Does not return a deferred. @@ -83,18 +84,18 @@ class Signal: method into all of the observers. """ - def __init__(self, name): - self.name = name - self.observers = [] + def __init__(self, name: str): + self.name: str = name + self.observers: List[Callable] = [] - def observe(self, observer): + def observe(self, observer: Callable) -> None: """Adds a new callable to the observer list which will be invoked by the 'fire' method. Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs): + def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index e946189f9a..de2adacd70 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,10 +13,14 @@ # limitations under the License. import queue +from typing import BinaryIO, Optional, Union, cast from twisted.internet import threads +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IPullProducer, IPushProducer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor class BackgroundFileConsumer: @@ -24,9 +28,9 @@ class BackgroundFileConsumer: and pull producers Args: - file_obj (file): The file like object to write to. Closed when + file_obj: The file like object to write to. Closed when finished. - reactor (twisted.internet.reactor): the Twisted reactor to use + reactor: the Twisted reactor to use """ # For PushProducers pause if we have this many unwritten slices @@ -34,13 +38,13 @@ class BackgroundFileConsumer: # And resume once the size of the queue is less than this _RESUME_ON_QUEUE_SIZE = 2 - def __init__(self, file_obj, reactor): - self._file_obj = file_obj + def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None: + self._file_obj: BinaryIO = file_obj - self._reactor = reactor + self._reactor: ISynapseReactor = reactor # Producer we're registered with - self._producer = None + self._producer: Optional[Union[IPushProducer, IPullProducer]] = None # True if PushProducer, false if PullProducer self.streaming = False @@ -51,20 +55,22 @@ class BackgroundFileConsumer: # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. - self._bytes_queue = queue.Queue() + self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() # Deferred that is resolved when finished writing - self._finished_deferred = None + self._finished_deferred: Optional[Deferred[None]] = None # If the _writer thread throws an exception it gets stored here. - self._write_exception = None + self._write_exception: Optional[Exception] = None - def registerProducer(self, producer, streaming): + def registerProducer( + self, producer: Union[IPushProducer, IPullProducer], streaming: bool + ) -> None: """Part of IConsumer interface Args: - producer (IProducer) - streaming (bool): True if push based producer, False if pull + producer + streaming: True if push based producer, False if pull based. """ if self._producer: @@ -81,29 +87,33 @@ class BackgroundFileConsumer: if not streaming: self._producer.resumeProducing() - def unregisterProducer(self): + def unregisterProducer(self) -> None: """Part of IProducer interface""" self._producer = None + assert self._finished_deferred is not None if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) - def write(self, bytes): + def write(self, write_bytes: bytes) -> None: """Part of IProducer interface""" if self._write_exception: raise self._write_exception + assert self._finished_deferred is not None if self._finished_deferred.called: raise Exception("consumer has closed") - self._bytes_queue.put_nowait(bytes) + self._bytes_queue.put_nowait(write_bytes) # If this is a PushProducer and the queue is getting behind # then we pause the producer. if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: self._paused_producer = True - self._producer.pauseProducing() + assert self._producer is not None + # cast safe because `streaming` means this is an IPushProducer + cast(IPushProducer, self._producer).pauseProducing() - def _writer(self): + def _writer(self) -> None: """This is run in a background thread to write to the file.""" try: while self._producer or not self._bytes_queue.empty(): @@ -130,11 +140,11 @@ class BackgroundFileConsumer: finally: self._file_obj.close() - def wait(self): + def wait(self) -> "Deferred[None]": """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) - def _resume_paused_producer(self): + def _resume_paused_producer(self) -> None: """Gets called if we should resume producing after being paused""" if self._paused_producer and self._producer: self._paused_producer = False diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 2ac7c2913c..9c405eb4d7 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from frozendict import frozendict -def freeze(o): +def freeze(o: Any) -> Any: if isinstance(o, dict): return frozendict({k: freeze(v) for k, v in o.items()}) @@ -33,7 +34,7 @@ def freeze(o): return o -def unfreeze(o): +def unfreeze(o: Any) -> Any: if isinstance(o, (dict, frozendict)): return {k: unfreeze(v) for k, v in o.items()} diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 3c0e8469f3..b163643ca3 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -13,42 +13,43 @@ # limitations under the License. import logging +from typing import Dict -from twisted.web.resource import NoResource +from twisted.web.resource import NoResource, Resource logger = logging.getLogger(__name__) -def create_resource_tree(desired_tree, root_resource): +def create_resource_tree( + desired_tree: Dict[str, Resource], root_resource: Resource +) -> Resource: """Create the resource tree for this homeserver. This in unduly complicated because Twisted does not support putting child resources more than 1 level deep at a time. Args: - web_client (bool): True to enable the web client. - root_resource (twisted.web.resource.Resource): The root - resource to add the tree to. + desired_tree: Dict from desired paths to desired resources. + root_resource: The root resource to add the tree to. Returns: - twisted.web.resource.Resource: the ``root_resource`` with a tree of - child resources added to it. + The ``root_resource`` with a tree of child resources added to it. """ # ideally we'd just use getChild and putChild but getChild doesn't work # unless you give it a Request object IN ADDITION to the name :/ So # instead, we'll store a copy of this mapping so we can actually add # extra resources to existing nodes. See self._resource_id for the key. - resource_mappings = {} - for full_path, res in desired_tree.items(): + resource_mappings: Dict[str, Resource] = {} + for full_path_str, res in desired_tree.items(): # twisted requires all resources to be bytes - full_path = full_path.encode("utf-8") + full_path = full_path_str.encode("utf-8") logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = NoResource() + child_resource: Resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource): return root_resource -def _resource_id(resource, path_seg): +def _resource_id(resource: Resource, path_seg: bytes) -> str: """Construct an arbitrary resource ID so you can retrieve the mapping later. @@ -96,4 +97,4 @@ def _resource_id(resource, path_seg): Returns: str: A unique string which can be a key to the child Resource. """ - return "%s-%s" % (resource, path_seg) + return "%s-%r" % (resource, path_seg) diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index a456b136f0..9f4be757ba 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -74,7 +74,7 @@ class ListNode(Generic[P]): new_node._refs_insert_after(node) return new_node - def remove_from_list(self): + def remove_from_list(self) -> None: """Remove this node from the list.""" with self._LOCK: self._refs_remove_node_from_list() @@ -84,7 +84,7 @@ class ListNode(Generic[P]): # immediately rather than at the next GC. self.cache_entry = None - def move_after(self, node: "ListNode"): + def move_after(self, node: "ListNode") -> None: """Move this node from its current location in the list to after the given node. """ @@ -103,7 +103,7 @@ class ListNode(Generic[P]): # Insert self back into the list, after target node self._refs_insert_after(node) - def _refs_remove_node_from_list(self): + def _refs_remove_node_from_list(self) -> None: """Internal method to *just* remove the node from the list, without e.g. clearing out the cache entry. """ @@ -122,7 +122,7 @@ class ListNode(Generic[P]): self.prev_node = None self.next_node = None - def _refs_insert_after(self, node: "ListNode"): + def _refs_insert_after(self, node: "ListNode") -> None: """Internal method to insert the node after the given node.""" # This method should only be called when we're not already in the list. diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index d1f76e3dc5..84e4f6ff55 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N should be considered expired. Normally the current time. """ - def verify_expiry_caveat(caveat: str): + def verify_expiry_caveat(caveat: str) -> bool: time_msec = get_time_ms() prefix = "time < " if not caveat.startswith(prefix): diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index cfb5b94ca9..f8b2d7bea9 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -15,6 +15,7 @@ import inspect import sys import traceback +from typing import Any, Dict, Optional from twisted.conch import manhole_ssh from twisted.conch.insults import insults @@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal from twisted.internet import defer +from twisted.internet.protocol import Factory + +from synapse.config.server import ManholeConfig PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" @@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(settings, globals): +def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with the supplied globals. Args: - username(str): The username ssh clients should auth with. - password(str): The password ssh clients should auth with. - globals(dict): The variables to expose in the shell. + username: The username ssh clients should auth with. + password: The password ssh clients should auth with. + globals: The variables to expose in the shell. Returns: - twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` + A factory to pass to ``listenTCP`` """ username = settings.username - password = settings.password + password = settings.password.encode("ascii") priv_key = settings.priv_key if priv_key is None: priv_key = Key.fromString(PRIVATE_KEY) @@ -84,19 +88,22 @@ def manhole(settings, globals): if pub_key is None: pub_key = Key.fromString(PUBLIC_KEY) - if not isinstance(password, bytes): - password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() - rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( + # mypy ignored here because: + # - can't deduce types of lambdas + # - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol] + rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment] SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.privateKeys[b"ssh-rsa"] = priv_key - factory.publicKeys[b"ssh-rsa"] = pub_key + + # conch has the wrong type on these dicts (says bytes to bytes, + # should be bytes to Keys judging by how it's used). + factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] + factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] return factory @@ -104,7 +111,7 @@ def manhole(settings, globals): class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" - def connectionMade(self): + def connectionMade(self) -> None: super().connectionMade() # replace the manhole interpreter with our own impl @@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole): class SynapseManholeInterpreter(ManholeInterpreter): - def showsyntaxerror(self, filename=None): + def showsyntaxerror(self, filename: Optional[str] = None) -> None: """Display the syntax error that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want any syntax errors to be sent to the terminal, rather than sentry. """ type, value, tb = sys.exc_info() + assert value is not None sys.last_type = type sys.last_value = value sys.last_traceback = tb @@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter): lines = traceback.format_exception_only(type, value) self.write("".join(lines)) - def showtraceback(self): + def showtraceback(self) -> None: """Display the exception that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want @@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter): """ sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() sys.last_traceback = last_tb + assert last_tb is not None + try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) self.write("".join(lines)) finally: - last_tb = ei = None + # On the line below, last_tb and ei appear to be dead. + # It's unclear whether there is a reason behind this line. + # It conceivably could be because an exception raised in this block + # will keep the local frame (containing these local variables) around. + # This was adapted taken from CPython's Lib/code.py; see here: + # https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150 + last_tb = ei = None # type: ignore - def displayhook(self, obj): + def displayhook(self, obj: Any) -> None: """ We override the displayhook so that we automatically convert coroutines into Deferreds. (Our superclass' displayhook will take care of the rest, diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 99f01e325c..9dd010af3b 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -24,7 +24,7 @@ from twisted.python.failure import Failure _already_patched = False -def do_patch(): +def do_patch() -> None: """ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ @@ -107,7 +107,7 @@ def do_patch(): _already_patched = True -def _check_yield_points(f: Callable, changes: List[str]): +def _check_yield_points(f: Callable, changes: List[str]) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks checking that after every yield the log contexts are correct. diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index a654c69684..dfe628c97e 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -15,33 +15,36 @@ import collections import contextlib import logging +import typing +from typing import Any, DefaultDict, Iterator, List, Set from twisted.internet import defer from synapse.api.errors import LimitExceededError +from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.util import Clock + +if typing.TYPE_CHECKING: + from contextlib import _GeneratorContextManager logger = logging.getLogger(__name__) class FederationRateLimiter: - def __init__(self, clock, config): - """ - Args: - clock (Clock) - config (FederationRateLimitConfig) - """ - - def new_limiter(): + def __init__(self, clock: Clock, config: FederationRateLimitConfig): + def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter(clock=clock, config=config) - self.ratelimiters = collections.defaultdict(new_limiter) + self.ratelimiters: DefaultDict[ + str, "_PerHostRatelimiter" + ] = collections.defaultdict(new_limiter) - def ratelimit(self, host): + def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": """Used to ratelimit an incoming request from a given host Example usage: @@ -60,11 +63,11 @@ class FederationRateLimiter: class _PerHostRatelimiter: - def __init__(self, clock, config): + def __init__(self, clock: Clock, config: FederationRateLimitConfig): """ Args: - clock (Clock) - config (FederationRateLimitConfig) + clock + config """ self.clock = clock @@ -75,21 +78,23 @@ class _PerHostRatelimiter: self.concurrent_requests = config.concurrent # request_id objects for requests which have been slept - self.sleeping_requests = set() + self.sleeping_requests: Set[object] = set() # map from request_id object to Deferred for requests which are ready # for processing but have been queued - self.ready_request_queue = collections.OrderedDict() + self.ready_request_queue: collections.OrderedDict[ + object, defer.Deferred[None] + ] = collections.OrderedDict() # request id objects for requests which are in progress - self.current_processing = set() + self.current_processing: Set[object] = set() # times at which we have recently (within the last window_size ms) # received requests. - self.request_times = [] + self.request_times: List[int] = [] @contextlib.contextmanager - def ratelimit(self): + def ratelimit(self) -> "Iterator[defer.Deferred[None]]": # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. @@ -102,7 +107,7 @@ class _PerHostRatelimiter: finally: self._on_exit(request_id) - def _on_enter(self, request_id): + def _on_enter(self, request_id: object) -> "defer.Deferred[None]": time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window @@ -120,9 +125,9 @@ class _PerHostRatelimiter: self.request_times.append(time_now) - def queue_request(): + def queue_request() -> "defer.Deferred[None]": if len(self.current_processing) >= self.concurrent_requests: - queue_defer = defer.Deferred() + queue_defer: defer.Deferred[None] = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( "Ratelimiter: queueing request (queue now %i items)", @@ -145,7 +150,7 @@ class _PerHostRatelimiter: self.sleeping_requests.add(request_id) - def on_wait_finished(_): + def on_wait_finished(_: Any) -> "defer.Deferred[None]": logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() @@ -155,19 +160,19 @@ class _PerHostRatelimiter: else: ret_defer = queue_request() - def on_start(r): + def on_start(r: object) -> object: logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r - def on_err(r): + def on_err(r: object) -> object: # XXX: why is this necessary? this is called before we start # processing the request so why would the request be in # current_processing? self.current_processing.discard(request_id) return r - def on_both(r): + def on_both(r: object) -> object: # Ensure that we've properly cleaned up. self.sleeping_requests.discard(request_id) self.ready_request_queue.pop(request_id, None) @@ -177,7 +182,7 @@ class _PerHostRatelimiter: ret_defer.addBoth(on_both) return make_deferred_yieldable(ret_defer) - def _on_exit(self, request_id): + def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 129b47cd49..648d9a95a7 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -13,9 +13,13 @@ # limitations under the License. import logging import random +from types import TracebackType +from typing import Any, Optional, Type import synapse.logging.context from synapse.api.errors import CodeMessageException +from synapse.storage import DataStore +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62 class NotRetryingDestination(Exception): - def __init__(self, retry_last_ts, retry_interval, destination): + def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): """Raised by the limiter (and federation client) to indicate that we are are deliberately not attempting to contact a given server. Args: - retry_last_ts (int): the unix ts in milliseconds of our last attempt + retry_last_ts: the unix ts in milliseconds of our last attempt to contact the server. 0 indicates that the last attempt was successful or that we've never actually attempted to connect. - retry_interval (int): the time in milliseconds to wait until the next + retry_interval: the time in milliseconds to wait until the next attempt. - destination (str): the domain in question + destination: the domain in question """ msg = "Not retrying server %s." % (destination,) @@ -51,7 +55,13 @@ class NotRetryingDestination(Exception): self.destination = destination -async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter( + destination: str, + clock: Clock, + store: DataStore, + ignore_backoff: bool = False, + **kwargs: Any, +) -> "RetryDestinationLimiter": """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k CodeMessageException with code < 500) Args: - destination (str): name of homeserver - clock (synapse.util.clock): timing source - store (synapse.storage.transactions.TransactionStore): datastore - ignore_backoff (bool): true to ignore the historical backoff data and + destination: name of homeserver + clock: timing source + store: datastore + ignore_backoff: true to ignore the historical backoff data and try the request anyway. We will still reset the retry_interval on success. Example usage: @@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k class RetryDestinationLimiter: def __init__( self, - destination, - clock, - store, - failure_ts, - retry_interval, - backoff_on_404=False, - backoff_on_failure=True, + destination: str, + clock: Clock, + store: DataStore, + failure_ts: Optional[int], + retry_interval: int, + backoff_on_404: bool = False, + backoff_on_failure: bool = True, ): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -128,17 +138,17 @@ class RetryDestinationLimiter: If no exception is raised, marks the destination as "up". Args: - destination (str) - clock (Clock) - store (DataStore) - failure_ts (int|None): when this destination started failing (in ms since + destination + clock + store + failure_ts: when this destination started failing (in ms since the epoch), or zero if the last request was successful - retry_interval (int): The next retry interval taken from the + retry_interval: The next retry interval taken from the database in milliseconds, or zero if the last request was successful. - backoff_on_404 (bool): Back off if we get a 404 + backoff_on_404: Back off if we get a 404 - backoff_on_failure (bool): set to False if we should not increase the + backoff_on_failure: set to False if we should not increase the retry interval on a failure. """ self.clock = clock @@ -150,10 +160,15 @@ class RetryDestinationLimiter: self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: valid_err_code = False if exc_type is None: valid_err_code = True @@ -161,7 +176,7 @@ class RetryDestinationLimiter: # avoid treating exceptions which don't derive from Exception as # failures; this is mostly so as not to catch defer._DefGen. valid_err_code = True - elif issubclass(exc_type, CodeMessageException): + elif isinstance(exc_val, CodeMessageException): # Some error codes are perfectly fine for some APIs, whereas other # APIs may expect to never received e.g. a 404. It's important to # handle 404 as some remote servers will return a 404 when the HS @@ -216,7 +231,7 @@ class RetryDestinationLimiter: if self.failure_ts is None: self.failure_ts = retry_last_ts - async def store_retry_timings(): + async def store_retry_timings() -> None: try: await self.store.set_destination_retry_timings( self.destination, diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index bf812ab516..06651e956d 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -18,7 +18,7 @@ import resource logger = logging.getLogger("synapse.app.homeserver") -def change_resource_limit(soft_file_no): +def change_resource_limit(soft_file_no: int) -> None: try: soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 38543dd1ea..12941065ca 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -16,7 +16,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union import jinja2 @@ -25,9 +25,9 @@ if TYPE_CHECKING: def build_jinja_env( - template_search_directories: Iterable[str], + template_search_directories: Sequence[str], config: "HomeServerConfig", - autoescape: Union[bool, Callable[[str], bool], None] = None, + autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, ) -> jinja2.Environment: """Set up a Jinja2 environment to load templates from the given search path @@ -63,12 +63,12 @@ def build_jinja_env( env.filters.update( { "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + "mxc_to_http": _create_mxc_to_http_filter(config.server.public_baseurl), } ) # common variables for all templates - env.globals.update({"server_name": config.server_name}) + env.globals.update({"server_name": config.server.server_name}) return env @@ -110,5 +110,5 @@ def _create_mxc_to_http_filter( return mxc_to_http_filter -def _format_ts_filter(value: int, format: str): +def _format_ts_filter(value: int, format: str) -> str: return time.strftime(format, time.localtime(value / 1000)) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index a1cf1960b0..baa9190a9a 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -14,6 +14,10 @@ import logging import re +import typing + +if typing.TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,13 +32,13 @@ logger = logging.getLogger(__name__) MAX_EMAIL_ADDRESS_LENGTH = 500 -def check_3pid_allowed(hs, medium, address): +def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: """Checks whether a given format of 3PID is allowed to be used on this HS Args: - hs (synapse.server.HomeServer): server - medium (str): 3pid medium - e.g. email, msisdn - address (str): address within that medium (e.g. "wotan@matrix.org") + hs: server + medium: 3pid medium - e.g. email, msisdn + address: address within that medium (e.g. "wotan@matrix.org") msisdns need to first have been canonicalised Returns: bool: whether the 3PID medium/address is allowed to be added to this HS diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index cb08af7385..1c20b24bbe 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -19,7 +19,7 @@ import subprocess logger = logging.getLogger(__name__) -def get_version_string(module): +def get_version_string(module) -> str: """Given a module calculate a git-aware version string for it. If called on a module not in a git checkout will return `__verison__`. diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 61814aff24..e108adc460 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -11,38 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Generic, List, TypeVar + +T = TypeVar("T") -class _Entry: +class _Entry(Generic[T]): __slots__ = ["end_key", "queue"] - def __init__(self, end_key): - self.end_key = end_key - self.queue = [] + def __init__(self, end_key: int) -> None: + self.end_key: int = end_key + self.queue: List[T] = [] -class WheelTimer: +class WheelTimer(Generic[T]): """Stores arbitrary objects that will be returned after their timers have expired. """ - def __init__(self, bucket_size=5000): + def __init__(self, bucket_size: int = 5000) -> None: """ Args: - bucket_size (int): Size of buckets in ms. Corresponds roughly to the + bucket_size: Size of buckets in ms. Corresponds roughly to the accuracy of the timer. """ - self.bucket_size = bucket_size - self.entries = [] - self.current_tick = 0 + self.bucket_size: int = bucket_size + self.entries: List[_Entry[T]] = [] + self.current_tick: int = 0 - def insert(self, now, obj, then): + def insert(self, now: int, obj: T, then: int) -> None: """Inserts object into timer. Args: - now (int): Current time in msec - obj (object): Object to be inserted - then (int): When to return the object strictly after. + now: Current time in msec + obj: Object to be inserted + then: When to return the object strictly after. """ then_key = int(then / self.bucket_size) + 1 @@ -70,7 +73,7 @@ class WheelTimer: self.entries[-1].queue.append(obj) - def fetch(self, now): + def fetch(self, now: int) -> List[T]: """Fetch any objects that have timed out Args: @@ -87,5 +90,5 @@ class WheelTimer: return ret - def __len__(self): + def __len__(self) -> int: return sum(len(entry.queue) for entry in self.entries) diff --git a/sytest-blacklist b/sytest-blacklist index de9986357b..65bf1774e3 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,5 +1,5 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in -# Synapse. +# Synapse. This doesn't include flakey tests---better to deflake them instead. # # Each line of this file is scanned by sytest during a run and if the line # exactly matches the name of a test, it will be marked as "expected fail", @@ -9,9 +9,6 @@ # Test names are encouraged to have a bug accompanied with them, serving as an # explanation for why the test has been excluded. -# Blacklisted due to https://github.com/matrix-org/synapse/issues/1679 -Remote room members also see posted message events - # Blacklisted due to https://github.com/matrix-org/synapse/issues/2065 Guest users can accept invites to private rooms over federation @@ -24,12 +21,6 @@ Newly created users see their own presence in /initialSync (SYT-34) # Blacklisted due to https://github.com/matrix-org/synapse/issues/1396 Should reject keys claiming to belong to a different user -# Blacklisted due to https://github.com/matrix-org/synapse/issues/1531 -Enabling an unknown default rule fails with 404 - -# Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 -New federated private chats get full presence information (SYN-115) - # Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing # this requirement from the spec Inbound federation of state requires event_id as a mandatory paramater diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index a91d31ce61..ae88ed89aa 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -94,7 +94,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # deactivate user self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) - self.get_success(self.handler.handle_user_deactivated(r_user_id)) + self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) # profile is not in directory profile = self.get_success(self.store.get_user_in_directory(r_user_id)) @@ -118,7 +118,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) - self.get_success(self.handler.handle_user_deactivated(s_user_id)) + self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) self.store.remove_from_user_dir.not_called() def test_handle_user_deactivated_regular_user(self): @@ -127,7 +127,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store.register_user(user_id=r_user_id, password_hash=None) ) self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) - self.get_success(self.handler.handle_user_deactivated(r_user_id)) + self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) def test_private_room(self): diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 72f976d8e2..a42388b26f 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import Optional +from synapse.api.constants import EventContentFields, EventTypes, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin from synapse.rest.client import login, room, room_upgrade_rest_servlet +from synapse.server import HomeServer from tests import unittest from tests.server import FakeChannel @@ -29,9 +31,8 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): room_upgrade_rest_servlet.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: "HomeServer"): self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() self.creator = self.register_user("creator", "pass") self.creator_token = self.login(self.creator, "pass") @@ -42,13 +43,18 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) self.helper.join(self.room_id, self.other, tok=self.other_token) - def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + def _upgrade_room( + self, token: Optional[str] = None, room_id: Optional[str] = None + ) -> FakeChannel: # We never want a cached response. self.reactor.advance(5 * 60 + 1) + if room_id is None: + room_id = self.room_id + return self.make_request( "POST", - "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + f"/_matrix/client/r0/rooms/{room_id}/upgrade", # This will upgrade a room to the same version, but that's fine. content={"new_version": DEFAULT_ROOM_VERSION}, access_token=token or self.creator_token, @@ -157,3 +163,56 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): tok=self.creator_token, ) self.assertNotIn(self.other, power_levels["users"]) + + def test_space(self): + """Test upgrading a space.""" + + # Create a space. + space_id = self.helper.create_room_as( + self.creator, + tok=self.creator_token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + + # Add the room as a child room. + self.helper.send_state( + space_id, + event_type=EventTypes.SpaceChild, + body={"via": [self.hs.hostname]}, + tok=self.creator_token, + state_key=self.room_id, + ) + + # Also add a room that was removed. + old_room_id = "!notaroom:" + self.hs.hostname + self.helper.send_state( + space_id, + event_type=EventTypes.SpaceChild, + body={}, + tok=self.creator_token, + state_key=old_room_id, + ) + + # Upgrade the room! + channel = self._upgrade_room(room_id=space_id) + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + new_space_id = channel.json_body["replacement_room"] + + state_ids = self.get_success(self.store.get_current_state_ids(new_space_id)) + + # Ensure the new room is still a space. + create_event = self.get_success( + self.store.get_event(state_ids[(EventTypes.Create, "")]) + ) + self.assertEqual( + create_event.content.get(EventContentFields.ROOM_TYPE), RoomTypes.SPACE + ) + + # The child link should have been copied over. + self.assertIn((EventTypes.SpaceChild, self.room_id), state_ids) + # The child that was removed should not be copied over. + self.assertNotIn((EventTypes.SpaceChild, old_room_id), state_ids) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index ac0e427752..b2c0279ba0 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -23,10 +23,13 @@ class WellKnownTests(unittest.HomeserverTestCase): # replace the JsonResource with a WellKnownResource return WellKnownResource(self.hs) + @unittest.override_config( + { + "public_baseurl": "https://tesths", + "default_identity_server": "https://testis", + } + ) def test_well_known(self): - self.hs.config.public_baseurl = "https://tesths" - self.hs.config.default_identity_server = "https://testis" - channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) @@ -35,14 +38,17 @@ class WellKnownTests(unittest.HomeserverTestCase): self.assertEqual( channel.json_body, { - "m.homeserver": {"base_url": "https://tesths"}, + "m.homeserver": {"base_url": "https://tesths/"}, "m.identity_server": {"base_url": "https://testis"}, }, ) + @unittest.override_config( + { + "public_baseurl": None, + } + ) def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) diff --git a/tests/unittest.py b/tests/unittest.py index f2c90cc47b..7a6f5954d0 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource): FederationRateLimitConfig( window_size=1, sleep_limit=1, - sleep_msec=1, + sleep_delay=1, reject_limit=1000, - concurrent_requests=1000, + concurrent=1000, ), )