Compare commits

...

79 commits

Author SHA1 Message Date
Erik Johnston b044860e56 Merge branch 'erikj/smaller_events' into erikj/test_send 2021-05-05 16:46:41 +01:00
Erik Johnston 2e1a8878d5 Fix default 2021-05-05 16:46:33 +01:00
Erik Johnston a7b7770bef Merge branch 'erikj/smaller_events' into erikj/test_send 2021-05-05 16:41:14 +01:00
Erik Johnston 5d9bbca631 Make origin optional 2021-05-05 16:41:10 +01:00
Erik Johnston 015fdfe5bb Merge branch 'erikj/smaller_events' into erikj/test_send 2021-05-05 16:37:37 +01:00
Erik Johnston faa7d48930 More ensmalling 2021-05-05 16:35:16 +01:00
Erik Johnston f4bb01d41a Compress 2021-05-05 15:06:47 +01:00
Erik Johnston ec1c2c69a2 Encode json dict 2021-05-05 15:06:46 +01:00
Erik Johnston 7e7b99bca9 Slots 2021-05-05 15:06:28 +01:00
Erik Johnston 3d937d23fd Don't use DictProperty 2021-05-05 15:06:27 +01:00
Erik Johnston c856e29ccd Remvoe dictionary based access form events 2021-05-05 13:52:29 +01:00
Erik Johnston 941a0a76d3 Fix log contexts 2021-05-05 11:48:09 +01:00
Erik Johnston b0d014819f Fix log contexts 2021-05-05 11:36:24 +01:00
Erik Johnston eeafa29399 Merge branch 'erikj/better_backfill' into erikj/test_send 2021-05-05 11:29:29 +01:00
Erik Johnston 016d55b94b Use less memory when backfilling 2021-05-05 11:25:27 +01:00
Erik Johnston adef51ab98 Fix cache metrics 2021-05-05 10:33:23 +01:00
Erik Johnston cdeb6050ea Log contexts 2021-05-05 10:25:54 +01:00
Erik Johnston 88bd909a4a Merge branch 'erikj/jemalloc_stats' into erikj/test_send 2021-05-05 10:20:36 +01:00
Erik Johnston a94edad23b Merge branch 'erikj/cache_mem_size' into erikj/test_send 2021-05-05 10:14:07 +01:00
Erik Johnston fc17e4e62e fix logging contexts 2021-05-04 18:09:03 +01:00
Erik Johnston 0d9d84dac0 fix logging contexts 2021-05-04 18:06:23 +01:00
Erik Johnston 04db7b9581 Merge branch 'erikj/limit_how_often_gc' into erikj/test_send 2021-05-04 17:59:35 +01:00
Erik Johnston 24965fc073 Merge branch 'erikj/efficient_presence_join' into erikj/test_send 2021-05-04 17:59:31 +01:00
Erik Johnston 5b031e2da3 Merge branch 'erikj/fix_presence_joined' into erikj/test_send 2021-05-04 17:59:16 +01:00
Erik Johnston 14b70bbd9f Merge branch 'erikj/refactor_keyring' into erikj/test_send 2021-05-04 17:57:57 +01:00
Erik Johnston d4175abe52 Allow fetching events 2021-05-04 17:57:46 +01:00
Erik Johnston b76fe71627 Fix remote resource 2021-05-04 17:57:46 +01:00
Erik Johnston 7f237d5639 Remove key_ready 2021-05-04 17:57:46 +01:00
Erik Johnston f37c5843d3 Merge branch 'erikj/refactor_keyring' into erikj/test_send 2021-05-04 17:37:22 +01:00
Erik Johnston d6ae1aef46 Merge remote-tracking branch 'origin/develop' into erikj/test_send 2021-05-04 17:37:15 +01:00
Erik Johnston 3bfd3c55f9 Refactor keyring 2021-05-04 17:36:02 +01:00
Erik Johnston cad5a47621 Bugfix newsfile 2021-05-04 15:54:39 +01:00
Erik Johnston 7e3d333b28 Move newsfile 2021-05-04 14:51:28 +01:00
Erik Johnston aabc46f0f6 Merge remote-tracking branch 'origin/develop' into erikj/cache_mem_size 2021-05-04 14:49:16 +01:00
Erik Johnston 78e3502ada Always report memory usage metrics when TRACK_MEMORY_USAGE is True 2021-05-04 14:32:42 +01:00
Erik Johnston 8206069c63 Comment 2021-05-04 14:29:30 +01:00
Erik Johnston a99524f383
Apply suggestions from code review
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2021-05-04 14:29:26 +01:00
Erik Johnston b5169b68e9 Document default. Add type annotations. Correctly convert to seconds 2021-05-04 14:23:02 +01:00
Erik Johnston 4c9446c4cb isort 2021-05-04 14:14:08 +01:00
Erik Johnston 4a8a483060 Fix store.get_users_in_room_with_profiles 2021-05-04 14:09:44 +01:00
Erik Johnston d145ba6ccc Move jemalloc to metrics to a sepearte file, and load from app to get proper logs 2021-05-04 14:00:12 +01:00
Erik Johnston dcb79da38a More decriptive log when failing to set up jemalloc collector 2021-05-04 13:40:38 +01:00
Erik Johnston 35c13c730c
Apply suggestions from code review
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2021-05-04 13:39:44 +01:00
Erik Johnston 8624333cd9 Correctly invalidate get_users_in_room_with_profiles cache 2021-05-04 13:27:28 +01:00
Erik Johnston 4caa84b279 Use lists instead of sets where appropriate 2021-05-04 13:16:15 +01:00
Erik Johnston 48cf260c7a Process state deltas in presence by room 2021-05-04 13:16:02 +01:00
Erik Johnston 7e5f78a698 Convert other uses of get_current_users_in_room and add warning 2021-05-04 13:02:18 +01:00
Erik Johnston 43c9acda4c Config 2021-05-04 11:49:13 +01:00
Erik Johnston bd04fb6308 Code review 2021-05-04 10:47:32 +01:00
Erik Johnston d3a6e38c96
Apply suggestions from code review
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2021-05-04 10:40:18 +01:00
Erik Johnston aa1a026509 Stuff 2021-04-30 15:19:31 +01:00
Erik Johnston 260c760d69 Don't log response 2021-04-30 15:18:17 +01:00
Erik Johnston 49da5e9ec4 Chunk _check_sigs_and_hash_and_fetch 2021-04-30 15:17:50 +01:00
Erik Johnston 3b2991e3fb Log memory usage 2021-04-30 15:00:22 +01:00
Erik Johnston aec80899ab Merge branch 'erikj/stream_deserealize' into erikj/test_send 2021-04-30 14:21:58 +01:00
Erik Johnston 68f1d258d9 Use ijson 2021-04-30 14:19:27 +01:00
Erik Johnston 8481bacc93 Merge branch 'erikj/efficient_presence_join' into erikj/test_send 2021-04-30 13:37:08 +01:00
Erik Johnston 68b6106ce5 Newsfile 2021-04-30 13:36:50 +01:00
Erik Johnston 0ed608cf56 Increase perf of handling presence when joining large rooms.
We ended up doing a *lot* of duplicate work, and e.g. ended up doing n^2
worth of `is_mine_id(..)` checks across all joined users.
2021-04-30 13:34:15 +01:00
Erik Johnston ac0143c4ac Record size of incoming bytes 2021-04-30 10:24:19 +01:00
Erik Johnston f5a25c7b53 Merge branch 'erikj/limit_how_often_gc' into erikj/test_send 2021-04-30 10:14:23 +01:00
Erik Johnston 5813719696 Merge branch 'erikj/fix_presence_joined' into erikj/test_send 2021-04-30 10:14:18 +01:00
Erik Johnston 6640fb467f Use correct name 2021-04-29 17:44:54 +01:00
Erik Johnston 0c8cd62149 Newsfile 2021-04-29 17:36:34 +01:00
Erik Johnston 996c0ce3d5 Use get_current_users_in_room from store and not StateHandler 2021-04-29 17:35:47 +01:00
Erik Johnston 938efeb595 Add some logging 2021-04-29 16:41:16 +01:00
Erik Johnston 4a3a9597f5 Merge remote-tracking branch 'origin/develop' into erikj/test_send 2021-04-29 16:41:04 +01:00
Erik Johnston 351f886bc8 Newsfile 2021-04-28 14:55:46 +01:00
Erik Johnston 79627b3a3c Limit how often GC happens by time.
Synapse can be quite memory intensive, and unless care is taken to tune
the GC thresholds it can end up thrashing, causing noticable performance
problems for large servers. We fix this by limiting how often we GC a
given generation, regardless of current counts/thresholds.

