Merge branch 'release-v1.43' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Andrew Morgan 2021-09-14 11:02:37 +01:00
commit 003c2ab629
119 changed files with 743 additions and 452 deletions

View file

@ -1,10 +1,2 @@
# This file serves as a blacklist for SyTest tests that we expect will fail in
# Synapse when run under worker mode. For more details, see sytest-blacklist.
Can re-join room if re-invited
# new failures as of https://github.com/matrix-org/sytest/pull/732
Device list doesn't change if remote server is down
# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update

1
changelog.d/10601.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations to the synapse.util package.

1
changelog.d/10774.bugfix Normal file
View file

@ -0,0 +1 @@
Properly handle room upgrades of spaces.

1
changelog.d/10788.misc Normal file
View file

@ -0,0 +1 @@
Remove fixed and flakey tests from the sytest-blacklist.

1
changelog.d/10789.misc Normal file
View file

@ -0,0 +1 @@
Improve internal details of the user directory code.

1
changelog.d/10795.doc Normal file
View file

@ -0,0 +1 @@
Correct 2 typographical errors in the *Log Contexts* documentation.

1
changelog.d/10798.misc Normal file
View file

@ -0,0 +1 @@
Use direct references to config flags.

1
changelog.d/10799.misc Normal file
View file

@ -0,0 +1 @@
Add a max version for the `jaeger-client` dependency for an incompatibility with the rust reporter.

1
changelog.d/10804.doc Normal file
View file

@ -0,0 +1 @@
Fixed a wording mistake in the sample configuration. Contributed by @bramvdnheuvel:nltrix.net.

View file

@ -10,7 +10,7 @@ Logcontexts are also used for CPU and database accounting, so that we
can track which requests were responsible for high CPU use or database
activity.
The `synapse.logging.context` module provides a facilities for managing
The `synapse.logging.context` module provides facilities for managing
the current log context (as well as providing the `LoggingContextFilter`
class).
@ -351,7 +351,7 @@ and the awaitable chain is now orphaned, and will be garbage-collected at
some point. Note that `await_something_interesting` is a coroutine,
which Python implements as a generator function. When Python
garbage-collects generator functions, it gives them a chance to
clean up by making the `async` (or `yield`) raise a `GeneratorExit`
clean up by making the `await` (or `yield`) raise a `GeneratorExit`
exception. In our case, that means that the `__exit__` handler of
`PreserveLoggingContext` will carefully restore the request context, but
there is now nothing waiting for its return, so the request context is

View file

@ -2086,7 +2086,7 @@ password_config:
#
#require_lowercase: true
# Whether a password must contain at least one lowercase letter.
# Whether a password must contain at least one uppercase letter.
# Defaults to 'false'.
#
#require_uppercase: true

View file

@ -10,3 +10,40 @@ DB corruption) get stale or out of sync. If this happens, for now the
solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql)
and then restart synapse. This should then start a background task to
flush the current tables and regenerate the directory.
Data model
----------
There are five relevant tables that collectively form the "user directory".
Three of them track a master list of all the users we could search for.
The last two (collectively called the "search tables") track who can
see who.
From all of these tables we exclude three types of local user:
- support users
- appservice users
- deactivated users
* `user_directory`. This contains the user_id, display name and avatar we'll
return when you search the directory.
- Because there's only one directory entry per user, it's important that we only
ever put publicly visible names here. Otherwise we might leak a private
nickname or avatar used in a private room.
- Indexed on rooms. Indexed on users.
* `user_directory_search`. To be joined to `user_directory`. It contains an extra
column that enables full text search based on user ids and display names.
Different schemas for SQLite and Postgres with different code paths to match.
- Indexed on the full text search data. Indexed on users.
* `user_directory_stream_pos`. When the initial background update to populate
the directory is complete, we record a stream position here. This indicates
that synapse should now listen for room changes and incrementally update
the directory where necessary.
* `users_in_public_rooms`. Contains associations between users and the public rooms they're in.
Used to determine which users are in public rooms and should be publicly visible in the directory.
* `users_who_share_private_rooms`. Rows are triples `(L, M, room id)` where `L`
is a local user and `M` is a local or remote user. `L` and `M` should be
different, but this isn't enforced by a constraint.

View file

@ -74,17 +74,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/daemonize.py,
synapse/util/hash.py,
synapse/util/iterutils.py,
synapse/util/linked_list.py,
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/module_loader.py,
synapse/util/msisdn.py,
synapse/util/stringutils.py,
synapse/util,
synapse/visibility.py,
tests/replication,
tests/test_event_auth.py,
@ -102,6 +92,69 @@ files =
[mypy-synapse.rest.client.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
disallow_untyped_defs = True
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-pymacaroons.*]
ignore_missing_imports = True

View file

@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory):
def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory):
def __init__(self): ...
def __init__(self) -> None: ...

View file