This does not help with the reverse problem where the thresholds are set
too high, but that should only happen in situations where they've been
manually configured.

Adds a `gc_min_seconds_between` config option to override the defaults.

Fixes #9890.
2021-04-28 14:51:31 +01:00
Erik Johnston 6237096e80 Newsfile 2021-04-26 14:23:15 +01:00
Erik Johnston 1b4ec8ef0e Export jemalloc stats to prometheus when used 2021-04-26 14:18:06 +01:00
Erik Johnston 5add13e05d Newsfile 2021-04-26 11:13:08 +01:00
Erik Johnston 2bf93f9b34 Fix 2021-04-26 10:58:04 +01:00
Erik Johnston bcf8858b67 Don't explode if memory has been twiddled 2021-04-26 10:56:42 +01:00
Erik Johnston 99fb72e63e Move TRACK_MEMORY_USAGE to root 2021-04-26 10:50:15 +01:00
Erik Johnston 567fe5e387 Make TRACK_MEMORY_USAGE configurable 2021-04-26 10:39:54 +01:00
Erik Johnston 0c9bab290f Ignore singletons 2021-04-26 10:29:26 +01:00
Erik Johnston 5003bd29d2 Don't have a global Asizer 2021-04-23 17:16:49 +01:00
Erik Johnston e9f5812eff Track memory usage of caches 2021-04-23 16:26:10 +01:00
43 changed files with 1035 additions and 690 deletions

1
changelog.d/9881.feature Normal file
View file

@ -0,0 +1 @@
Add experimental option to track memory usage of the caches.

1
changelog.d/9882.misc Normal file
View file

@ -0,0 +1 @@
Export jemalloc stats to Prometheus if it is being used.

1
changelog.d/9902.feature Normal file
View file

@ -0,0 +1 @@
Add limits to how often Synapse will GC, ensuring that large servers do not end up GC thrashing if `gc_thresholds` has not been correctly set.

1
changelog.d/9910.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where user directory could get out of sync if room visibility and membership changed in quick succession.

1
changelog.d/9910.feature Normal file
View file

@ -0,0 +1 @@
Improve performance after joining a large room when presence is enabled.

1
changelog.d/9916.misc Normal file
View file

@ -0,0 +1 @@
Improve performance of handling presence when joining large rooms.

View file

@ -152,6 +152,16 @@ presence:
#
#gc_thresholds: [700, 10, 10]
# The minimum time in seconds between each GC for a generation, regardless of
# the GC thresholds. This ensures that we don't do GC too frequently.
#
# A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive
# generation 0 GCs, etc.
#
# Defaults to `[1s, 10s, 30s]`.
#
#gc_min_interval: [0.5s, 30s, 1m]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is 100. -1 means no upper limit.
#

View file

@ -171,3 +171,6 @@ ignore_missing_imports = True
[mypy-txacme.*]
ignore_missing_imports = True
[mypy-pympler.*]
ignore_missing_imports = True

View file