@ -46,7 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
async def can_do_action(
self,
@ -56,7 +56,7 @@ class Ratelimiter:
burst_count: Optional[int] = None,
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[int] = None,
_time_now_s: Optional[float] = None,
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
@ -160,7 +160,7 @@ class Ratelimiter:
return allowed, time_allowed
def _prune_message_counts(self, time_now_s: int):
def _prune_message_counts(self, time_now_s: float):
"""Remove message count entries that have not exceeded their defined
rate_hz limit
@ -188,7 +188,7 @@ class Ratelimiter:
burst_count: Optional[int] = None,
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[int] = None,
_time_now_s: Optional[float] = None,
):
"""Checks if an action can be performed. If not, raises a LimitExceededError

View file

@ -41,11 +41,11 @@ class ConsentURIBuilder:
"""
if hs_config.form_secret is None:
raise ConfigError("form_secret not set in config")
if hs_config.public_baseurl is None:
if hs_config.server.public_baseurl is None:
raise ConfigError("public_baseurl not set in config")
self._hmac_secret = hs_config.form_secret.encode("utf-8")
self._public_baseurl = hs_config.public_baseurl
self._public_baseurl = hs_config.server.public_baseurl
def build_user_consent_uri(self, user_id):
"""Build a URI which we can give to the user to do their privacy

View file

@ -82,7 +82,7 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
run_command (Callable[]): callable that actually runs the reactor
"""
logger = logging.getLogger(config.worker_app)
logger = logging.getLogger(config.worker.worker_app)
start_reactor(
appname,
@ -398,7 +398,7 @@ async def start(hs: "HomeServer"):
# If background tasks are running on the main process, start collecting the
# phone home stats.
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
start_phone_stats_home(hs)
# We now freeze all allocated objects in the hopes that (almost)
@ -433,9 +433,13 @@ def setup_sentry(hs):
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
scope.set_tag("matrix_server_name", hs.config.server_name)
scope.set_tag("matrix_server_name", hs.config.server.server_name)
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver"
app = (
hs.config.worker.worker_app
if hs.config.worker.worker_app
else "synapse.app.homeserver"
)
name = hs.get_instance_name()
scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name)

View file

@ -178,12 +178,12 @@ def start(config_options):
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
if config.worker_app is not None:
assert config.worker_app == "synapse.app.admin_cmd"
if config.worker.worker_app is not None:
assert config.worker.worker_app == "synapse.app.admin_cmd"
# Update the config with some basic overrides so that don't have to specify
# a full worker config.
config.worker_app = "synapse.app.admin_cmd"
config.worker.worker_app = "synapse.app.admin_cmd"
if (
not config.worker_daemonize
@ -196,7 +196,7 @@ def start(config_options):
# Explicitly disable background processes
config.update_user_directory = False
config.run_background_tasks = False
config.worker.run_background_tasks = False
config.start_pushers = False
config.pusher_shard_config.instances = []
config.send_federation = False
@ -205,7 +205,7 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
ss = AdminCmdServer(
config.server_name,
config.server.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)

View file

@ -416,7 +416,7 @@ def start(config_options):
sys.exit(1)
# For backwards compatibility let any of the old app names.
assert config.worker_app in (
assert config.worker.worker_app in (
"synapse.app.appservice",
"synapse.app.client_reader",
"synapse.app.event_creator",
@ -430,7 +430,7 @@ def start(config_options):
"synapse.app.user_dir",
)
if config.worker_app == "synapse.app.appservice":
if config.worker.worker_app == "synapse.app.appservice":
if config.appservice.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
@ -446,7 +446,7 @@ def start(config_options):
# For other worker types we force this to off.
config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.user_dir":
if config.worker.worker_app == "synapse.app.user_dir":
if config.server.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
@ -469,7 +469,7 @@ def start(config_options):
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = GenericWorkerServer(
config.server_name,
config.server.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)

View file

@ -350,7 +350,7 @@ def setup(config_options):
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = SynapseHomeServer(
config.server_name,
config.server.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)

View file

@ -73,7 +73,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
store = hs.get_datastore()
stats["homeserver"] = hs.config.server_name
stats["homeserver"] = hs.config.server.server_name
stats["server_context"] = hs.config.server_context
stats["timestamp"] = now
stats["uptime_seconds"] = uptime

View file

@ -88,7 +88,7 @@ class AuthConfig(Config):
#
#require_lowercase: true
# Whether a password must contain at least one lowercase letter.
# Whether a password must contain at least one uppercase letter.
# Defaults to 'false'.
#
#require_uppercase: true

View file

@ -223,7 +223,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# writes.
log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs):
@ -335,5 +335,5 @@ def setup_logging(
# Log immediately so we can grep backwards.
logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
logging.info("Server hostname: %s", config.server.server_name)
logging.info("Instance name: %s", hs.get_instance_name())

View file

@ -14,6 +14,8 @@
from typing import Dict, Optional
import attr
from ._base import Config
@ -29,18 +31,13 @@ class RateLimitConfig:
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
@attr.s(auto_attribs=True)
class FederationRateLimitConfig:
_items_and_default = {
"window_size": 1000,
"sleep_limit": 10,
"sleep_delay": 500,
"reject_limit": 50,
"concurrent": 3,
}
def __init__(self, **kwargs):
for i in self._items_and_default.keys():
setattr(self, i, kwargs.get(i) or self._items_and_default[i])
window_size: int = 1000
sleep_limit: int = 10
sleep_delay: int = 500
reject_limit: int = 50
concurrent: int = 3
class RatelimitConfig(Config):
@ -69,11 +66,15 @@ class RatelimitConfig(Config):
else:
self.rc_federation = FederationRateLimitConfig(
**{
"window_size": config.get("federation_rc_window_size"),
"sleep_limit": config.get("federation_rc_sleep_limit"),
"sleep_delay": config.get("federation_rc_sleep_delay"),
"reject_limit": config.get("federation_rc_reject_limit"),
"concurrent": config.get("federation_rc_concurrent"),
k: v
for k, v in {
"window_size": config.get("federation_rc_window_size"),
"sleep_limit": config.get("federation_rc_sleep_limit"),
"sleep_delay": config.get("federation_rc_sleep_delay"),
"reject_limit": config.get("federation_rc_reject_limit"),
"concurrent": config.get("federation_rc_concurrent"),
}.items()
if v is not None
}
)

View file

@ -88,7 +88,7 @@ class EventValidator:
self._validate_retention(event)
if event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server_name, event):
if not server_matches_acl_event(config.server.server_name, event):
raise SynapseError(
400, "Can't create an ACL event that denies the local server"
)

View file

@ -22,6 +22,7 @@ from prometheus_client import Counter
from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
import synapse.metrics
from synapse.api.presence import UserPresenceState
@ -280,11 +281,14 @@ class FederationSender(AbstractFederationSender):
self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {}
self._rr_txn_interval_per_room_ms = (
1000.0 / hs.config.federation_rr_transactions_per_room_per_second
1000.0
/ hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
)
# wake up destinations that have outstanding PDUs to be caught up
self._catchup_after_startup_timer = self.clock.call_later(
self._catchup_after_startup_timer: Optional[
IDelayedCall
] = self.clock.call_later(
CATCH_UP_STARTUP_DELAY_SEC,
run_as_background_process,
"wake_destinations_needing_catchup",
@ -406,7 +410,7 @@ class FederationSender(AbstractFederationSender):
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
"federation_sender"
).observe((now - ts) / 1000)
@ -435,6 +439,7 @@ class FederationSender(AbstractFederationSender):
if events:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels(
"federation_sender"

View file

@ -144,7 +144,7 @@ class GroupAttestionRenewer:
self.is_mine_id = hs.is_mine_id
self.attestations = hs.get_groups_attestation_signing()
if not hs.config.worker_app:
if not hs.config.worker.worker_app:
self._renew_attestations_loop = self.clock.looping_call(
self._start_renew_attestations, 30 * 60 * 1000
)

View file

@ -45,16 +45,16 @@ class BaseHandler:
self.request_ratelimiter = Ratelimiter(
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.rc_message
self._rc_message = self.hs.config.ratelimiting.rc_message
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
if self.hs.config.ratelimiting.rc_admin_redaction:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second,
burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None

View file

@ -78,7 +78,7 @@ class AccountValidityHandler:
)
# Check the renewal emails to send and send them every 30min.
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
@ -249,7 +249,7 @@ class AccountValidityHandler:
renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
self.hs.config.server.public_baseurl,
renewal_token,
)
@ -398,6 +398,7 @@ class AccountValidityHandler:
"""
now = self.clock.time_msec()
if expiration_ts is None:
assert self._account_validity_period is not None
expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user(

View file

@ -131,6 +131,8 @@ class ApplicationServicesHandler:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender"
).observe((now - ts) / 1000)
@ -166,6 +168,7 @@ class ApplicationServicesHandler:
if events:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels(
"appservice_sender"

View file

@ -244,8 +244,8 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
)
# The number of seconds to keep a UI auth session active.
@ -255,14 +255,14 @@ class AuthHandler(BaseHandler):
self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
)
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
@ -289,7 +289,7 @@ class AuthHandler(BaseHandler):
hs.config.sso_account_deactivated_template
)
self._server_name = hs.config.server_name
self._server_name = hs.config.server.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@ -749,7 +749,7 @@ class AuthHandler(BaseHandler):
"name": self.hs.config.user_consent_policy_name,
"url": "%s_matrix/consent?v=%s"
% (
self.hs.config.public_baseurl,
self.hs.config.server.public_baseurl,
self.hs.config.user_consent_version,
),
},
@ -1799,7 +1799,7 @@ class MacaroonGenerator:
def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
location=self.hs.config.server.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key,
)

View file

@ -46,7 +46,7 @@ class DeactivateAccountHandler(BaseHandler):
# Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do).
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting)
self._account_validity_enabled = (
@ -131,7 +131,7 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.add_user_pending_deactivation(user_id)
# delete from user directory
await self.user_directory_handler.handle_user_deactivated(user_id)
await self.user_directory_handler.handle_local_user_deactivated(user_id)
# Mark the user as erased, if they asked for that
if erase_data:

View file

@ -84,8 +84,8 @@ class DeviceMessageHandler:
self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
rate_hz=hs.config.ratelimiting.rc_key_requests.per_second,
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:

View file

@ -57,7 +57,7 @@ class E2eKeysHandler:
federation_registry = hs.get_federation_registry()
self._is_master = hs.config.worker_app is None
self._is_master = hs.config.worker.worker_app is None
if not self._is_master:
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)

View file

@ -101,7 +101,7 @@ class FederationHandler(BaseHandler):
hs
)
if hs.config.worker_app:
if hs.config.worker.worker_app:
self._maybe_store_room_on_outlier_membership = (
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
@ -1614,7 +1614,7 @@ class FederationHandler(BaseHandler):
Args:
room_id
"""
if self.config.worker_app:
if self.config.worker.worker_app:
await self._clean_room_for_join_client(room_id)
else:
await self.store.clean_room_for_join(room_id)

View file

@ -149,7 +149,7 @@ class FederationEventHandler:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
if hs.config.worker_app:
if hs.config.worker.worker_app:
self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
@ -1009,7 +1009,7 @@ class FederationEventHandler:
await self._store.mark_remote_user_device_cache_as_stale(sender)
# Immediately attempt a resync in the background
if self._config.worker_app:
if self._config.worker.worker_app:
await self._user_device_resync(user_id=sender)
else:
await self._device_list_updater.user_device_resync(sender)

View file

@ -540,13 +540,13 @@ class IdentityHandler(BaseHandler):
# It is already checked that public_baseurl is configured since this code
# should only be used if account_threepid_delegate_msisdn is true.
assert self.hs.config.public_baseurl
assert self.hs.config.server.public_baseurl
# we need to tell the client to send the token back to us, since it doesn't
# otherwise know where to send it, so add submit_url response parameter
# (see also MSC2078)
data["submit_url"] = (
self.hs.config.public_baseurl
self.hs.config.server.public_baseurl
+ "_matrix/client/unstable/add_threepid/msisdn/submit_token"
)
return data

View file

@ -84,7 +84,7 @@ class MessageHandler:
# scheduled.
self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app:
if not hs.config.worker.worker_app:
run_as_background_process(
"_schedule_next_expiry", self._schedule_next_expiry
)
@ -461,7 +461,7 @@ class EventCreationHandler:
self._dummy_events_threshold = hs.config.dummy_events_threshold
if (
self.config.run_background_tasks
self.config.worker.run_background_tasks
and self.config.cleanup_extremities_with_dummy_events
):
self.clock.looping_call(

View file

@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name: str = hs.config.server_name
self._server_name: str = hs.config.server.server_name
# identifier for the external_ids table
self.idp_id = provider.idp_id

View file

@ -91,7 +91,7 @@ class PaginationHandler:
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
if hs.config.run_background_tasks and hs.config.retention_enabled:
if hs.config.worker.run_background_tasks and hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)

View file

@ -28,6 +28,7 @@ from bisect import bisect
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler):
super().__init__(hs)
self.hs = hs
self.server_name = hs.hostname
self.wheel_timer = WheelTimer()
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.use_presence
@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler):
prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()}
new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE

View file

@ -63,7 +63,7 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)

View file

@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class ReadMarkerHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
self.server_name = hs.config.server.server_name
self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker")

View file

@ -29,7 +29,7 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server_name
self.server_name = hs.config.server.server_name
self.store = hs.get_datastore()
self.event_auth_handler = hs.get_event_auth_handler()

View file

@ -102,7 +102,7 @@ class RegistrationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
if hs.config.worker_app:
if hs.config.worker.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
hs
@ -696,7 +696,7 @@ class RegistrationHandler(BaseHandler):
address: the IP address used to perform the registration.
shadow_banned: Whether to shadow-ban the user
"""
if self.hs.config.worker_app:
if self.hs.config.worker.worker_app:
await self._register_client(
user_id=user_id,
password_hash=password_hash,
@ -786,7 +786,7 @@ class RegistrationHandler(BaseHandler):
Does the bits that need doing on the main process. Not for use outside this
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker_app
assert not self.hs.config.worker.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@ -843,7 +843,7 @@ class RegistrationHandler(BaseHandler):
"""
# TODO: 3pid registration can actually happen on the workers. Consider
# refactoring it.
if self.hs.config.worker_app:
if self.hs.config.worker.worker_app:
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
)

View file

@ -33,6 +33,7 @@ from synapse.api.constants import (
Membership,
RoomCreationPreset,
RoomEncryptionAlgorithms,
RoomTypes,
)
from synapse.api.errors import (
AuthError,
@ -397,7 +398,7 @@ class RoomCreationHandler(BaseHandler):
initial_state = {}
# Replicate relevant room events
types_to_copy = (
types_to_copy: List[Tuple[str, Optional[str]]] = [
(EventTypes.JoinRules, ""),
(EventTypes.Name, ""),
(EventTypes.Topic, ""),
@ -408,7 +409,16 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.ServerACL, ""),
(EventTypes.RelatedGroups, ""),
(EventTypes.PowerLevels, ""),
)
]
# If the old room was a space, copy over the room type and the rooms in
# the space.
if (
old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
== RoomTypes.SPACE
):
creation_content[EventContentFields.ROOM_TYPE] = RoomTypes.SPACE
types_to_copy.append((EventTypes.SpaceChild, None))
old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
@ -419,6 +429,11 @@ class RoomCreationHandler(BaseHandler):
for k, old_event_id in old_room_state_ids.items():
old_event = old_room_state_events.get(old_event_id)
if old_event:
# If the event is an space child event with empty content, it was
# removed from the space and should be ignored.
if k[0] == EventTypes.SpaceChild and not old_event.content:
continue
initial_state[k] = old_event.content
# deep-copy the power-levels event before we start modifying it

View file

@ -13,6 +13,7 @@
# limitations under the License.
import logging
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
@ -21,6 +22,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class MatchChange(Enum):
no_change = auto()
now_true = auto()
now_false = auto()
class StateDeltasHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -31,18 +38,12 @@ class StateDeltasHandler:
event_id: Optional[str],
key_name: str,
public_value: str,
) -> Optional[bool]:
) -> MatchChange:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
For example, check if `history_visibility` (`key_name`) changed from
`shared` to `world_readable` (`public_value`).
Returns:
None if the field in the events either both match `public_value`
or if neither do, i.e. there has been no change.
True if it didn't match `public_value` but now does
False if it did match `public_value` but now doesn't
"""
prev_event = None
event = None
@ -54,7 +55,7 @@ class StateDeltasHandler:
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
return None
return MatchChange.no_change
prev_value = None
value = None
@ -68,8 +69,8 @@ class StateDeltasHandler:
logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value:
return True
return MatchChange.now_true
elif value != public_value and prev_value == public_value:
return False
return MatchChange.now_false
else:
return None
return MatchChange.no_change

View file

@ -54,7 +54,7 @@ class StatsHandler:
# Guard to ensure we only process deltas one at a time
self._is_processing = False
if self.stats_enabled and hs.config.run_background_tasks:
if self.stats_enabled and hs.config.worker.run_background_tasks:
self.notifier.add_replication_callback(self.notify_new_event)
# We kick this off so that we don't have to wait for a change before

View file

@ -53,7 +53,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@ -73,7 +73,7 @@ class FollowerTypingHandler:
self._room_typing: Dict[str, Set[str]] = {}
self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict
@ -30,14 +30,26 @@ logger = logging.getLogger(__name__)
class UserDirectoryHandler(StateDeltasHandler):
"""Handles querying of and keeping updated the user_directory.
"""Handles queries and updates for the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
The user directory is filled with users who this server can see are joined to a
world_readable or publicly joinable room. We keep a database table up to date
by streaming changes of the current state and recalculating whether users should
be in the directory or not when necessary.
When a local user searches the user_directory, we report two kinds of users:
- users this server can see are joined to a world_readable or publicly
joinable room, and
- users belonging to a private room shared by that local user.
The two cases are tracked separately in the `users_in_public_rooms` and
`users_who_share_private_rooms` tables. Both kinds of users have their
username and avatar tracked in a `user_directory` table.
This handler has three responsibilities:
1. Forwarding requests to `/user_directory/search` to the UserDirectoryStore.
2. Providing hooks for the application to call when local users are added,
removed, or have their profile changed.
3. Listening for room state changes that indicate remote users have
joined or left a room, or that their profile has changed.
"""
def __init__(self, hs: "HomeServer"):
@ -130,7 +142,7 @@ class UserDirectoryHandler(StateDeltasHandler):
user_id, profile.display_name, profile.avatar_url
)
async def handle_user_deactivated(self, user_id: str) -> None:
async def handle_local_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
@ -196,7 +208,7 @@ class UserDirectoryHandler(StateDeltasHandler):
public_value=Membership.JOIN,
)
if change is False:
if change is MatchChange.now_false:
# Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room
is_in_room = await self.store.is_host_joined(
@ -219,14 +231,14 @@ class UserDirectoryHandler(StateDeltasHandler):
is_support = await self.store.is_support_user(state_key)
if not is_support:
if change is None:
if change is MatchChange.no_change:
# Handle any profile changes
await self._handle_profile_change(
state_key, room_id, prev_event_id, event_id
)
continue
if change: # The user joined
if change is MatchChange.now_true: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
# It isn't expected for this event to not exist, but we
# don't want the entire background process to break.
@ -263,14 +275,14 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Handling change for %s: %s", typ, room_id)
if typ == EventTypes.RoomHistoryVisibility:
change = await self._get_key_change(
publicness = await self._get_key_change(
prev_event_id,
event_id,
key_name="history_visibility",
public_value=HistoryVisibility.WORLD_READABLE,
)
elif typ == EventTypes.JoinRules:
change = await self._get_key_change(
publicness = await self._get_key_change(
prev_event_id,
event_id,
key_name="join_rule",
@ -278,9 +290,7 @@ class UserDirectoryHandler(StateDeltasHandler):
)
else:
raise Exception("Invalid event type")
# If change is None, no change. True => become world_readable/public,
# False => was world_readable/public
if change is None:
if publicness is MatchChange.no_change:
logger.debug("No change")
return
@ -290,13 +300,13 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id
)
logger.debug("Change: %r, is_public: %r", change, is_public)
logger.debug("Change: %r, publicness: %r", publicness, is_public)
if change and not is_public:
if publicness is MatchChange.now_true and not is_public:
# If we became world readable but room isn't currently public then
# we ignore the change
return
elif not change and is_public:
elif publicness is MatchChange.now_false and is_public:
# If we stopped being world readable but are still public,
# ignore the change
return

View file

@ -236,8 +236,17 @@ except ImportError:
try:
from rust_python_jaeger_reporter import Reporter
# jaeger-client 4.7.0 requires that reporters inherit from BaseReporter, which
# didn't exist before that version.
try:
from jaeger_client.reporter import BaseReporter
except ImportError:
class BaseReporter: # type: ignore[no-redef]
pass
@attr.s(slots=True, frozen=True)
class _WrappedRustReporter:
class _WrappedRustReporter(BaseReporter):
"""Wrap the reporter to ensure `report_span` never throws."""
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
@ -374,7 +383,7 @@ def init_tracer(hs: "HomeServer"):
config = JaegerConfig(
config=hs.config.jaeger_config,
service_name=f"{hs.config.server_name} {hs.get_instance_name()}",
service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
scope_manager=LogContextScopeManager(hs.config),
metrics_factory=PrometheusMetricsFactory(),
)
@ -382,6 +391,7 @@ def init_tracer(hs: "HomeServer"):
# If we have the rust jaeger reporter available let's use that.
if RustReporter:
logger.info("Using rust_python_jaeger_reporter library")
assert config.sampler is not None
tracer = config.create_tracer(RustReporter(), config.sampler)
opentracing.set_global_tracer(tracer)
else:

View file

@ -178,7 +178,7 @@ class ModuleApi:
@property
def public_baseurl(self) -> str:
"""The configured public base URL for this homeserver."""
return self._hs.config.public_baseurl
return self._hs.config.server.public_baseurl
@property
def email_app_name(self) -> str:
@ -640,7 +640,7 @@ class ModuleApi:
if desc is None:
desc = f.__name__
if self._hs.config.run_background_tasks or run_on_all_instances:
if self._hs.config.worker.run_background_tasks or run_on_all_instances:
self._clock.looping_call(
run_as_background_process,
msec,

View file

@ -130,7 +130,7 @@ class Mailer:
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
self.hs.config.public_baseurl
self.hs.config.server.public_baseurl
+ "_synapse/client/password_reset/email/submit_token?%s"
% urllib.parse.urlencode(params)
)
@ -140,7 +140,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.password_reset
% {"server_name": self.hs.config.server_name},
% {"server_name": self.hs.config.server.server_name},
template_vars,
)
@ -160,7 +160,7 @@ class Mailer:
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
self.hs.config.public_baseurl
self.hs.config.server.public_baseurl
+ "_matrix/client/unstable/registration/email/submit_token?%s"
% urllib.parse.urlencode(params)
)
@ -170,7 +170,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.email_validation
% {"server_name": self.hs.config.server_name},
% {"server_name": self.hs.config.server.server_name},
template_vars,
)
@ -191,7 +191,7 @@ class Mailer:
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
self.hs.config.public_baseurl
self.hs.config.server.public_baseurl
+ "_matrix/client/unstable/add_threepid/email/submit_token?%s"
% urllib.parse.urlencode(params)
)
@ -201,7 +201,7 @@ class Mailer:
await self.send_email(
email_address,
self.email_subjects.email_validation
% {"server_name": self.hs.config.server_name},
% {"server_name": self.hs.config.server.server_name},
template_vars,
)
@ -852,7 +852,7 @@ class Mailer:
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
self.hs.config.server.public_baseurl,
urllib.parse.urlencode(params),
)

View file

@ -73,7 +73,7 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
):
self.client_name = client_name
self.command_handler = command_handler
self.server_name = hs.config.server_name
self.server_name = hs.config.server.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class

View file

@ -168,7 +168,7 @@ class ReplicationCommandHandler:
continue
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
if hs.config.worker.worker_app is not None:
continue
if stream.NAME == FederationStream.NAME and hs.config.send_federation:
@ -222,7 +222,7 @@ class ReplicationCommandHandler:
},
)
self._is_master = hs.config.worker_app is None
self._is_master = hs.config.worker.worker_app is None
self._federation_sender = None
if self._is_master and not hs.config.send_federation:

View file

@ -40,7 +40,7 @@ class ReplicationStreamProtocolFactory(Factory):
def __init__(self, hs):
self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.server_name = hs.config.server.server_name
# If we've created a `ReplicationStreamProtocolFactory` then we're
# almost certainly registering a replication listener, so let's ensure

View file

@ -42,7 +42,7 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs: "HomeServer"):
if hs.config.worker_app is None:
if hs.config.worker.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender
# will be a real FederationSender, which has stubs for current_token and

View file

@ -247,7 +247,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RegistrationTokenRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
if hs.config.worker_app is None:
if hs.config.worker.worker_app is None:
SendServerNoticeServlet(hs).register(http_server)

View file

@ -68,7 +68,10 @@ class AuthRestServlet(RestServlet):
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
% (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
% (
self.hs.config.server.public_baseurl,
self.hs.config.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
)
@ -135,7 +138,7 @@ class AuthRestServlet(RestServlet):
session=session,
terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.public_baseurl,
self.hs.config.server.public_baseurl,
self.hs.config.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web"

View file

@ -93,14 +93,14 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter = Ratelimiter(
store=hs.get_datastore(),
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
store=hs.get_datastore(),
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
)
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
@ -486,7 +486,7 @@ class SsoRedirectServlet(RestServlet):
# register themselves with the main SSOHandler.
_load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
self._public_baseurl = hs.config.public_baseurl
self._public_baseurl = hs.config.server.public_baseurl
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None

View file

@ -69,7 +69,7 @@ class IdTokenServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.server_name = hs.config.server.server_name
async def on_POST(
self, request: SynapseRequest, user_id: str

View file

@ -59,7 +59,7 @@ class PushRuleRestServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self._is_worker = hs.config.worker.worker_app is not None
self._users_new_default_push_rules = hs.config.users_new_default_push_rules

View file

@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
# Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1,
# Amount of artificial delay to apply
sleep_msec=1000,
sleep_delay=1000,
# Error with 429 if more than reject_limit requests are queued
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
concurrent=1,
),
)
@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet):
Returns:
dictionary for response from /register
"""
result = {"user_id": user_id, "home_server": self.hs.hostname}
result: JsonDict = {
"user_id": user_id,
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name, is_guest=True
)
result = {
result: JsonDict = {
"user_id": user_id,
"device_id": device_id,
"access_token": access_token,

View file

@ -388,7 +388,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
if server and server != self.hs.config.server.server_name:
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
@ -438,7 +438,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
if server and server != self.hs.config.server.server_name:
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)

View file

@ -86,12 +86,12 @@ class LocalKey(Resource):
json_object = {
"valid_until_ts": self.valid_until_ts,
"server_name": self.config.server_name,
"server_name": self.config.server.server_name,
"verify_keys": verify_keys,
"old_verify_keys": old_verify_keys,
}
for key in self.config.signing_key:
json_object = sign_json(json_object, self.config.server_name, key)
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
def render_GET(self, request):

View file

@ -224,7 +224,9 @@ class RemoteKey(DirectServeJsonResource):
for key_json in json_results:
key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
key_json = sign_json(
key_json, self.config.server.server_name, signing_key
)
signed_keys.append(key_json)

View file

@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None:
try:

View file

@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None:
try:

View file

@ -34,10 +34,10 @@ class WellKnownBuilder:
def get_well_known(self):
# if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None:
if self._config.server.public_baseurl is None:
return None
result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}}
if self._config.default_identity_server:
result["m.identity_server"] = {

View file

@ -313,7 +313,7 @@ class HomeServer(metaclass=abc.ABCMeta):
# Register background tasks required by this server. This must be done
# somewhat manually due to the background tasks not being registered
# unless handlers are instantiated.
if self.config.run_background_tasks:
if self.config.worker.run_background_tasks:
self.setup_background_tasks()
def start_listening(self) -> None:
@ -370,8 +370,8 @@ class HomeServer(metaclass=abc.ABCMeta):
return Ratelimiter(
store=self.get_datastore(),
clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count,
rate_hz=self.config.ratelimiting.rc_registration.per_second,
burst_count=self.config.ratelimiting.rc_registration.burst_count,
)
@cache_in_self
@ -498,7 +498,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_device_handler(self):
if self.config.worker_app:
if self.config.worker.worker_app:
return DeviceWorkerHandler(self)
else:
return DeviceHandler(self)
@ -621,7 +621,7 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_federation_sender(self) -> AbstractFederationSender:
if self.should_send_federation():
return FederationSender(self)
elif not self.config.worker_app:
elif not self.config.worker.worker_app:
return FederationRemoteSendQueue(self)
else:
raise Exception("Workers cannot send federation traffic")
@ -650,14 +650,14 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_groups_local_handler(
self,
) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app:
if self.config.worker.worker_app:
return GroupsLocalWorkerHandler(self)
else:
return GroupsLocalHandler(self)
@cache_in_self
def get_groups_server_handler(self):
if self.config.worker_app:
if self.config.worker.worker_app:
return GroupsServerWorkerHandler(self)
else:
return GroupsServerHandler(self)
@ -684,7 +684,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_room_member_handler(self) -> RoomMemberHandler:
if self.config.worker_app:
if self.config.worker.worker_app:
return RoomMemberWorkerHandler(self)
return RoomMemberMasterHandler(self)
@ -694,13 +694,13 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_server_notices_manager(self) -> ServerNoticesManager:
if self.config.worker_app:
if self.config.worker.worker_app:
raise Exception("Workers cannot send server notices")
return ServerNoticesManager(self)
@cache_in_self
def get_server_notices_sender(self) -> WorkerServerNoticesSender:
if self.config.worker_app:
if self.config.worker.worker_app:
return WorkerServerNoticesSender(self)
return ServerNoticesSender(self)
@ -766,7 +766,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
return FederationRateLimiter(
self.get_clock(), config=self.config.ratelimiting.rc_federation
)
@cache_in_self
def get_module_api(self) -> ModuleApi:

View file

@ -271,7 +271,7 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
args = [self.hs.config.server_name]
args = [self.hs.config.server.server_name]
# Set ordering
order_by_column = UserSortOrder(order_by).value
@ -356,13 +356,13 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig
return
user_domain = get_domain_from_id(rows[0][0])
if user_domain == config.server_name:
if user_domain == config.server.server_name:
return
raise Exception(
"Found users in database not native to %s!\n"
"You cannot change a synapse server_name after it's been configured"
% (config.server_name,)
% (config.server.server_name,)
)

View file

@ -35,7 +35,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
super().__init__(database, db_conn, hs)
if (
hs.config.run_background_tasks
hs.config.worker.run_background_tasks
and self.hs.config.redaction_retention_period is not None
):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)

View file

@ -355,7 +355,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
self.user_ips_max_age = hs.config.user_ips_max_age
if hs.config.run_background_tasks and self.user_ips_max_age:
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@wrap_as_background_process("prune_old_user_ips")

View file

@ -51,7 +51,7 @@ class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)