@ -23,6 +23,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.types import RoomID, UserID
FILTER_SCHEMA = {
@ -290,6 +291,13 @@ class Filter:
ev_type = "m.presence"
contains_url = False
labels = [] # type: List[str]
elif isinstance(event, EventBase):
sender = event.sender
room_id = event.room_id
ev_type = event.type
content = event.content
contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
else:
sender = event.get("sender", None)
if not sender:

View file

@ -37,6 +37,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
@ -115,6 +116,7 @@ def start_reactor(
def run():
logger.info("Running")
setup_jemalloc_stats()
change_resource_limit(soft_file_limit)
if gc_thresholds:
gc.set_threshold(*gc_thresholds)

View file

@ -454,6 +454,10 @@ def start(config_options):
config.server.update_user_directory = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
if config.server.gc_seconds:
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = GenericWorkerServer(
config.server_name,

View file

@ -341,6 +341,10 @@ def setup(config_options):
sys.exit(0)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
if config.server.gc_seconds:
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = SynapseHomeServer(
config.server_name,

View file

@ -17,6 +17,8 @@ import re
import threading
from typing import Callable, Dict
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
# The prefix for all cache factor-related environment variables
@ -189,6 +191,15 @@ class CacheConfig(Config):
)
self.cache_factors[cache] = factor
self.track_memory_usage = cache_config.get("track_memory_usage", False)
if self.track_memory_usage:
try:
check_requirements("cache_memory")
except DependencyException as e:
raise ConfigError(
e.message # noqa: B306, DependencyException.message is a property
)
# Resize all caches (if necessary) with the new factors we've loaded
self.resize_all_caches()

View file

@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
from typing import Any, Dict, Iterable, List, Optional, Set
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
import yaml
@ -572,6 +572,7 @@ class ServerConfig(Config):
_warn_if_webclient_configured(self.listeners)
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
@attr.s
class LimitRemoteRoomsConfig:
@ -917,6 +918,16 @@ class ServerConfig(Config):
#
#gc_thresholds: [700, 10, 10]
# The minimum time in seconds between each GC for a generation, regardless of
# the GC thresholds. This ensures that we don't do GC too frequently.
#
# A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive
# generation 0 GCs, etc.
#
# Defaults to `[1s, 10s, 30s]`.
#
#gc_min_interval: [0.5s, 30s, 1m]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is 100. -1 means no upper limit.
#
@ -1305,6 +1316,24 @@ class ServerConfig(Config):
help="Turn on the twisted telnet manhole service on the given port.",
)
def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
"""Reads the three durations for the GC min interval option, returning seconds."""
if durations is None:
return None
try:
if len(durations) != 3:
raise ValueError()
return (
self.parse_duration(durations[0]) / 1000,
self.parse_duration(durations[1]) / 1000,
self.parse_duration(durations[2]) / 1000,
)
except Exception:
raise ConfigError(
"Value of `gc_min_interval` must be a list of three durations if set"
)
def is_threepid_reserved(reserved_threepids, threepid):
"""Check the threepid against the reserved threepid config

View file

@ -48,7 +48,7 @@ def check_event_content_hash(
# some malformed events lack a 'hashes'. Protect against it being missing
# or a weird type by basically treating it the same as an unhashed event.
hashes = event.get("hashes")
hashes = getattr(event, "hashes", None)
# nb it might be a frozendict or a dict
if not isinstance(hashes, collections.abc.Mapping):
raise SynapseError(

View file

@ -16,8 +16,7 @@
import abc
import logging
import urllib
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
from signedjson.key import (
@ -42,17 +41,18 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
preserve_fn,
run_in_background,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.async_helpers import Linearizer, yieldable_gather_results
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@ -74,8 +74,6 @@ class VerifyJsonRequest:
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
request_name: The name of the request.
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
@ -88,20 +86,93 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
json_object = attr.ib(type=JsonDict)
json_object_callback = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
key_ids = attr.ib(init=False, type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
key_ids = attr.ib(type=List[str])
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@staticmethod
def from_json_object(
server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
):
key_ids = signature_ids(json_object, server_name)
return VerifyJsonRequest(
server_name, lambda: json_object, minimum_valid_until_ms, key_ids
)
@staticmethod
def from_event(
server_name: str,
minimum_valid_until_ms: int,
event: EventBase,
):
key_ids = list(event.signatures.get(server_name, []))
return VerifyJsonRequest(
server_name,
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
key_ids,
)
class KeyLookupError(ValueError):
pass
@attr.s(slots=True)
class _QueueValue:
server_name = attr.ib(type=str)
minimum_valid_until_ts = attr.ib(type=int)
key_ids = attr.ib(type=List[str])
class _Queue:
def __init__(self, name, clock, process_items):
self._name = name
self._clock = clock
self._is_processing = False
self._next_values = []
self.process_items = process_items
async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]:
d = defer.Deferred()
self._next_values.append((value, d))
if not self._is_processing:
run_as_background_process(self._name, self._unsafe_process)
return await make_deferred_yieldable(d)
async def _unsafe_process(self):
try:
if self._is_processing:
return
self._is_processing = True
while self._next_values:
# We purposefully defer to the next loop.
await self._clock.sleep(0)
next_values = self._next_values
self._next_values = []
try:
values = [value for value, _ in next_values]
results = await self.process_items(values)
for value, deferred in next_values:
with PreserveLoggingContext():
deferred.callback(results.get(value.server_name, {}))
except Exception as e:
for _, deferred in next_values:
deferred.errback(e)
finally:
self._is_processing = False
class Keyring:
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
@ -116,12 +187,7 @@ class Keyring:
)
self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} # type: Dict[str, defer.Deferred]
self._server_queue = Linearizer("keyring_server")
def verify_json_for_server(
self,
@ -130,365 +196,150 @@ class Keyring:
validity_time: int,
request_name: str,
) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server
Args:
server_name: name of the server which must have signed this object
json_object: object to be checked
validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
request_name: an identifier for this json object (eg, an event id)
for logging.
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
request = VerifyJsonRequest.from_json_object(
server_name,
validity_time,
json_object,
)
return defer.ensureDeferred(self._verify_object(request))
def verify_json_objects_for_server(
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json:
Iterable of (server_name, json_object, validity_time, request_name)
tuples.
validity_time is a timestamp at which the signing key must be
valid.
request_name is an identifier for this json object (eg, an event id)
for logging.
Returns:
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
for server_name, json_object, validity_time, request_name in server_and_json
)
def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server
Args:
verify_requests: Iterable of verification requests.
Returns:
List<Deferred[None]>: for each input item, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
# a list of VerifyJsonRequests which are awaiting a key lookup
key_lookups = []
handle = preserve_fn(_handle_key_deferred)
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which
will complete or fail (in the sentinel context) when verification completes.
"""
if not verify_request.key_ids:
return defer.fail(
SynapseError(
400,
"Not signed by %s" % (verify_request.server_name,),
Codes.UNAUTHORIZED,
)
return [
defer.ensureDeferred(
run_in_background(
self._verify_object,
VerifyJsonRequest.from_json_object(
server_name,
validity_time,
json_object,
),
)
logger.debug(
"Verifying %s for %s with key_ids %s, min_validity %i",
verify_request.request_name,
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
)
for server_name, json_object, validity_time, request_name in server_and_json
]
# add the key request to the queue, but don't start it off yet.
key_lookups.append(verify_request)
# now run _handle_key_deferred, which will wait for the key request
# to complete and then do the verification.
#
# We want _handle_key_request to log to the right context, so we
# wrap it with preserve_fn (aka run_in_background)
return handle(verify_request)
results = [process(r) for r in verify_requests]
if key_lookups:
run_in_background(self._start_key_lookups, key_lookups)
return results
async def _start_key_lookups(
self, verify_requests: List[VerifyJsonRequest]
) -> None:
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
Args:
verify_requests:
"""
try:
# map from server name to a set of outstanding request ids
server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding.
await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
for server_name in server_to_request_ids.keys():
self.key_downloads[server_name] = defer.Deferred()
logger.debug("Got key lookup lock on %s", server_name)
# When we've finished fetching all the keys for a given server_name,
# drop the lock by resolving the deferred in key_downloads.
def drop_server_lock(server_name):
d = self.key_downloads.pop(server_name)
d.callback(None)
def lookup_done(res, verify_request):
server_name = verify_request.server_name
server_requests = server_to_request_ids[server_name]
server_requests.remove(id(verify_request))
# if there are no more requests for this server, we can drop the lock.
if not server_requests:
logger.debug("Releasing key lookup lock on %s", server_name)
drop_server_lock(server_name)
return res
for verify_request in verify_requests:
verify_request.key_ready.addBoth(lookup_done, verify_request)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names: list of servers which we want to look up
Returns:
Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
[w[0] for w in wait_on],
loop_count,
def verify_events_for_server(
self, server_and_json: Iterable[Tuple[str, EventBase, int]]
) -> List[defer.Deferred]:
return [
run_in_background(
self._verify_object,
VerifyJsonRequest.from_event(
server_name,
validity_time,
event,
),
)
with PreserveLoggingContext():
await defer.DeferredList((w[1] for w in wait_on))
for server_name, event, validity_time in server_and_json
]
loop_count += 1
async def _verify_object(self, verify_request: VerifyJsonRequest):
# TODO: Use a batching thing.
with (await self._server_queue.queue(verify_request.server_name)):
found_keys: Dict[str, FetchKeyResult] = {}
missing_key_ids = set(verify_request.key_ids)
for fetcher in self._key_fetchers:
if not missing_key_ids:
break
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
Args:
verify_requests: list of verify requests
"""
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
async def do_iterations():
try:
with Measure(self.clock, "get_server_verify_keys"):
for f in self._key_fetchers:
if not remaining_requests:
return
await self._attempt_key_fetches_with_fetcher(
f, remaining_requests
)
# look for any requests which weren't satisfied
while remaining_requests:
verify_request = remaining_requests.pop()
rq_str = (
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
% (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
)
)
# If we run the errback immediately, it may cancel our
# loggingcontext while we are still in it, so instead we
# schedule it for the next time round the reactor.
#
# (this also ensures that we don't get a stack overflow if we
# has a massive queue of lookups waiting for this server).
self.clock.call_later(
0,
verify_request.key_ready.errback,
SynapseError(
401,
"Failed to find any key to satisfy %s" % (rq_str,),
Codes.UNAUTHORIZED,
),
)
except Exception as err:
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved.
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in remaining_requests:
if not verify_request.key_ready.called:
verify_request.key_ready.errback(err)
run_in_background(do_iterations)
async def _attempt_key_fetches_with_fetcher(
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
fetcher: fetcher to use to fetch the keys
remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
# The keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.key_ready.called
keys_for_server = missing_keys[verify_request.server_name]
for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
keys = await fetcher.get_keys(
verify_request.server_name,
list(missing_key_ids),
verify_request.minimum_valid_until_ts,
)
results = await fetcher.get_keys(missing_keys)
for key_id, key in keys.items():
if not key:
continue
completed = []
for verify_request in remaining_requests:
server_name = verify_request.server_name
if key.valid_until_ts < verify_request.minimum_valid_until_ts:
continue
existing_key = found_keys.get(key_id)
if existing_key:
if key.valid_until_ts <= existing_key.valid_until_ts:
continue
found_keys[key_id] = key
missing_key_ids.difference_update(found_keys)
if missing_key_ids:
raise SynapseError(
400,
"Missing keys for %s: %s"
% (verify_request.server_name, missing_key_ids),
Codes.UNAUTHORIZED,
)
# see if any of the keys we got this time are sufficient to
# complete this VerifyJsonRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id)
if not fetch_key_result:
# we didn't get a result for this key
continue
if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue
# we have a valid key for this request. If we run the callback
# immediately, it may cancel our loggingcontext while we are still in
# it, so instead we schedule it for the next time round the reactor.
#
# (this also ensures that we don't get a stack overflow if we had
# a massive queue of lookups waiting for this server).
logger.debug(
"Found key %s:%s for %s",
server_name,
key_id,
verify_request.request_name,
)
self.clock.call_later(
0,
verify_request.key_ready.callback,
(server_name, key_id, fetch_key_result.verify_key),
)
completed.append(verify_request)
break
remaining_requests.difference_update(completed)
verify_key = found_keys[key_id].verify_key
try:
json_object = verify_request.json_object_callback()
verify_signed_json(
json_object,
verify_request.server_name,
verify_key,
)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
verify_request.server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s: %s"
% (
verify_request.server_name,
verify_key.alg,
verify_key.version,
str(e),
),
Codes.UNAUTHORIZED,
)
class KeyFetcher(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
def __init__(self, hs: "HomeServer"):
self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys)
Returns:
Map from server_name -> key_id -> FetchKeyResult
"""
raise NotImplementedError
async def get_keys(
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
return await self._queue.add_to_queue(
_QueueValue(
server_name=server_name,
key_ids=key_ids,
minimum_valid_until_ts=minimum_valid_until_ts,
)
)
@abc.abstractmethod
async def _fetch_keys(
self, keys_to_fetch: List[_QueueValue]
) -> Dict[str, Dict[str, FetchKeyResult]]:
pass
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]):
key_ids_to_fetch = (
(server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
(queue_value.server_name, key_id)
for queue_value in keys_to_fetch
for key_id in queue_value.key_ids
)
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@ -500,6 +351,8 @@ class StoreKeyFetcher(KeyFetcher):
class BaseV2KeyFetcher(KeyFetcher):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
self.config = hs.config
@ -607,10 +460,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
async def _fetch_keys(
self, keys_to_fetch: List[_QueueValue]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
"""see KeyFetcher._fetch_keys"""
async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
@ -646,12 +499,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
async def get_server_verify_key_v2_indirect(
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
the keys to be fetched.
key_server: notary server to query for the keys
@ -665,7 +518,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name = key_server.server_name
logger.info(
"Requesting keys %s from notary server %s",
keys_to_fetch.items(),
keys_to_fetch,
perspective_name,
)
@ -675,11 +528,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/query",
data={
"server_keys": {
server_name: {
key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
queue_value.server_name: {
key_id: {
"minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
}
for key_id in queue_value.key_ids
}
for server_name, server_keys in keys_to_fetch.items()
for queue_value in keys_to_fetch
}
},
)
@ -779,8 +634,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
async def _fetch_keys(
self, keys_to_fetch: List[_QueueValue]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
@ -793,8 +648,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {}
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item
async def get_key(key_to_fetch_item: _QueueValue) -> None:
server_name = key_to_fetch_item.server_name
key_ids = key_to_fetch_item.key_ids
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
@ -805,7 +662,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
await yieldable_gather_results(get_key, keys_to_fetch.items())
await yieldable_gather_results(get_key, keys_to_fetch)
return results
async def get_server_verify_key_v2_direct(
@ -877,37 +734,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys.update(response_keys)
return keys
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
verify_request:
Raises:
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s: %s"
% (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)

View file

@ -418,7 +418,9 @@ def get_send_level(
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
power_levels_event = _get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
send_level = get_send_level(
event.type, getattr(event, "state_key", None), power_levels_event
)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:

View file

@ -16,12 +16,15 @@
import abc
import os
from typing import Dict, Optional, Tuple, Type
import zlib
from typing import Dict, List, Optional, Tuple, Type, Union
from unpaddedbase64 import encode_base64
import attr
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.types import JsonDict, RoomStreamToken
from synapse.util import json_decoder, json_encoder
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
@ -37,6 +40,26 @@ from synapse.util.stringutils import strtobool
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
_PRESET_ZDICT = b"""{"auth_events":[],"prev_events":[],"type":"m.room.member",m.room.message"room_id":,"sender":,"content":{"msgtype":"m.text","body":""room_version":"creator":"depth":"prev_state":"state_key":""origin":"origin_server_ts":"hashes":{"sha256":"signatures":,"unsigned":{"age_ts":"ed25519"""
def _encode_dict(d: JsonDict) -> bytes:
json_bytes = json_encoder.encode(d).encode("utf-8")
c = zlib.compressobj(1, zdict=_PRESET_ZDICT)
result_bytes = c.compress(json_bytes)
result_bytes += c.flush()
return result_bytes
def _decode_dict(b: bytes) -> JsonDict:
d = zlib.decompressobj(zdict=_PRESET_ZDICT)
result_bytes = d.decompress(b)
result_bytes += d.flush()
return json_decoder.decode(result_bytes.decode("utf-8"))
class DictProperty:
"""An object property which delegates to the `_dict` within its parent object."""
@ -205,7 +228,81 @@ class _EventInternalMetadata:
return self._dict.get("redacted", False)
@attr.s(slots=True, auto_attribs=True)
class _Signatures:
_signatures_bytes: bytes
@staticmethod
def from_dict(signature_dict: JsonDict) -> "_Signatures":
return _Signatures(_encode_dict(signature_dict))
def get_dict(self) -> JsonDict:
return _decode_dict(self._signatures_bytes)
def get(self, server_name, default=None):
return self.get_dict().get(server_name, default)
def update(self, other: Union[JsonDict, "_Signatures"]):
if isinstance(other, _Signatures):
other_dict = _decode_dict(other._signatures_bytes)
else:
other_dict = other
signatures = self.get_dict()
signatures.update(other_dict)
self._signatures_bytes = _encode_dict(signatures)
class _SmallListV1(str):
__slots__ = []
def get(self):
return self.split(",")
@staticmethod
def create(event_ids):
return _SmallListV1(",".join(event_ids))
class _SmallListV2_V3(bytes):
__slots__ = []
def get(self, url_safe):
i = 0
while i * 32 < len(self):
bit = self[i * 32 : (i + 1) * 32]
i += 1
yield "$" + encode_base64(bit, urlsafe=url_safe)
@staticmethod
def create(event_ids):
return _SmallListV2_V3(
b"".join(decode_base64(event_id[1:]) for event_id in event_ids)
)
class EventBase(metaclass=abc.ABCMeta):
__slots__ = [
"room_version",
"signatures",
"unsigned",
"rejected_reason",
"_encoded_dict",
"_auth_event_ids",
"depth",
"_content",
"_hashes",
"origin",
"origin_server_ts",
"_prev_event_ids",
"redacts",
"room_id",
"sender",
"type",
"state_key",
"internal_metadata",
]
@property
@abc.abstractmethod
def format_version(self) -> int:
@ -224,32 +321,44 @@ class EventBase(metaclass=abc.ABCMeta):
assert room_version.event_format == self.format_version
self.room_version = room_version
self.signatures = signatures
self.signatures = _Signatures.from_dict(signatures)
self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._dict = event_dict
self._encoded_dict = _encode_dict(event_dict)
self.depth = event_dict["depth"]
self.origin = event_dict.get("origin")
self.origin_server_ts = event_dict["origin_server_ts"]
self.redacts = event_dict.get("redacts")
self.room_id = event_dict["room_id"]
self.sender = event_dict["sender"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
auth_events = DictProperty("auth_events")
depth = DictProperty("depth")
content = DictProperty("content")
hashes = DictProperty("hashes")
origin = DictProperty("origin")
origin_server_ts = DictProperty("origin_server_ts")
prev_events = DictProperty("prev_events")
redacts = DefaultDictProperty("redacts", None)
room_id = DictProperty("room_id")
sender = DictProperty("sender")
state_key = DictProperty("state_key")
type = DictProperty("type")
user_id = DictProperty("sender")
@property
def content(self) -> JsonDict:
return self.get_dict()["content"]
@property
def hashes(self) -> JsonDict:
return self.get_dict()["hashes"]
@property
def prev_events(self) -> List[str]:
return list(self._prev_events)
@property
def event_id(self) -> str:
raise NotImplementedError()
@property
def user_id(self) -> str:
return self.sender
@property
def membership(self):
return self.content["membership"]
@ -258,17 +367,13 @@ class EventBase(metaclass=abc.ABCMeta):
return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self) -> JsonDict:
d = dict(self._dict)
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
d = _decode_dict(self._encoded_dict)
d.update(
{"signatures": self.signatures.get_dict(), "unsigned": dict(self.unsigned)}
)
return d
def get(self, key, default=None):
return self._dict.get(key, default)
def get_internal_metadata_dict(self):
return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now=None) -> JsonDict:
pdu_json = self.get_dict()
@ -285,41 +390,11 @@ class EventBase(metaclass=abc.ABCMeta):
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
return self._dict[field]
def __contains__(self, field):
return field in self._dict
def items(self):
return list(self._dict.items())
def keys(self):
return self._dict.keys()
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return [e for e, _ in self.prev_events]
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return [e for e, _ in self.auth_events]
def freeze(self):
"""'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
# self._dict = freeze(self._dict)
class FrozenEvent(EventBase):
@ -355,6 +430,12 @@ class FrozenEvent(EventBase):
frozen_dict = event_dict
self._event_id = event_dict["event_id"]
self._auth_event_ids = _SmallListV1.create(
e for e, _ in event_dict["auth_events"]
)
self._prev_event_ids = _SmallListV1.create(
e for e, _ in event_dict["prev_events"]
)
super().__init__(
frozen_dict,
@ -369,18 +450,26 @@ class FrozenEvent(EventBase):
def event_id(self) -> str:
return self._event_id
def auth_event_ids(self):
return list(self._auth_event_ids.get())
def prev_event_ids(self):
return list(self._prev_event_ids.get())
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
self.event_id,
self.type,
getattr(self, "state_key", None),
)
class FrozenEventV2(EventBase):
__slots__ = ["_event_id"]
format_version = EventFormatVersions.V2 # All events of this type are V2
def __init__(
@ -415,6 +504,8 @@ class FrozenEventV2(EventBase):
frozen_dict = event_dict
self._event_id = None
self._auth_event_ids = _SmallListV2_V3.create(event_dict["auth_events"])
self._prev_event_ids = _SmallListV2_V3.create(event_dict["prev_events"])
super().__init__(
frozen_dict,
@ -436,24 +527,6 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's prev_events
"""
return self.prev_events
def auth_event_ids(self):
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
list[str]: The list of event IDs of this event's auth_events
"""
return self.auth_events
def __str__(self):
return self.__repr__()
@ -461,14 +534,22 @@ class FrozenEventV2(EventBase):
return "<%s event_id=%r, type=%r, state_key=%r>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
self.type,
self.state_key if self.is_state() else None,
)
def auth_event_ids(self):
return list(self._auth_event_ids.get(False))
def prev_event_ids(self):
return list(self._prev_event_ids.get(False))
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
__slots__ = ["_event_id"]
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
@ -484,6 +565,12 @@ class FrozenEventV3(FrozenEventV2):
)
return self._event_id
def auth_event_ids(self):
return list(self._auth_event_ids.get(True))
def prev_event_ids(self):
return list(self._prev_event_ids.get(True))
def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
"""Returns the python type to use to construct an Event object for the

View file

@ -38,6 +38,8 @@ class EventValidator:
if event.format_version == EventFormatVersions.V1:
EventID.from_string(event.event_id)
event_dict = event.get_dict()
required = [
"auth_events",
"content",
@ -49,7 +51,7 @@ class EventValidator:
]
for k in required:
if not hasattr(event, k):
if k not in event_dict:
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values

View file

@ -73,10 +73,10 @@ class FederationBase:
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = current_context()
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
@defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
@ -90,9 +90,7 @@ class FederationBase:
# received event was probably a redacted copy (but we then use our
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
if set(redacted_event.keys()) == set(pdu.keys()) and set(
redacted_event.content.keys()
) == set(pdu.content.keys()):
if set(redacted_event.content.keys()) == set(pdu.content.keys()):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
@ -137,11 +135,7 @@ class FederationBase:
return deferreds
class PduToCheckSig(
namedtuple(
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
)
):
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@ -184,7 +178,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@ -195,13 +188,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@ -230,13 +222,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
more_deferreds = keyring.verify_json_objects_for_server(
more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
]

View file

@ -33,6 +33,7 @@ from typing import (
)
import attr
import ijson
from prometheus_client import Counter
from twisted.internet import defer
@ -55,11 +56,16 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.context import (
get_thread_resource_usage,
make_deferred_yieldable,
preserve_fn,
)
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@ -385,7 +391,6 @@ class FederationClient(FederationBase):
Returns:
A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
async def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
@ -420,6 +425,7 @@ class FederationClient(FederationBase):
return res
handle = preserve_fn(handle_check_result)
deferreds = self._check_sigs_and_hashes(room_version, pdus)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = await make_deferred_yieldable(
@ -667,19 +673,37 @@ class FederationClient(FederationBase):
async def send_request(destination) -> Dict[str, Any]:
content = await self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content)
# logger.debug("Got content: %s", content.getvalue())
state = [
event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("state", [])
]
# logger.info("send_join content: %d", len(content))
auth_chain = [
event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("auth_chain", [])
]
content.seek(0)
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
r = get_thread_resource_usage()
logger.info("Memory before state: %s", r.ru_maxrss)
state = []
for i, p in enumerate(ijson.items(content, "state.item")):
state.append(event_from_pdu_json(p, room_version, outlier=True))
if i % 1000 == 999:
await self._clock.sleep(0)
r = get_thread_resource_usage()
logger.info("Memory after state: %s", r.ru_maxrss)
logger.info("Parsed state: %d", len(state))
content.seek(0)
auth_chain = []
for i, p in enumerate(ijson.items(content, "auth_chain.item")):
auth_chain.append(event_from_pdu_json(p, room_version, outlier=True))
if i % 1000 == 999:
await self._clock.sleep(0)
r = get_thread_resource_usage()
logger.info("Memory after: %s", r.ru_maxrss)
logger.info("Parsed auth chain: %d", len(auth_chain))
create_event = None
for e in state:
@ -704,12 +728,19 @@ class FederationClient(FederationBase):
% (create_room_version,)
)
valid_pdus = await self._check_sigs_and_hash_and_fetch(
destination,
list(pdus.values()),
outlier=True,
room_version=room_version,
)
valid_pdus = []
for chunk in batch_iter(itertools.chain(state, auth_chain), 1000):
logger.info("Handling next _check_sigs_and_hash_and_fetch chunk")
new_valid_pdus = await self._check_sigs_and_hash_and_fetch(
destination,
chunk,
outlier=True,
room_version=room_version,
)
valid_pdus.extend(new_valid_pdus)
logger.info("_check_sigs_and_hash_and_fetch done")
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@ -744,6 +775,8 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
logger.info("Returning from send_join")
return {
"state": signed_state,
"auth_chain": signed_auth,
@ -769,6 +802,8 @@ class FederationClient(FederationBase):
if not self._is_unknown_endpoint(e):
raise
raise NotImplementedError()
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
resp = await self.transport_layer.send_join_v1(

View file

@ -244,7 +244,10 @@ class TransportLayerClient:
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
destination=destination, path=path, data=content
destination=destination,
path=path,
data=content,
return_string_io=True,
)
return response
@ -254,7 +257,10 @@ class TransportLayerClient:
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
destination=destination, path=path, data=content
destination=destination,
path=path,
data=content,
return_string_io=True,
)
return response

View file

@ -78,7 +78,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
users = await self.state.get_current_users_in_room(room_id)
users = await self.store.get_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users}
if not servers:
@ -270,7 +270,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
users = await self.state.get_current_users_in_room(room_id)
users = await self.store.get_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)

View file

@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = await self.state.get_current_users_in_room(
users = await self.store.get_users_in_room(
event.room_id
) # type: Iterable[str]
else:

View file

@ -552,7 +552,7 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
) -> Tuple[List[EventBase], List[EventBase]]:
) -> List[EventBase]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@ -573,11 +573,10 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
event_map = await self._get_events_from_store_or_dest(
failed_to_fetch = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s %s",
@ -585,55 +584,12 @@ class FederationHandler(BaseHandler):
failed_to_fetch,
)
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return remote_state, auth_chain
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
This function *does not* automatically get missing auth events of the
newly fetched events. Callers must include the full auth chain of
of the missing events in the `event_ids` argument, to ensure that any
missing auth events are correctly fetched.
Returns:
map from event_id to event
"""
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
if missing_events:
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
room_id,
)
await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
(await self.store.get_events(missing_events, allow_rejected=True))
)
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
@ -641,7 +597,7 @@ class FederationHandler(BaseHandler):
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
for idx, event in enumerate(remote_state)
if event.room_id != room_id
]
@ -658,9 +614,49 @@ class FederationHandler(BaseHandler):
room_id,
)
del fetched_events[bad_event_id]
if bad_events:
remote_state = [e for e in remote_state if e.room_id == room_id]
return fetched_events
return remote_state
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
"""Fetch events from a remote destination, checking if we already have them.
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
This function *does not* automatically get missing auth events of the
newly fetched events. Callers must include the full auth chain of
of the missing events in the `event_ids` argument, to ensure that any
missing auth events are correctly fetched.
Returns:
map from event_id to event
"""
have_events = await self.store.have_seen_events(event_ids)
missing_events = set(event_ids) - have_events
if not missing_events:
return set()
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
room_id,
)
await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
new_events = await self.store.have_seen_events(missing_events)
return missing_events - new_events
async def _get_state_after_missing_prev_event(
self,
@ -963,27 +959,23 @@ class FederationHandler(BaseHandler):
# For each edge get the current state.
auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = await self._get_state_for_room(
state = await self._get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
required_auth = {
a_id
for event in events
+ list(state_events.values())
+ list(auth_events.values())
for event in events + list(state_events.values())
for a_id in event.auth_event_ids()
}
auth_events = await self.store.get_events(required_auth, allow_rejected=True)
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
@ -1452,7 +1444,7 @@ class FederationHandler(BaseHandler):
# room stuff after join currently doesn't work on workers.
assert self.config.worker.worker_app is None
logger.debug("Joining %s to %s", joinee, room_id)
logger.info("Joining %s to %s", joinee, room_id)
origin, event, room_version_obj = await self._make_and_verify_event(
target_hosts,
@ -1463,6 +1455,8 @@ class FederationHandler(BaseHandler):
params={"ver": KNOWN_ROOM_VERSIONS},
)
logger.info("make_join done from %s", origin)
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
@ -1482,10 +1476,13 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
logger.info("Sending join")
ret = await self.federation_client.send_join(
host_list, event, room_version_obj
)
logger.info("send join done")
origin = ret["origin"]
state = ret["state"]
auth_chain = ret["auth_chain"]
@ -1510,10 +1507,14 @@ class FederationHandler(BaseHandler):
room_version=room_version_obj,
)
logger.info("Persisting auth true")
max_stream_id = await self._persist_auth_tree(
origin, room_id, auth_chain, state, event, room_version_obj
)
logger.info("Persisted auth true")
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
await self._replication.wait_for_stream_position(
@ -2166,6 +2167,8 @@ class FederationHandler(BaseHandler):
ctx = await self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
logger.info("Computed contexts")
event_map = {
e.event_id: e for e in itertools.chain(auth_events, state, [event])
}
@ -2207,6 +2210,8 @@ class FederationHandler(BaseHandler):
else:
logger.info("Failed to find auth event %r", e_id)
logger.info("Got missing events")
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
@ -2231,6 +2236,8 @@ class FederationHandler(BaseHandler):
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
logger.info("Authed events")
await self.persist_events_and_notify(
room_id,
[
@ -2239,10 +2246,14 @@ class FederationHandler(BaseHandler):
],
)
logger.info("Persisted events")
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
logger.info("Computed context")
return await self.persist_events_and_notify(
room_id, [(event, new_event_context)]
)

View file

@ -258,7 +258,7 @@ class MessageHandler:
"Getting joined members after leaving is not implemented"
)
users_with_profile = await self.state.get_current_users_in_room(room_id)
users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
@ -1108,7 +1108,7 @@ class EventCreationHandler:
# it's not a self-redaction (to avoid having to look up whether the
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
if event.type == EventTypes.Redaction and event.redacts:
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@ -1195,7 +1195,7 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
if event.type == EventTypes.Redaction and event.redacts:
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@ -1401,7 +1401,7 @@ class EventCreationHandler:
]
for k in immutable_fields:
if getattr(builder, k, None) != original_event.get(k):
if getattr(builder, k, None) != getattr(original_event, k, None):
raise Exception(
"Third party rules module created an invalid event: "
"cannot change field " + k

View file

@ -1183,7 +1183,16 @@ class PresenceHandler(BasePresenceHandler):
max_pos, deltas = await self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
await self._handle_state_delta(deltas)
# We may get multiple deltas for different rooms, but we want to
# handle them on a room by room basis, so we batch them up by
# room.
deltas_by_room: Dict[str, List[JsonDict]] = {}
for delta in deltas:
deltas_by_room.setdefault(delta["room_id"], []).append(delta)
for room_id, deltas_for_room in deltas_by_room.items():
await self._handle_state_delta(room_id, deltas_for_room)
self._event_pos = max_pos
@ -1192,17 +1201,21 @@ class PresenceHandler(BasePresenceHandler):
max_pos
)
async def _handle_state_delta(self, deltas: List[JsonDict]) -> None:
"""Process current state deltas to find new joins that need to be
handled.
async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None:
"""Process current state deltas for the room to find new joins that need
to be handled.
"""
# A map of destination to a set of user state that they should receive
presence_destinations = {} # type: Dict[str, Set[UserPresenceState]]
# Sets of newly joined users. Note that if the local server is
# joining a remote room for the first time we'll see both the joining
# user and all remote users as newly joined.
newly_joined_users = set()
for delta in deltas:
assert room_id == delta["room_id"]
typ = delta["type"]
state_key = delta["state_key"]
room_id = delta["room_id"]
event_id = delta["event_id"]
prev_event_id = delta["prev_event_id"]
@ -1231,72 +1244,55 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events.
continue
# Retrieve any user presence state updates that need to be sent as a result,
# and the destinations that need to receive it
destinations, user_presence_states = await self._on_user_joined_room(
room_id, state_key
)
newly_joined_users.add(state_key)
# Insert the destinations and respective updates into our destinations dict
for destination in destinations:
presence_destinations.setdefault(destination, set()).update(
user_presence_states
)
if not newly_joined_users:
# If nobody has joined then there's nothing to do.
return
# Send out user presence updates for each destination
for destination, user_state_set in presence_destinations.items():
self._federation_queue.send_presence_to_destinations(
destinations=[destination], states=user_state_set
)
# We want to send:
# 1. presence states of all local users in the room to newly joined
# remote servers
# 2. presence states of newly joined users to all remote servers in
# the room.
#
# TODO: Only send presence states to remote hosts that don't already
# have them (because they already share rooms).
async def _on_user_joined_room(
self, room_id: str, user_id: str
) -> Tuple[List[str], List[UserPresenceState]]:
"""Called when we detect a user joining the room via the current state
delta stream. Returns the destinations that need to be updated and the
presence updates to send to them.
# Get all the users who were already in the room, by fetching the
# current users in the room and removing the newly joined users.
users = await self.store.get_users_in_room(room_id)
prev_users = set(users) - newly_joined_users
Args:
room_id: The ID of the room that the user has joined.
user_id: The ID of the user that has joined the room.
# Construct sets for all the local users and remote hosts that were
# already in the room
prev_local_users = []
prev_remote_hosts = set()
for user_id in prev_users:
if self.is_mine_id(user_id):
prev_local_users.append(user_id)
else:
prev_remote_hosts.add(get_domain_from_id(user_id))
Returns:
A tuple of destinations and presence updates to send to them.
"""
if self.is_mine_id(user_id):
# If this is a local user then we need to send their presence
# out to hosts in the room (who don't already have it)
# Similarly, construct sets for all the local users and remote hosts
# that were *not* already in the room. Care needs to be taken with the
# calculating the remote hosts, as a host may have already been in the
# room even if there is a newly joined user from that host.
newly_joined_local_users = []
newly_joined_remote_hosts = set()
for user_id in newly_joined_users:
if self.is_mine_id(user_id):
newly_joined_local_users.append(user_id)
else:
host = get_domain_from_id(user_id)
if host not in prev_remote_hosts:
newly_joined_remote_hosts.add(host)
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
remote_hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
filtered_remote_hosts = [
host for host in remote_hosts if host != self.server_name
]
state = await self.current_state_for_user(user_id)
return filtered_remote_hosts, [state]
else:
# A remote user has joined the room, so we need to:
# 1. Check if this is a new server in the room
# 2. If so send any presence they don't already have for
# local users in the room.
# TODO: We should be able to filter the users down to those that
# the server hasn't previously seen
# TODO: Check that this is actually a new server joining the
# room.
remote_host = get_domain_from_id(user_id)
users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids)
# Send presence states of all local users in the room to newly joined
# remote servers. (We actually only send states for local users already
# in the room, as we'll send states for newly joined local users below.)
if prev_local_users and newly_joined_remote_hosts:
local_states = await self.current_state_for_users(prev_local_users)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@ -1306,13 +1302,27 @@ class PresenceHandler(BasePresenceHandler):
now = self.clock.time_msec()
states = [
state
for state in states_d.values()
for state in local_states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
]
return [remote_host], states
self._federation_queue.send_presence_to_destinations(
destinations=newly_joined_remote_hosts,
states=states,
)
# Send presence states of newly joined users to all remote servers in
# the room
if newly_joined_local_users and (
prev_remote_hosts or newly_joined_remote_hosts
):
local_states = await self.current_state_for_users(newly_joined_local_users)
self._federation_queue.send_presence_to_destinations(
destinations=prev_remote_hosts | newly_joined_remote_hosts,
states=list(local_states.values()),
)
def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool:

View file

@ -475,7 +475,7 @@ class RoomCreationHandler(BaseHandler):
):
await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
UserID.from_string(old_event.state_key),
new_room_id,
"ban",
ratelimit=False,
@ -1327,7 +1327,7 @@ class RoomShutdownHandler:
new_room_id = None
logger.info("Shutting down room %r", room_id)
users = await self.state.get_current_users_in_room(room_id)
users = await self.store.get_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
for user_id in users:

View file

@ -1190,7 +1190,7 @@ class SyncHandler:
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.state.get_current_users_in_room(room_id)
joined_users = await self.store.get_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
@ -1206,7 +1206,7 @@ class SyncHandler:
# Now find users that we no longer track
for room_id in newly_left_rooms:
left_users = await self.state.get_current_users_in_room(room_id)
left_users = await self.store.get_users_in_room(room_id)
newly_left_users.update(left_users)
# Remove any users that we still share a room with.
@ -1361,7 +1361,7 @@ class SyncHandler:
extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
users = await self.state.get_current_users_in_room(room_id)
users = await self.store.get_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())