View file

@ -62,7 +62,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)

View file

@ -82,7 +82,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
self._doing_notif_rotation = False
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._rotate_notif_loop = self._clock.looping_call(
self._rotate_notifs, 30 * 60 * 1000
)

View file

@ -158,7 +158,7 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
self._cleanup_old_transaction_ids,

View file

@ -56,7 +56,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
# Used in _generate_user_daily_visits to keep track of progress

View file

@ -132,14 +132,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
hs.config.account_validity.account_validity_startup_job_max_delta
)
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.call_later(
0.0,
self._set_expiration_date_when_missing,
)
# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)
@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
assert self._account_validity_period is not None
expiration_ts = now_ms + self._account_validity_period
if use_delta:

View file

@ -815,7 +815,7 @@ class RoomWorkerStore(SQLBaseStore):
If it is `None` media will be removed from quarantine
"""
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server_name
is_local = server_name == self.config.server.server_name
def _quarantine_media_by_id_txn(txn):
local_mxcs = [media_id] if is_local else []

View file

@ -81,7 +81,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.close()
if (
self.hs.config.run_background_tasks
self.hs.config.worker.run_background_tasks
and self.hs.config.metrics_flags.known_servers
):
self._known_servers_count = 1
@ -196,6 +196,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) -> Dict[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member`
events, and so may be specific to this room rather than part of a user's
global profile. To avoid privacy leaks, the profile data should only be
revealed to users who are already in this room.
Args:
room_id: The ID of the room to retrieve the users of.

View file

@ -48,7 +48,7 @@ class SessionStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions.
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session(

View file

@ -672,7 +672,7 @@ class StatsStore(StateDeltasStore):
def get_users_media_usage_paginate_txn(txn):
filters = []
args = [self.hs.config.server_name]
args = [self.hs.config.server.server_name]
if search_term:
filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)")

View file

@ -60,7 +60,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
if hs.config.run_background_tasks:
if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
@wrap_as_background_process("cleanup_transactions")

View file

@ -196,7 +196,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.
for user_id, profile in users_with_profile.items():
@ -207,7 +206,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
to_insert = set()
if is_public:
for user_id in user_ids:
for user_id in users_with_profile:
if self.get_if_app_services_interested_in_user(user_id):
continue
@ -217,14 +216,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
await self.add_users_in_public_rooms(room_id, to_insert)
to_insert.clear()
else:
for user_id in user_ids:
for user_id in users_with_profile:
if not self.hs.is_mine_id(user_id):
continue
if self.get_if_app_services_interested_in_user(user_id):
continue
for other_user_id in user_ids:
for other_user_id in users_with_profile:
if user_id == other_user_id:
continue
@ -511,7 +510,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
self._prefer_local_users_in_search = (
hs.config.user_directory_search_prefer_local_users
)
self._server_name = hs.config.server_name
self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):

View file

@ -134,7 +134,7 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once.
if (
config.worker_app is not None
config.worker.worker_app is not None
and version_info.current_version != SCHEMA_VERSION
):
raise UpgradeDatabaseException(
@ -154,7 +154,7 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once.
if config and config.worker_app is not None:
if config and config.worker.worker_app is not None:
raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR)
_setup_new_database(cur, database_engine, databases=databases)
@ -355,7 +355,7 @@ def _upgrade_existing_database(
else:
assert config
is_worker = config and config.worker_app is not None
is_worker = config and config.worker.worker_app is not None
if (
current_schema_state.compat_version is not None

View file

@ -38,7 +38,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config")
pass
appservices = load_appservices(config.server_name, config_files)
appservices = load_appservices(config.server.server_name, config_files)
owned = {}

View file

@ -67,7 +67,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ?
"""
cur.execute(sql, ("%:" + config.server_name,))
cur.execute(sql, ("%:" + config.server.server_name,))
cur.execute(
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"

View file

@ -38,6 +38,7 @@ from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
IReactorTCP,
IReactorThreads,
IReactorTime,
)
@ -63,7 +64,12 @@ JsonDict = Dict[str, Any]
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
IReactorTCP,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,
IReactorThreads,
Interface,
):
"""The interfaces necessary for Synapse to function."""

View file

@ -15,27 +15,35 @@
import json
import logging
import re
from typing import Pattern
import typing
from typing import Any, Callable, Dict, Generator, Pattern
import attr
from frozendict import frozendict
from twisted.internet import defer, task
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.python.failure import Failure
from synapse.logging import context
if typing.TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
_WILDCARD_RUN = re.compile(r"([\?\*]+)")
def _reject_invalid_json(val):
def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_frozendict(obj):
def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
"""Helper for json_encoder. Makes frozendicts serializable by returning
the underlying dict
"""
@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder(
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure):
def unwrapFirstError(failure: Failure) -> Failure:
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
return failure.value.subFailure
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
@attr.s(slots=True)
@ -75,25 +83,25 @@ class Clock:
reactor: The Twisted reactor to use.
"""
_reactor = attr.ib()
_reactor: IReactorTime = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
@defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations
def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
return res
def time(self):
def time(self) -> float:
"""Returns the current system time in seconds since epoch."""
return self._reactor.seconds()
def time_msec(self):
def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(self, f, msec, *args, **kwargs):
def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall:
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
@ -102,8 +110,8 @@ class Clock:
other than trivial, you probably want to wrap it in run_as_background_process.
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
@ -113,7 +121,7 @@ class Clock:
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(self, delay, callback, *args, **kwargs):
def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall:
"""Call something later
Note that the function will be called with no logcontext, so if it is anything
@ -133,7 +141,7 @@ class Clock:
with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try:
timer.cancel()
except Exception:

View file

@ -37,6 +37,7 @@ import attr
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.base import ReactorBase
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
@ -268,6 +269,7 @@ class Linearizer:
if not clock:
from twisted.internet import reactor
assert isinstance(reactor, ReactorBase)
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@ -411,7 +413,7 @@ class ReadWriteLock:
# writers and readers have been resolved. The new writer replaces the latest
# writer.
def __init__(self):
def __init__(self) -> None:
# Latest readers queued
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
@ -503,7 +505,7 @@ def timeout_deferred(
timed_out = [False]
def time_it_out():
def time_it_out() -> None:
timed_out[0] = True
try:
@ -550,19 +552,21 @@ def timeout_deferred(
return new_d
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True)
class DoneAwaitable:
class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value."""
value = attr.ib()
value = attr.ib(type=Any) # should be: R
def __await__(self):
return self
def __iter__(self):
def __iter__(self) -> "DoneAwaitable":
return self
def __next__(self):
def __next__(self) -> None:
raise StopIteration(self.value)

View file

@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]):
# First we create a defer and add it and the value to the list of
# pending items.
d = defer.Deferred()
d: defer.Deferred[R] = defer.Deferred()
self._next_values.setdefault(key, []).append((value, d))
# If we're not currently processing the key fire off a background

View file

@ -64,32 +64,32 @@ class CacheMetric:
evicted_size = attr.ib(default=0)
memory_usage = attr.ib(default=None)
def inc_hits(self):
def inc_hits(self) -> None:
self.hits += 1
def inc_misses(self):
def inc_misses(self) -> None:
self.misses += 1
def inc_evictions(self, size=1):
def inc_evictions(self, size: int = 1) -> None:
self.evicted_size += size
def inc_memory_usage(self, memory: int):
def inc_memory_usage(self, memory: int) -> None:
if self.memory_usage is None:
self.memory_usage = 0
self.memory_usage += memory
def dec_memory_usage(self, memory: int):
def dec_memory_usage(self, memory: int) -> None:
self.memory_usage -= memory
def clear_memory_usage(self):
def clear_memory_usage(self) -> None:
if self.memory_usage is not None:
self.memory_usage = 0
def describe(self):
return []
def collect(self):
def collect(self) -> None:
try:
if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache))

View file

@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]):
TreeCache, "MutableMapping[KT, CacheEntry]"
] = cache_type()
def metrics_cb():
def metrics_cb() -> None:
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than
@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
def max_entries(self):
return self.cache.max_size
def check_thread(self):
def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]):
self._pending_deferred_cache[key] = entry
def compare_and_pop():
def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]):
return False
def cb(result):
def cb(result) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]):
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
def eb(_fail) -> None:
compare_and_pop()
entry.invalidate()
@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]):
for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
def invalidate_all(self):
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
@ -332,7 +332,7 @@ class CacheEntry:
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
def invalidate(self) -> None:
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:

View file

@ -27,10 +27,14 @@ logger = logging.getLogger(__name__)
KT = TypeVar("KT")
# The type of the dictionary keys.
DKT = TypeVar("DKT")
# The type of the dictionary values.
DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True)
class DictionaryEntry:
class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache
Attributes:
@ -43,10 +47,10 @@ class DictionaryEntry:
"""
full = attr.ib(type=bool)
known_absent = attr.ib()
value = attr.ib()
known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
def __len__(self):
def __len__(self) -> int:
return len(self.value)
@ -56,7 +60,7 @@ class _Sentinel(enum.Enum):
sentinel = object()
class DictionaryCache(Generic[KT, DKT]):
class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]):
Args:
key
dict_key: If given a set of keys then return only those keys
dict_keys: If given a set of keys then return only those keys
that exist in the cache.
Returns:
@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]):
self,
sequence: int,
key: KT,
value: Dict[DKT, Any],
value: Dict[DKT, DV],
fetched_keys: Optional[Set[DKT]] = None,
) -> None:
"""Updates the entry in the cache
@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]):
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
entry.known_absent.update(known_absent)
self.cache[key] = entry
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)

View file

@ -35,6 +35,7 @@ from typing import (
from typing_extensions import Literal
from twisted.internet import reactor
from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]):
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
if clock is None:
real_clock = Clock(reactor)
real_clock = Clock(cast(IReactorTime, reactor))
else:
real_clock = clock
@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]):
lock = threading.Lock()
def evict():
def evict() -> None:
while cache_len() > self.max_size:
# Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node

View file

@ -195,7 +195,7 @@ class StreamChangeCache:
for entity in r:
del self._entity_to_key[entity]
def _evict(self):
def _evict(self) -> None:
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)

View file

@ -35,17 +35,17 @@ class TreeCache:
root = {key_1: {key_2: _value}}
"""
def __init__(self):
self.size = 0
def __init__(self) -> None:
self.size: int = 0
self.root = TreeCacheNode()
def __setitem__(self, key, value):
return self.set(key, value)
def __setitem__(self, key, value) -> None:
self.set(key, value)
def __contains__(self, key):
def __contains__(self, key) -> bool:
return self.get(key, SENTINEL) is not SENTINEL
def set(self, key, value):
def set(self, key, value) -> None:
if isinstance(value, TreeCacheNode):
# this would mean we couldn't tell where our tree ended and the value
# started.
@ -73,7 +73,7 @@ class TreeCache:
return default
return node.get(key[-1], default)
def clear(self):
def clear(self) -> None:
self.size = 0
self.root = TreeCacheNode()
@ -128,7 +128,7 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
def __len__(self):
def __len__(self) -> int:
return self.size