View file

@ -154,6 +154,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
return_string_io=False,
) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@ -175,12 +176,12 @@ async def _handle_json_response(
d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
def parse(_len: int):
return json_decoder.decode(buf.getvalue())
await make_deferred_yieldable(d)
d.addCallback(parse)
body = await make_deferred_yieldable(d)
if return_string_io:
body = buf
else:
body = json_decoder.decode(buf.getvalue())
except BodyExceededMaxSize as e:
# The response was too big.
logger.warning(
@ -225,12 +226,13 @@ async def _handle_json_response(
time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info(
"{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
"{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
time_taken_secs,
len(buf.getvalue()),
request.method,
request.uri.decode("ascii"),
)
@ -683,6 +685,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
return_string_io=False,
) -> Union[JsonDict, list]:
"""Sends the specified json data using PUT
@ -757,7 +760,12 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
body = await _handle_json_response(
self.reactor, _sec_timeout, request, response, start_ms
self.reactor,
_sec_timeout,
request,
response,
start_ms,
return_string_io=return_string_io,
)
return body

View file

@ -535,6 +535,13 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric())
# The minimum time in seconds between GCs for each generation, regardless of the current GC
# thresholds and counts.
MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
# The time (in seconds since the epoch) of the last time we did a GC for each generation.
_last_gc = [0.0, 0.0, 0.0]
def runUntilCurrentTimer(reactor, func):
@functools.wraps(func)
@ -575,11 +582,16 @@ def runUntilCurrentTimer(reactor, func):
return ret
# Check if we need to do a manual GC (since its been disabled), and do
# one if necessary.
# one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
# promote an object into gen 2, and we don't want to handle the same
# object multiple times.
threshold = gc.get_threshold()
counts = gc.get_count()
for i in (2, 1, 0):
if threshold[i] < counts[i]:
# We check if we need to do one based on a straightforward
# comparison between the threshold and count. We also do an extra
# check to make sure that we don't a GC too often.
if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
if i == 0:
logger.debug("Collecting gc %d", i)
else:
@ -589,6 +601,8 @@ def runUntilCurrentTimer(reactor, func):
unreachable = gc.collect(i)
end = time.time()
_last_gc[i] = end
gc_time.labels(i).observe(end - start)
gc_unreachable.labels(i).set(unreachable)
@ -615,6 +629,7 @@ try:
except AttributeError:
pass
__all__ = [
"MetricsResource",
"generate_latest",

191
synapse/metrics/jemalloc.py Normal file
View file

@ -0,0 +1,191 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import logging
import os
import re
from typing import Optional
from synapse.metrics import REGISTRY, GaugeMetricFamily
logger = logging.getLogger(__name__)
def _setup_jemalloc_stats():
"""Checks to see if jemalloc is loaded, and hooks up a collector to record
statistics exposed by jemalloc.
"""
# Try to find the loaded jemalloc shared library, if any. We need to
# introspect into what is loaded, rather than loading whatever is on the
# path, as if we load a *different* jemalloc version things will seg fault.
# We look in `/proc/self/maps`, which only exists on linux.
if not os.path.exists("/proc/self/maps"):
logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
return
# We're looking for a path at the end of the line that includes
# "libjemalloc".
regex = re.compile(r"/\S+/libjemalloc.*$")
jemalloc_path = None
with open("/proc/self/maps") as f:
for line in f:
match = regex.search(line.strip())
if match:
jemalloc_path = match.group()
if not jemalloc_path:
# No loaded jemalloc was found.
logger.debug("jemalloc not found")
return
jemalloc = ctypes.CDLL(jemalloc_path)
def _mallctl(
name: str, read: bool = True, write: Optional[int] = None
) -> Optional[int]:
"""Wrapper around `mallctl` for reading and writing integers to
jemalloc.
Args:
name: The name of the option to read from/write to.
read: Whether to try and read the value.
write: The value to write, if given.
Returns:
The value read if `read` is True, otherwise None.
Raises:
An exception if `mallctl` returns a non-zero error code.
"""
input_var = None
input_var_ref = None
input_len_ref = None
if read:
input_var = ctypes.c_size_t(0)
input_len = ctypes.c_size_t(ctypes.sizeof(input_var))
input_var_ref = ctypes.byref(input_var)
input_len_ref = ctypes.byref(input_len)
write_var_ref = None
write_len = ctypes.c_size_t(0)
if write is not None:
write_var = ctypes.c_size_t(write)
write_len = ctypes.c_size_t(ctypes.sizeof(write_var))
write_var_ref = ctypes.byref(write_var)
# The interface is:
#
# int mallctl(
# const char *name,
# void *oldp,
# size_t *oldlenp,
# void *newp,
# size_t newlen
# )
#
# Where oldp/oldlenp is a buffer where the old value will be written to
# (if not null), and newp/newlen is the buffer with the new value to set
# (if not null). Note that they're all references *except* newlen.
result = jemalloc.mallctl(
name.encode("ascii"),
input_var_ref,
input_len_ref,
write_var_ref,
write_len,
)
if result != 0:
raise Exception("Failed to call mallctl")
if input_var is None:
return None
return input_var.value
def _jemalloc_refresh_stats() -> None:
"""Request that jemalloc updates its internal statistics. This needs to
be called before querying for stats, otherwise it will return stale
values.
"""
try:
_mallctl("epoch", read=False, write=1)
except Exception:
pass
class JemallocCollector:
"""Metrics for internal jemalloc stats."""
def collect(self):
_jemalloc_refresh_stats()
g = GaugeMetricFamily(
"jemalloc_stats_app_memory_bytes",
"The stats reported by jemalloc",
labels=["type"],
)
# Read the relevant global stats from jemalloc. Note that these may
# not be accurate if python is configured to use its internal small
# object allocator (which is on by default, disable by setting the
# env `PYTHONMALLOC=malloc`).
#
# See the jemalloc manpage for details about what each value means,
# roughly:
# - allocated ─ Total number of bytes allocated by the app
# - active ─ Total number of bytes in active pages allocated by
# the application, this is bigger than `allocated`.
# - resident ─ Maximum number of bytes in physically resident data
# pages mapped by the allocator, comprising all pages dedicated
# to allocator metadata, pages backing active allocations, and
# unused dirty pages. This is bigger than `active`.
# - mapped ─ Total number of bytes in active extents mapped by the
# allocator.
# - metadata ─ Total number of bytes dedicated to jemalloc
# metadata.
for t in (
"allocated",
"active",
"resident",
"mapped",
"metadata",
):
try:
value = _mallctl(f"stats.{t}")
except Exception:
# There was an error fetching the value, skip.
continue
g.add_metric([t], value=value)
yield g
REGISTRY.register(JemallocCollector())
logger.debug("Added jemalloc stats")
def setup_jemalloc_stats():
"""Try to setup jemalloc stats, if jemalloc is loaded."""
try:
_setup_jemalloc_stats()
except Exception as e:
logger.info("Failed to setup collector to record jemalloc stats: %s", e)

View file

@ -277,7 +277,7 @@ class Notifier:
event_pos=event_pos,
room_id=event.room_id,
event_type=event.type,
state_key=event.get("state_key"),
state_key=getattr(event, "state_key", None),
membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token,
extra_users=extra_users or [],

View file

@ -125,7 +125,7 @@ class PushRuleEvaluatorForEvent:
self._power_levels = power_levels
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
self._value_cache = _flatten_dict(event.get_dict())
def matches(
self, condition: Dict[str, Any], user_id: str, display_name: str
@ -271,7 +271,7 @@ def _re_word_boundary(r: str) -> str:
def _flatten_dict(
d: Union[EventBase, dict],
d: dict,
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:

View file

@ -116,6 +116,8 @@ CONDITIONAL_REQUIREMENTS = {
# hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.)
"redis": ["txredisapi>=1.4.7", "hiredis"],
# Required to use experimental `caches.track_memory_usage` config option.
"cache_memory": ["pympler"],
}
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]

View file

@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource):
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await yieldable_gather_results(
lambda t: self.fetcher.get_keys(*t),
(
(server_name, list(keys), 0)
for server_name, keys in cache_misses.items()
),
)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []

View file

@ -213,19 +213,23 @@ class StateHandler:
return ret.state
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
self, room_id: str, latest_event_ids: List[str]
) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
so this should only be used when wanting the users at a particular point
in the room.
Args:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room")

View file

@ -69,6 +69,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))

View file

@ -205,8 +205,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
sql = """
SELECT user_id, display_name, avatar_url FROM room_memberships
WHERE room_id = ? AND membership = ?
SELECT state_key, display_name, avatar_url FROM room_memberships as m
INNER JOIN current_state_events as c
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
"""
txn.execute(sql, (room_id, Membership.JOIN))

View file

@ -142,8 +142,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
batch_size (int): Maximum number of state events to process
per cycle.
"""
state = self.hs.get_state_handler()
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()
@ -197,7 +195,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
users_with_profile = await state.get_current_users_in_room(room_id)
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.

View file

@ -24,6 +24,11 @@ from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__)
# Whether to track estimated memory usage of the LruCaches.
TRACK_MEMORY_USAGE = False
caches_by_name = {} # type: Dict[str, Sized]
collectors_by_name = {} # type: Dict[str, CacheMetric]
@ -32,6 +37,11 @@ cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
cache_memory_usage = Gauge(
"synapse_util_caches_cache_size_bytes",
"Estimated memory usage of the caches",
["name"],
)
response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
@ -52,6 +62,7 @@ class CacheMetric:
hits = attr.ib(default=0)
misses = attr.ib(default=0)
evicted_size = attr.ib(default=0)
memory_usage = attr.ib(default=None)
def inc_hits(self):
self.hits += 1
@ -62,6 +73,19 @@ class CacheMetric:
def inc_evictions(self, size=1):
self.evicted_size += size
def inc_memory_usage(self, memory: int):
if self.memory_usage is None:
self.memory_usage = 0
self.memory_usage += memory
def dec_memory_usage(self, memory: int):
self.memory_usage -= memory
def clear_memory_usage(self):
if self.memory_usage is not None:
self.memory_usage = 0
def describe(self):
return []
@ -81,6 +105,13 @@ class CacheMetric:
cache_total.labels(self._cache_name).set(self.hits + self.misses)
if getattr(self._cache, "max_size", None):
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted
# into the cache yet.
cache_memory_usage.labels(self._cache_name).set(
self.memory_usage or 0
)
if self._collect_callback:
self._collect_callback()
except Exception as e:

View file

@ -32,9 +32,36 @@ from typing import (
from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util import caches
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
try:
from pympler.asizeof import Asizer
def _get_size_of(val: Any, *, recurse=True) -> int:
"""Get an estimate of the size in bytes of the object.
Args:
val: The object to size.
recurse: If true will include referenced values in the size,
otherwise only sizes the given object.
"""
# Ignore singleton values when calculating memory usage.
if val in ((), None, ""):
return 0
sizer = Asizer()
sizer.exclude_refs((), None, "")
return sizer.asizeof(val, limit=100 if recurse else 0)
except ImportError:
def _get_size_of(val: Any, *, recurse=True) -> int:
return 0
# Function type: the type used for invalidation callbacks
FT = TypeVar("FT", bound=Callable[..., Any])
@ -56,7 +83,7 @@ def enumerate_leaves(node, depth):
class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
def __init__(
self,
@ -84,6 +111,16 @@ class _Node:
self.add_callbacks(callbacks)
self.memory = 0
if caches.TRACK_MEMORY_USAGE:
self.memory = (
_get_size_of(key)
+ _get_size_of(value)
+ _get_size_of(self.callbacks, recurse=False)
+ _get_size_of(self, recurse=False)
)
self.memory += _get_size_of(self.memory, recurse=False)
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
@ -233,6 +270,9 @@ class LruCache(Generic[KT, VT]):
if size_callback:
cached_cache_len[0] += size_callback(node.value)
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@ -258,6 +298,9 @@ class LruCache(Generic[KT, VT]):
node.run_and_clear_callbacks()
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.dec_memory_usage(node.memory)
return deleted_len
@overload
@ -373,6 +416,9 @@ class LruCache(Generic[KT, VT]):
if size_callback:
cached_cache_len[0] = 0
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.clear_memory_usage()
@synchronized
def cache_contains(key: KT) -> bool:
return key in cache

View file

@ -729,7 +729,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=["server2"], states={expected_state}
destinations={"server2"}, states=[expected_state]
)
#
@ -740,7 +740,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self._add_new_user(room_id, "@bob:server3")
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=["server3"], states={expected_state}
destinations={"server3"}, states=[expected_state]
)
def test_remote_gets_presence_when_local_user_joins(self):
@ -788,14 +788,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.presence_handler.current_state_for_user("@test2:server")
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.assertEqual(
self.federation_sender.send_presence_to_destinations.call_count, 2
)
self.federation_sender.send_presence_to_destinations.assert_any_call(
destinations=["server3"], states={expected_state}
)
self.federation_sender.send_presence_to_destinations.assert_any_call(
destinations=["server2"], states={expected_state}
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations={"server2", "server3"}, states=[expected_state]
)
def _add_new_user(self, room_id, user_id):