View file

@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
signal.signal(signal.SIGTERM, sigterm)
# Cleanup pid file at exit.
def exit():
def exit() -> None:
logger.warning("Stopping daemon.")
os.remove(pid_file)
sys.exit(0)

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List
from twisted.internet import defer
@ -37,11 +38,11 @@ class Distributor:
model will do for today.
"""
def __init__(self):
self.signals = {}
self.pre_registration = {}
def __init__(self) -> None:
self.signals: Dict[str, Signal] = {}
self.pre_registration: Dict[str, List[Callable]] = {}
def declare(self, name):
def declare(self, name: str) -> None:
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
@ -52,7 +53,7 @@ class Distributor:
for observer in self.pre_registration[name]:
signal.observe(observer)
def observe(self, name, observer):
def observe(self, name: str, observer: Callable) -> None:
if name in self.signals:
self.signals[name].observe(observer)
else:
@ -62,7 +63,7 @@ class Distributor:
self.pre_registration[name] = []
self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs):
def fire(self, name: str, *args, **kwargs) -> None:
"""Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred.
@ -83,18 +84,18 @@ class Signal:
method into all of the observers.
"""
def __init__(self, name):
self.name = name
self.observers = []
def __init__(self, name: str):
self.name: str = name
self.observers: List[Callable] = []
def observe(self, observer):
def observe(self, observer: Callable) -> None:
"""Adds a new callable to the observer list which will be invoked by
the 'fire' method.
Each observer callable may return a Deferred."""
self.observers.append(observer)
def fire(self, *args, **kwargs):
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.

Some files were not shown because too many files have changed in this diff Show more