Compare commits
79 commits
develop
...
erikj/test
Author | SHA1 | Date | |
---|---|---|---|
b044860e56 | |||
2e1a8878d5 | |||
a7b7770bef | |||
5d9bbca631 | |||
015fdfe5bb | |||
faa7d48930 | |||
f4bb01d41a | |||
ec1c2c69a2 | |||
7e7b99bca9 | |||
3d937d23fd | |||
c856e29ccd | |||
941a0a76d3 | |||
b0d014819f | |||
eeafa29399 | |||
016d55b94b | |||
adef51ab98 | |||
cdeb6050ea | |||
88bd909a4a | |||
a94edad23b | |||
fc17e4e62e | |||
0d9d84dac0 | |||
04db7b9581 | |||
24965fc073 | |||
5b031e2da3 | |||
14b70bbd9f | |||
d4175abe52 | |||
b76fe71627 | |||
7f237d5639 | |||
f37c5843d3 | |||
d6ae1aef46 | |||
3bfd3c55f9 | |||
cad5a47621 | |||
7e3d333b28 | |||
aabc46f0f6 | |||
78e3502ada | |||
8206069c63 | |||
a99524f383 | |||
b5169b68e9 | |||
4c9446c4cb | |||
4a8a483060 | |||
d145ba6ccc | |||
dcb79da38a | |||
35c13c730c | |||
8624333cd9 | |||
4caa84b279 | |||
48cf260c7a | |||
7e5f78a698 | |||
43c9acda4c | |||
bd04fb6308 | |||
d3a6e38c96 | |||
aa1a026509 | |||
260c760d69 | |||
49da5e9ec4 | |||
3b2991e3fb | |||
aec80899ab | |||
68f1d258d9 | |||
8481bacc93 | |||
68b6106ce5 | |||
0ed608cf56 | |||
ac0143c4ac | |||
f5a25c7b53 | |||
5813719696 | |||
6640fb467f | |||
0c8cd62149 | |||
996c0ce3d5 | |||
938efeb595 | |||
4a3a9597f5 | |||
351f886bc8 | |||
79627b3a3c | |||
6237096e80 | |||
1b4ec8ef0e | |||
5add13e05d | |||
2bf93f9b34 | |||
bcf8858b67 | |||
99fb72e63e | |||
567fe5e387 | |||
0c9bab290f | |||
5003bd29d2 | |||
e9f5812eff |
1
changelog.d/9881.feature
Normal file
1
changelog.d/9881.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add experimental option to track memory usage of the caches.
|
1
changelog.d/9882.misc
Normal file
1
changelog.d/9882.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Export jemalloc stats to Prometheus if it is being used.
|
1
changelog.d/9902.feature
Normal file
1
changelog.d/9902.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add limits to how often Synapse will GC, ensuring that large servers do not end up GC thrashing if `gc_thresholds` has not been correctly set.
|
1
changelog.d/9910.bugfix
Normal file
1
changelog.d/9910.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug where user directory could get out of sync if room visibility and membership changed in quick succession.
|
1
changelog.d/9910.feature
Normal file
1
changelog.d/9910.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Improve performance after joining a large room when presence is enabled.
|
1
changelog.d/9916.misc
Normal file
1
changelog.d/9916.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve performance of handling presence when joining large rooms.
|
|
@ -152,6 +152,16 @@ presence:
|
|||
#
|
||||
#gc_thresholds: [700, 10, 10]
|
||||
|
||||
# The minimum time in seconds between each GC for a generation, regardless of
|
||||
# the GC thresholds. This ensures that we don't do GC too frequently.
|
||||
#
|
||||
# A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive
|
||||
# generation 0 GCs, etc.
|
||||
#
|
||||
# Defaults to `[1s, 10s, 30s]`.
|
||||
#
|
||||
#gc_min_interval: [0.5s, 30s, 1m]
|
||||
|
||||
# Set the limit on the returned events in the timeline in the get
|
||||
# and sync operations. The default value is 100. -1 means no upper limit.
|
||||
#
|
||||
|
|
3
mypy.ini
3
mypy.ini
|
@ -171,3 +171,6 @@ ignore_missing_imports = True
|
|||
|
||||
[mypy-txacme.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-pympler.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
|
@ -23,6 +23,7 @@ from jsonschema import FormatChecker
|
|||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import RoomID, UserID
|
||||
|
||||
FILTER_SCHEMA = {
|
||||
|
@ -290,6 +291,13 @@ class Filter:
|
|||
ev_type = "m.presence"
|
||||
contains_url = False
|
||||
labels = [] # type: List[str]
|
||||
elif isinstance(event, EventBase):
|
||||
sender = event.sender
|
||||
room_id = event.room_id
|
||||
ev_type = event.type
|
||||
content = event.content
|
||||
contains_url = isinstance(content.get("url"), str)
|
||||
labels = content.get(EventContentFields.LABELS, [])
|
||||
else:
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
|
|
|
@ -37,6 +37,7 @@ from synapse.config.homeserver import HomeServerConfig
|
|||
from synapse.crypto import context_factory
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.metrics.jemalloc import setup_jemalloc_stats
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.daemonize import daemonize_process
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
|
@ -115,6 +116,7 @@ def start_reactor(
|
|||
|
||||
def run():
|
||||
logger.info("Running")
|
||||
setup_jemalloc_stats()
|
||||
change_resource_limit(soft_file_limit)
|
||||
if gc_thresholds:
|
||||
gc.set_threshold(*gc_thresholds)
|
||||
|
|
|
@ -454,6 +454,10 @@ def start(config_options):
|
|||
config.server.update_user_directory = False
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
|
||||
|
||||
if config.server.gc_seconds:
|
||||
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
|
||||
|
||||
hs = GenericWorkerServer(
|
||||
config.server_name,
|
||||
|
|
|
@ -341,6 +341,10 @@ def setup(config_options):
|
|||
sys.exit(0)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
|
||||
|
||||
if config.server.gc_seconds:
|
||||
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
|
||||
|
||||
hs = SynapseHomeServer(
|
||||
config.server_name,
|
||||
|
|
|
@ -17,6 +17,8 @@ import re
|
|||
import threading
|
||||
from typing import Callable, Dict
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
# The prefix for all cache factor-related environment variables
|
||||
|
@ -189,6 +191,15 @@ class CacheConfig(Config):
|
|||
)
|
||||
self.cache_factors[cache] = factor
|
||||
|
||||
self.track_memory_usage = cache_config.get("track_memory_usage", False)
|
||||
if self.track_memory_usage:
|
||||
try:
|
||||
check_requirements("cache_memory")
|
||||
except DependencyException as e:
|
||||
raise ConfigError(
|
||||
e.message # noqa: B306, DependencyException.message is a property
|
||||
)
|
||||
|
||||
# Resize all caches (if necessary) with the new factors we've loaded
|
||||
self.resize_all_caches()
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import logging
|
|||
import os.path
|
||||
import re
|
||||
from textwrap import indent
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
import yaml
|
||||
|
@ -572,6 +572,7 @@ class ServerConfig(Config):
|
|||
_warn_if_webclient_configured(self.listeners)
|
||||
|
||||
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
|
||||
self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))
|
||||
|
||||
@attr.s
|
||||
class LimitRemoteRoomsConfig:
|
||||
|
@ -917,6 +918,16 @@ class ServerConfig(Config):
|
|||
#
|
||||
#gc_thresholds: [700, 10, 10]
|
||||
|
||||
# The minimum time in seconds between each GC for a generation, regardless of
|
||||
# the GC thresholds. This ensures that we don't do GC too frequently.
|
||||
#
|
||||
# A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive
|
||||
# generation 0 GCs, etc.
|
||||
#
|
||||
# Defaults to `[1s, 10s, 30s]`.
|
||||
#
|
||||
#gc_min_interval: [0.5s, 30s, 1m]
|
||||
|
||||
# Set the limit on the returned events in the timeline in the get
|
||||
# and sync operations. The default value is 100. -1 means no upper limit.
|
||||
#
|
||||
|
@ -1305,6 +1316,24 @@ class ServerConfig(Config):
|
|||
help="Turn on the twisted telnet manhole service on the given port.",
|
||||
)
|
||||
|
||||
def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
|
||||
"""Reads the three durations for the GC min interval option, returning seconds."""
|
||||
if durations is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if len(durations) != 3:
|
||||
raise ValueError()
|
||||
return (
|
||||
self.parse_duration(durations[0]) / 1000,
|
||||
self.parse_duration(durations[1]) / 1000,
|
||||
self.parse_duration(durations[2]) / 1000,
|
||||
)
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Value of `gc_min_interval` must be a list of three durations if set"
|
||||
)
|
||||
|
||||
|
||||
def is_threepid_reserved(reserved_threepids, threepid):
|
||||
"""Check the threepid against the reserved threepid config
|
||||
|
|
|
@ -48,7 +48,7 @@ def check_event_content_hash(
|
|||
|
||||
# some malformed events lack a 'hashes'. Protect against it being missing
|
||||
# or a weird type by basically treating it the same as an unhashed event.
|
||||
hashes = event.get("hashes")
|
||||
hashes = getattr(event, "hashes", None)
|
||||
# nb it might be a frozendict or a dict
|
||||
if not isinstance(hashes, collections.abc.Mapping):
|
||||
raise SynapseError(
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
import abc
|
||||
import logging
|
||||
import urllib
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from signedjson.key import (
|
||||
|
@ -42,17 +41,18 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.config.key import TrustedKeyServer
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import prune_event_dict
|
||||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
preserve_fn,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.async_helpers import Linearizer, yieldable_gather_results
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -74,8 +74,6 @@ class VerifyJsonRequest:
|
|||
minimum_valid_until_ts: time at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
request_name: The name of the request.
|
||||
|
||||
key_ids: The set of key_ids to that could be used to verify the JSON object
|
||||
|
||||
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
|
||||
|
@ -88,20 +86,93 @@ class VerifyJsonRequest:
|
|||
"""
|
||||
|
||||
server_name = attr.ib(type=str)
|
||||
json_object = attr.ib(type=JsonDict)
|
||||
json_object_callback = attr.ib(type=Callable[[], JsonDict])
|
||||
minimum_valid_until_ts = attr.ib(type=int)
|
||||
request_name = attr.ib(type=str)
|
||||
key_ids = attr.ib(init=False, type=List[str])
|
||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
||||
key_ids = attr.ib(type=List[str])
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.key_ids = signature_ids(self.json_object, self.server_name)
|
||||
@staticmethod
|
||||
def from_json_object(
|
||||
server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
|
||||
):
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
return VerifyJsonRequest(
|
||||
server_name, lambda: json_object, minimum_valid_until_ms, key_ids
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_event(
|
||||
server_name: str,
|
||||
minimum_valid_until_ms: int,
|
||||
event: EventBase,
|
||||
):
|
||||
key_ids = list(event.signatures.get(server_name, []))
|
||||
return VerifyJsonRequest(
|
||||
server_name,
|
||||
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
||||
minimum_valid_until_ms,
|
||||
key_ids,
|
||||
)
|
||||
|
||||
|
||||
class KeyLookupError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _QueueValue:
|
||||
server_name = attr.ib(type=str)
|
||||
minimum_valid_until_ts = attr.ib(type=int)
|
||||
key_ids = attr.ib(type=List[str])
|
||||
|
||||
|
||||
class _Queue:
|
||||
def __init__(self, name, clock, process_items):
|
||||
self._name = name
|
||||
self._clock = clock
|
||||
self._is_processing = False
|
||||
self._next_values = []
|
||||
|
||||
self.process_items = process_items
|
||||
|
||||
async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]:
|
||||
d = defer.Deferred()
|
||||
self._next_values.append((value, d))
|
||||
|
||||
if not self._is_processing:
|
||||
run_as_background_process(self._name, self._unsafe_process)
|
||||
|
||||
return await make_deferred_yieldable(d)
|
||||
|
||||
async def _unsafe_process(self):
|
||||
try:
|
||||
if self._is_processing:
|
||||
return
|
||||
|
||||
self._is_processing = True
|
||||
|
||||
while self._next_values:
|
||||
# We purposefully defer to the next loop.
|
||||
await self._clock.sleep(0)
|
||||
|
||||
next_values = self._next_values
|
||||
self._next_values = []
|
||||
|
||||
try:
|
||||
values = [value for value, _ in next_values]
|
||||
results = await self.process_items(values)
|
||||
|
||||
for value, deferred in next_values:
|
||||
with PreserveLoggingContext():
|
||||
deferred.callback(results.get(value.server_name, {}))
|
||||
|
||||
except Exception as e:
|
||||
for _, deferred in next_values:
|
||||
deferred.errback(e)
|
||||
|
||||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
|
||||
class Keyring:
|
||||
def __init__(
|
||||
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||
|
@ -116,12 +187,7 @@ class Keyring:
|
|||
)
|
||||
self._key_fetchers = key_fetchers
|
||||
|
||||
# map from server name to Deferred. Has an entry for each server with
|
||||
# an ongoing key download; the Deferred completes once the download
|
||||
# completes.
|
||||
#
|
||||
# These are regular, logcontext-agnostic Deferreds.
|
||||
self.key_downloads = {} # type: Dict[str, defer.Deferred]
|
||||
self._server_queue = Linearizer("keyring_server")
|
||||
|
||||
def verify_json_for_server(
|
||||
self,
|
||||
|
@ -130,365 +196,150 @@ class Keyring:
|
|||
validity_time: int,
|
||||
request_name: str,
|
||||
) -> defer.Deferred:
|
||||
"""Verify that a JSON object has been signed by a given server
|
||||
|
||||
Args:
|
||||
server_name: name of the server which must have signed this object
|
||||
|
||||
json_object: object to be checked
|
||||
|
||||
validity_time: timestamp at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
request_name: an identifier for this json object (eg, an event id)
|
||||
for logging.
|
||||
|
||||
Returns:
|
||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
||||
errbacks with an error
|
||||
"""
|
||||
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
||||
requests = (req,)
|
||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
||||
request = VerifyJsonRequest.from_json_object(
|
||||
server_name,
|
||||
validity_time,
|
||||
json_object,
|
||||
)
|
||||
return defer.ensureDeferred(self._verify_object(request))
|
||||
|
||||
def verify_json_objects_for_server(
|
||||
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
server_and_json:
|
||||
Iterable of (server_name, json_object, validity_time, request_name)
|
||||
tuples.
|
||||
|
||||
validity_time is a timestamp at which the signing key must be
|
||||
valid.
|
||||
|
||||
request_name is an identifier for this json object (eg, an event id)
|
||||
for logging.
|
||||
|
||||
Returns:
|
||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||
or failure to verify each json object's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
"""
|
||||
return self._verify_objects(
|
||||
VerifyJsonRequest(server_name, json_object, validity_time, request_name)
|
||||
for server_name, json_object, validity_time, request_name in server_and_json
|
||||
)
|
||||
|
||||
def _verify_objects(
|
||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Does the work of verify_json_[objects_]for_server
|
||||
|
||||
|
||||
Args:
|
||||
verify_requests: Iterable of verification requests.
|
||||
|
||||
Returns:
|
||||
List<Deferred[None]>: for each input item, a deferred indicating success
|
||||
or failure to verify each json object's signature for the given
|
||||
server_name. The deferreds run their callbacks in the sentinel
|
||||
logcontext.
|
||||
"""
|
||||
# a list of VerifyJsonRequests which are awaiting a key lookup
|
||||
key_lookups = []
|
||||
handle = preserve_fn(_handle_key_deferred)
|
||||
|
||||
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
|
||||
"""Process an entry in the request list
|
||||
|
||||
Adds a key request to key_lookups, and returns a deferred which
|
||||
will complete or fail (in the sentinel context) when verification completes.
|
||||
"""
|
||||
if not verify_request.key_ids:
|
||||
return defer.fail(
|
||||
SynapseError(
|
||||
400,
|
||||
"Not signed by %s" % (verify_request.server_name,),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
return [
|
||||
defer.ensureDeferred(
|
||||
run_in_background(
|
||||
self._verify_object,
|
||||
VerifyJsonRequest.from_json_object(
|
||||
server_name,
|
||||
validity_time,
|
||||
json_object,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Verifying %s for %s with key_ids %s, min_validity %i",
|
||||
verify_request.request_name,
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
for server_name, json_object, validity_time, request_name in server_and_json
|
||||
]
|
||||
|
||||
# add the key request to the queue, but don't start it off yet.
|
||||
key_lookups.append(verify_request)
|
||||
|
||||
# now run _handle_key_deferred, which will wait for the key request
|
||||
# to complete and then do the verification.
|
||||
#
|
||||
# We want _handle_key_request to log to the right context, so we
|
||||
# wrap it with preserve_fn (aka run_in_background)
|
||||
return handle(verify_request)
|
||||
|
||||
results = [process(r) for r in verify_requests]
|
||||
|
||||
if key_lookups:
|
||||
run_in_background(self._start_key_lookups, key_lookups)
|
||||
|
||||
return results
|
||||
|
||||
async def _start_key_lookups(
|
||||
self, verify_requests: List[VerifyJsonRequest]
|
||||
) -> None:
|
||||
"""Sets off the key fetches for each verify request
|
||||
|
||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
||||
|
||||
Args:
|
||||
verify_requests:
|
||||
"""
|
||||
|
||||
try:
|
||||
# map from server name to a set of outstanding request ids
|
||||
server_to_request_ids = {} # type: Dict[str, Set[int]]
|
||||
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
|
||||
# Wait for any previous lookups to complete before proceeding.
|
||||
await self.wait_for_previous_lookups(server_to_request_ids.keys())
|
||||
|
||||
# take out a lock on each of the servers by sticking a Deferred in
|
||||
# key_downloads
|
||||
for server_name in server_to_request_ids.keys():
|
||||
self.key_downloads[server_name] = defer.Deferred()
|
||||
logger.debug("Got key lookup lock on %s", server_name)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# drop the lock by resolving the deferred in key_downloads.
|
||||
def drop_server_lock(server_name):
|
||||
d = self.key_downloads.pop(server_name)
|
||||
d.callback(None)
|
||||
|
||||
def lookup_done(res, verify_request):
|
||||
server_name = verify_request.server_name
|
||||
server_requests = server_to_request_ids[server_name]
|
||||
server_requests.remove(id(verify_request))
|
||||
|
||||
# if there are no more requests for this server, we can drop the lock.
|
||||
if not server_requests:
|
||||
logger.debug("Releasing key lookup lock on %s", server_name)
|
||||
drop_server_lock(server_name)
|
||||
|
||||
return res
|
||||
|
||||
for verify_request in verify_requests:
|
||||
verify_request.key_ready.addBoth(lookup_done, verify_request)
|
||||
|
||||
# Actually start fetching keys.
|
||||
self._get_server_verify_keys(verify_requests)
|
||||
except Exception:
|
||||
logger.exception("Error starting key lookups")
|
||||
|
||||
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
||||
Args:
|
||||
server_names: list of servers which we want to look up
|
||||
|
||||
Returns:
|
||||
Resolves once all key lookups for the given servers have
|
||||
completed. Follows the synapse rules of logcontext preservation.
|
||||
"""
|
||||
loop_count = 1
|
||||
while True:
|
||||
wait_on = [
|
||||
(server_name, self.key_downloads[server_name])
|
||||
for server_name in server_names
|
||||
if server_name in self.key_downloads
|
||||
]
|
||||
if not wait_on:
|
||||
break
|
||||
logger.info(
|
||||
"Waiting for existing lookups for %s to complete [loop %i]",
|
||||
[w[0] for w in wait_on],
|
||||
loop_count,
|
||||
def verify_events_for_server(
|
||||
self, server_and_json: Iterable[Tuple[str, EventBase, int]]
|
||||
) -> List[defer.Deferred]:
|
||||
return [
|
||||
run_in_background(
|
||||
self._verify_object,
|
||||
VerifyJsonRequest.from_event(
|
||||
server_name,
|
||||
validity_time,
|
||||
event,
|
||||
),
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
await defer.DeferredList((w[1] for w in wait_on))
|
||||
for server_name, event, validity_time in server_and_json
|
||||
]
|
||||
|
||||
loop_count += 1
|
||||
async def _verify_object(self, verify_request: VerifyJsonRequest):
|
||||
# TODO: Use a batching thing.
|
||||
with (await self._server_queue.queue(verify_request.server_name)):
|
||||
found_keys: Dict[str, FetchKeyResult] = {}
|
||||
missing_key_ids = set(verify_request.key_ids)
|
||||
for fetcher in self._key_fetchers:
|
||||
if not missing_key_ids:
|
||||
break
|
||||
|
||||
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
|
||||
"""Tries to find at least one key for each verify request
|
||||
|
||||
For each verify_request, verify_request.key_ready is called back with
|
||||
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
|
||||
with a SynapseError if none of the keys are found.
|
||||
|
||||
Args:
|
||||
verify_requests: list of verify requests
|
||||
"""
|
||||
|
||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||
|
||||
async def do_iterations():
|
||||
try:
|
||||
with Measure(self.clock, "get_server_verify_keys"):
|
||||
for f in self._key_fetchers:
|
||||
if not remaining_requests:
|
||||
return
|
||||
await self._attempt_key_fetches_with_fetcher(
|
||||
f, remaining_requests
|
||||
)
|
||||
|
||||
# look for any requests which weren't satisfied
|
||||
while remaining_requests:
|
||||
verify_request = remaining_requests.pop()
|
||||
rq_str = (
|
||||
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_request.key_ids,
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
)
|
||||
|
||||
# If we run the errback immediately, it may cancel our
|
||||
# loggingcontext while we are still in it, so instead we
|
||||
# schedule it for the next time round the reactor.
|
||||
#
|
||||
# (this also ensures that we don't get a stack overflow if we
|
||||
# has a massive queue of lookups waiting for this server).
|
||||
self.clock.call_later(
|
||||
0,
|
||||
verify_request.key_ready.errback,
|
||||
SynapseError(
|
||||
401,
|
||||
"Failed to find any key to satisfy %s" % (rq_str,),
|
||||
Codes.UNAUTHORIZED,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
# we don't really expect to get here, because any errors should already
|
||||
# have been caught and logged. But if we do, let's log the error and make
|
||||
# sure that all of the deferreds are resolved.
|
||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
||||
with PreserveLoggingContext():
|
||||
for verify_request in remaining_requests:
|
||||
if not verify_request.key_ready.called:
|
||||
verify_request.key_ready.errback(err)
|
||||
|
||||
run_in_background(do_iterations)
|
||||
|
||||
async def _attempt_key_fetches_with_fetcher(
|
||||
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
|
||||
):
|
||||
"""Use a key fetcher to attempt to satisfy some key requests
|
||||
|
||||
Args:
|
||||
fetcher: fetcher to use to fetch the keys
|
||||
remaining_requests: outstanding key requests.
|
||||
Any successfully-completed requests will be removed from the list.
|
||||
"""
|
||||
# The keys to fetch.
|
||||
# server_name -> key_id -> min_valid_ts
|
||||
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
|
||||
|
||||
for verify_request in remaining_requests:
|
||||
# any completed requests should already have been removed
|
||||
assert not verify_request.key_ready.called
|
||||
keys_for_server = missing_keys[verify_request.server_name]
|
||||
|
||||
for key_id in verify_request.key_ids:
|
||||
# If we have several requests for the same key, then we only need to
|
||||
# request that key once, but we should do so with the greatest
|
||||
# min_valid_until_ts of the requests, so that we can satisfy all of
|
||||
# the requests.
|
||||
keys_for_server[key_id] = max(
|
||||
keys_for_server.get(key_id, -1),
|
||||
keys = await fetcher.get_keys(
|
||||
verify_request.server_name,
|
||||
list(missing_key_ids),
|
||||
verify_request.minimum_valid_until_ts,
|
||||
)
|
||||
|
||||
results = await fetcher.get_keys(missing_keys)
|
||||
for key_id, key in keys.items():
|
||||
if not key:
|
||||
continue
|
||||
|
||||
completed = []
|
||||
for verify_request in remaining_requests:
|
||||
server_name = verify_request.server_name
|
||||
if key.valid_until_ts < verify_request.minimum_valid_until_ts:
|
||||
continue
|
||||
|
||||
existing_key = found_keys.get(key_id)
|
||||
if existing_key:
|
||||
if key.valid_until_ts <= existing_key.valid_until_ts:
|
||||
continue
|
||||
|
||||
found_keys[key_id] = key
|
||||
|
||||
missing_key_ids.difference_update(found_keys)
|
||||
|
||||
if missing_key_ids:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing keys for %s: %s"
|
||||
% (verify_request.server_name, missing_key_ids),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
# see if any of the keys we got this time are sufficient to
|
||||
# complete this VerifyJsonRequest.
|
||||
result_keys = results.get(server_name, {})
|
||||
for key_id in verify_request.key_ids:
|
||||
fetch_key_result = result_keys.get(key_id)
|
||||
if not fetch_key_result:
|
||||
# we didn't get a result for this key
|
||||
continue
|
||||
|
||||
if (
|
||||
fetch_key_result.valid_until_ts
|
||||
< verify_request.minimum_valid_until_ts
|
||||
):
|
||||
# key was not valid at this point
|
||||
continue
|
||||
|
||||
# we have a valid key for this request. If we run the callback
|
||||
# immediately, it may cancel our loggingcontext while we are still in
|
||||
# it, so instead we schedule it for the next time round the reactor.
|
||||
#
|
||||
# (this also ensures that we don't get a stack overflow if we had
|
||||
# a massive queue of lookups waiting for this server).
|
||||
logger.debug(
|
||||
"Found key %s:%s for %s",
|
||||
server_name,
|
||||
key_id,
|
||||
verify_request.request_name,
|
||||
)
|
||||
self.clock.call_later(
|
||||
0,
|
||||
verify_request.key_ready.callback,
|
||||
(server_name, key_id, fetch_key_result.verify_key),
|
||||
)
|
||||
completed.append(verify_request)
|
||||
break
|
||||
|
||||
remaining_requests.difference_update(completed)
|
||||
verify_key = found_keys[key_id].verify_key
|
||||
try:
|
||||
json_object = verify_request.json_object_callback()
|
||||
verify_signed_json(
|
||||
json_object,
|
||||
verify_request.server_name,
|
||||
verify_key,
|
||||
)
|
||||
except SignatureVerifyException as e:
|
||||
logger.debug(
|
||||
"Error verifying signature for %s:%s:%s with key %s: %s",
|
||||
verify_request.server_name,
|
||||
verify_key.alg,
|
||||
verify_key.version,
|
||||
encode_verify_key_base64(verify_key),
|
||||
str(e),
|
||||
)
|
||||
raise SynapseError(
|
||||
401,
|
||||
"Invalid signature for server %s with key %s:%s: %s"
|
||||
% (
|
||||
verify_request.server_name,
|
||||
verify_key.alg,
|
||||
verify_key.version,
|
||||
str(e),
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
class KeyFetcher(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys)
|
||||
|
||||
Returns:
|
||||
Map from server_name -> key_id -> FetchKeyResult
|
||||
"""
|
||||
raise NotImplementedError
|
||||
async def get_keys(
|
||||
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
return await self._queue.add_to_queue(
|
||||
_QueueValue(
|
||||
server_name=server_name,
|
||||
key_ids=key_ids,
|
||||
minimum_valid_until_ts=minimum_valid_until_ts,
|
||||
)
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_QueueValue]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
pass
|
||||
|
||||
|
||||
class StoreKeyFetcher(KeyFetcher):
|
||||
"""KeyFetcher impl which fetches keys from our data store"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]):
|
||||
key_ids_to_fetch = (
|
||||
(server_name, key_id)
|
||||
for server_name, keys_for_server in keys_to_fetch.items()
|
||||
for key_id in keys_for_server.keys()
|
||||
(queue_value.server_name, key_id)
|
||||
for queue_value in keys_to_fetch
|
||||
for key_id in queue_value.key_ids
|
||||
)
|
||||
|
||||
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||
|
@ -500,6 +351,8 @@ class StoreKeyFetcher(KeyFetcher):
|
|||
|
||||
class BaseV2KeyFetcher(KeyFetcher):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.config = hs.config
|
||||
|
||||
|
@ -607,10 +460,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
self.client = hs.get_federation_http_client()
|
||||
self.key_servers = self.config.key_servers
|
||||
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_QueueValue]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""see KeyFetcher.get_keys"""
|
||||
"""see KeyFetcher._fetch_keys"""
|
||||
|
||||
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
||||
try:
|
||||
|
@ -646,12 +499,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
return union_of_keys
|
||||
|
||||
async def get_server_verify_key_v2_indirect(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
|
||||
self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
the keys to be fetched.
|
||||
|
||||
key_server: notary server to query for the keys
|
||||
|
||||
|
@ -665,7 +518,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
perspective_name = key_server.server_name
|
||||
logger.info(
|
||||
"Requesting keys %s from notary server %s",
|
||||
keys_to_fetch.items(),
|
||||
keys_to_fetch,
|
||||
perspective_name,
|
||||
)
|
||||
|
||||
|
@ -675,11 +528,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
path="/_matrix/key/v2/query",
|
||||
data={
|
||||
"server_keys": {
|
||||
server_name: {
|
||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
||||
for key_id, min_valid_ts in server_keys.items()
|
||||
queue_value.server_name: {
|
||||
key_id: {
|
||||
"minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
|
||||
}
|
||||
for key_id in queue_value.key_ids
|
||||
}
|
||||
for server_name, server_keys in keys_to_fetch.items()
|
||||
for queue_value in keys_to_fetch
|
||||
}
|
||||
},
|
||||
)
|
||||
|
@ -779,8 +634,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_federation_http_client()
|
||||
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_QueueValue]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
|
@ -793,8 +648,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
|
||||
results = {}
|
||||
|
||||
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
|
||||
server_name, key_ids = key_to_fetch_item
|
||||
async def get_key(key_to_fetch_item: _QueueValue) -> None:
|
||||
server_name = key_to_fetch_item.server_name
|
||||
key_ids = key_to_fetch_item.key_ids
|
||||
|
||||
try:
|
||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||
results[server_name] = keys
|
||||
|
@ -805,7 +662,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
except Exception:
|
||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||
|
||||
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
||||
await yieldable_gather_results(get_key, keys_to_fetch)
|
||||
return results
|
||||
|
||||
async def get_server_verify_key_v2_direct(
|
||||
|
@ -877,37 +734,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||
keys.update(response_keys)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
||||
"""Waits for the key to become available, and then performs a verification
|
||||
|
||||
Args:
|
||||
verify_request:
|
||||
|
||||
Raises:
|
||||
SynapseError if there was a problem performing the verification
|
||||
"""
|
||||
server_name = verify_request.server_name
|
||||
with PreserveLoggingContext():
|
||||
_, key_id, verify_key = await verify_request.key_ready
|
||||
|
||||
json_object = verify_request.json_object
|
||||
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
except SignatureVerifyException as e:
|
||||
logger.debug(
|
||||
"Error verifying signature for %s:%s:%s with key %s: %s",
|
||||
server_name,
|
||||
verify_key.alg,
|
||||
verify_key.version,
|
||||
encode_verify_key_base64(verify_key),
|
||||
str(e),
|
||||
)
|
||||
raise SynapseError(
|
||||
401,
|
||||
"Invalid signature for server %s with key %s:%s: %s"
|
||||
% (server_name, verify_key.alg, verify_key.version, str(e)),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
|
|
@ -418,7 +418,9 @@ def get_send_level(
|
|||
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
|
||||
power_levels_event = _get_power_level_event(auth_events)
|
||||
|
||||
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
|
||||
send_level = get_send_level(
|
||||
event.type, getattr(event, "state_key", None), power_levels_event
|
||||
)
|
||||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
if user_level < send_level:
|
||||
|
|
|
@ -16,12 +16,15 @@
|
|||
|
||||
import abc
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
import zlib
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
import attr
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
|
||||
from synapse.types import JsonDict, RoomStreamToken
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.frozenutils import freeze
|
||||
from synapse.util.stringutils import strtobool
|
||||
|
@ -37,6 +40,26 @@ from synapse.util.stringutils import strtobool
|
|||
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
|
||||
|
||||
|
||||
_PRESET_ZDICT = b"""{"auth_events":[],"prev_events":[],"type":"m.room.member",m.room.message"room_id":,"sender":,"content":{"msgtype":"m.text","body":""room_version":"creator":"depth":"prev_state":"state_key":""origin":"origin_server_ts":"hashes":{"sha256":"signatures":,"unsigned":{"age_ts":"ed25519"""
|
||||
|
||||
|
||||
def _encode_dict(d: JsonDict) -> bytes:
|
||||
json_bytes = json_encoder.encode(d).encode("utf-8")
|
||||
c = zlib.compressobj(1, zdict=_PRESET_ZDICT)
|
||||
result_bytes = c.compress(json_bytes)
|
||||
result_bytes += c.flush()
|
||||
return result_bytes
|
||||
|
||||
|
||||
def _decode_dict(b: bytes) -> JsonDict:
|
||||
d = zlib.decompressobj(zdict=_PRESET_ZDICT)
|
||||
|
||||
result_bytes = d.decompress(b)
|
||||
result_bytes += d.flush()
|
||||
|
||||
return json_decoder.decode(result_bytes.decode("utf-8"))
|
||||
|
||||
|
||||
class DictProperty:
|
||||
"""An object property which delegates to the `_dict` within its parent object."""
|
||||
|
||||
|
@ -205,7 +228,81 @@ class _EventInternalMetadata:
|
|||
return self._dict.get("redacted", False)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _Signatures:
|
||||
_signatures_bytes: bytes
|
||||
|
||||
@staticmethod
|
||||
def from_dict(signature_dict: JsonDict) -> "_Signatures":
|
||||
return _Signatures(_encode_dict(signature_dict))
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
return _decode_dict(self._signatures_bytes)
|
||||
|
||||
def get(self, server_name, default=None):
|
||||
return self.get_dict().get(server_name, default)
|
||||
|
||||
def update(self, other: Union[JsonDict, "_Signatures"]):
|
||||
if isinstance(other, _Signatures):
|
||||
other_dict = _decode_dict(other._signatures_bytes)
|
||||
else:
|
||||
other_dict = other
|
||||
|
||||
signatures = self.get_dict()
|
||||
signatures.update(other_dict)
|
||||
self._signatures_bytes = _encode_dict(signatures)
|
||||
|
||||
|
||||
class _SmallListV1(str):
|
||||
__slots__ = []
|
||||
|
||||
def get(self):
|
||||
return self.split(",")
|
||||
|
||||
@staticmethod
|
||||
def create(event_ids):
|
||||
return _SmallListV1(",".join(event_ids))
|
||||
|
||||
|
||||
class _SmallListV2_V3(bytes):
|
||||
__slots__ = []
|
||||
|
||||
def get(self, url_safe):
|
||||
i = 0
|
||||
while i * 32 < len(self):
|
||||
bit = self[i * 32 : (i + 1) * 32]
|
||||
i += 1
|
||||
yield "$" + encode_base64(bit, urlsafe=url_safe)
|
||||
|
||||
@staticmethod
|
||||
def create(event_ids):
|
||||
return _SmallListV2_V3(
|
||||
b"".join(decode_base64(event_id[1:]) for event_id in event_ids)
|
||||
)
|
||||
|
||||
|
||||
class EventBase(metaclass=abc.ABCMeta):
|
||||
__slots__ = [
|
||||
"room_version",
|
||||
"signatures",
|
||||
"unsigned",
|
||||
"rejected_reason",
|
||||
"_encoded_dict",
|
||||
"_auth_event_ids",
|
||||
"depth",
|
||||
"_content",
|
||||
"_hashes",
|
||||
"origin",
|
||||
"origin_server_ts",
|
||||
"_prev_event_ids",
|
||||
"redacts",
|
||||
"room_id",
|
||||
"sender",
|
||||
"type",
|
||||
"state_key",
|
||||
"internal_metadata",
|
||||
]
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def format_version(self) -> int:
|
||||
|
@ -224,32 +321,44 @@ class EventBase(metaclass=abc.ABCMeta):
|
|||
assert room_version.event_format == self.format_version
|
||||
|
||||
self.room_version = room_version
|
||||
self.signatures = signatures
|
||||
self.signatures = _Signatures.from_dict(signatures)
|
||||
self.unsigned = unsigned
|
||||
self.rejected_reason = rejected_reason
|
||||
|
||||
self._dict = event_dict
|
||||
self._encoded_dict = _encode_dict(event_dict)
|
||||
|
||||
self.depth = event_dict["depth"]
|
||||
self.origin = event_dict.get("origin")
|
||||
self.origin_server_ts = event_dict["origin_server_ts"]
|
||||
self.redacts = event_dict.get("redacts")
|
||||
self.room_id = event_dict["room_id"]
|
||||
self.sender = event_dict["sender"]
|
||||
self.type = event_dict["type"]
|
||||
if "state_key" in event_dict:
|
||||
self.state_key = event_dict["state_key"]
|
||||
|
||||
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
|
||||
|
||||
auth_events = DictProperty("auth_events")
|
||||
depth = DictProperty("depth")
|
||||
content = DictProperty("content")
|
||||
hashes = DictProperty("hashes")
|
||||
origin = DictProperty("origin")
|
||||
origin_server_ts = DictProperty("origin_server_ts")
|
||||
prev_events = DictProperty("prev_events")
|
||||
redacts = DefaultDictProperty("redacts", None)
|
||||
room_id = DictProperty("room_id")
|
||||
sender = DictProperty("sender")
|
||||
state_key = DictProperty("state_key")
|
||||
type = DictProperty("type")
|
||||
user_id = DictProperty("sender")
|
||||
@property
|
||||
def content(self) -> JsonDict:
|
||||
return self.get_dict()["content"]
|
||||
|
||||
@property
|
||||
def hashes(self) -> JsonDict:
|
||||
return self.get_dict()["hashes"]
|
||||
|
||||
@property
|
||||
def prev_events(self) -> List[str]:
|
||||
return list(self._prev_events)
|
||||
|
||||
@property
|
||||
def event_id(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return self.sender
|
||||
|
||||
@property
|
||||
def membership(self):
|
||||
return self.content["membership"]
|
||||
|
@ -258,17 +367,13 @@ class EventBase(metaclass=abc.ABCMeta):
|
|||
return hasattr(self, "state_key") and self.state_key is not None
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
d = dict(self._dict)
|
||||
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
|
||||
d = _decode_dict(self._encoded_dict)
|
||||
d.update(
|
||||
{"signatures": self.signatures.get_dict(), "unsigned": dict(self.unsigned)}
|
||||
)
|
||||
|
||||
return d
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def get_internal_metadata_dict(self):
|
||||
return self.internal_metadata.get_dict()
|
||||
|
||||
def get_pdu_json(self, time_now=None) -> JsonDict:
|
||||
pdu_json = self.get_dict()
|
||||
|
||||
|
@ -285,41 +390,11 @@ class EventBase(metaclass=abc.ABCMeta):
|
|||
def __set__(self, instance, value):
|
||||
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
||||
|
||||
def __getitem__(self, field):
|
||||
return self._dict[field]
|
||||
|
||||
def __contains__(self, field):
|
||||
return field in self._dict
|
||||
|
||||
def items(self):
|
||||
return list(self._dict.items())
|
||||
|
||||
def keys(self):
|
||||
return self._dict.keys()
|
||||
|
||||
def prev_event_ids(self):
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of event IDs of this event's prev_events
|
||||
"""
|
||||
return [e for e, _ in self.prev_events]
|
||||
|
||||
def auth_event_ids(self):
|
||||
"""Returns the list of auth event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of event IDs of this event's auth_events
|
||||
"""
|
||||
return [e for e, _ in self.auth_events]
|
||||
|
||||
def freeze(self):
|
||||
"""'Freeze' the event dict, so it cannot be modified by accident"""
|
||||
|
||||
# this will be a no-op if the event dict is already frozen.
|
||||
self._dict = freeze(self._dict)
|
||||
# self._dict = freeze(self._dict)
|
||||
|
||||
|
||||
class FrozenEvent(EventBase):
|
||||
|
@ -355,6 +430,12 @@ class FrozenEvent(EventBase):
|
|||
frozen_dict = event_dict
|
||||
|
||||
self._event_id = event_dict["event_id"]
|
||||
self._auth_event_ids = _SmallListV1.create(
|
||||
e for e, _ in event_dict["auth_events"]
|
||||
)
|
||||
self._prev_event_ids = _SmallListV1.create(
|
||||
e for e, _ in event_dict["prev_events"]
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
frozen_dict,
|
||||
|
@ -369,18 +450,26 @@ class FrozenEvent(EventBase):
|
|||
def event_id(self) -> str:
|
||||
return self._event_id
|
||||
|
||||
def auth_event_ids(self):
|
||||
return list(self._auth_event_ids.get())
|
||||
|
||||
def prev_event_ids(self):
|
||||
return list(self._prev_event_ids.get())
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % (
|
||||
self.get("event_id", None),
|
||||
self.get("type", None),
|
||||
self.get("state_key", None),
|
||||
self.event_id,
|
||||
self.type,
|
||||
getattr(self, "state_key", None),
|
||||
)
|
||||
|
||||
|
||||
class FrozenEventV2(EventBase):
|
||||
__slots__ = ["_event_id"]
|
||||
|
||||
format_version = EventFormatVersions.V2 # All events of this type are V2
|
||||
|
||||
def __init__(
|
||||
|
@ -415,6 +504,8 @@ class FrozenEventV2(EventBase):
|
|||
frozen_dict = event_dict
|
||||
|
||||
self._event_id = None
|
||||
self._auth_event_ids = _SmallListV2_V3.create(event_dict["auth_events"])
|
||||
self._prev_event_ids = _SmallListV2_V3.create(event_dict["prev_events"])
|
||||
|
||||
super().__init__(
|
||||
frozen_dict,
|
||||
|
@ -436,24 +527,6 @@ class FrozenEventV2(EventBase):
|
|||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
||||
return self._event_id
|
||||
|
||||
def prev_event_ids(self):
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of event IDs of this event's prev_events
|
||||
"""
|
||||
return self.prev_events
|
||||
|
||||
def auth_event_ids(self):
|
||||
"""Returns the list of auth event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of event IDs of this event's auth_events
|
||||
"""
|
||||
return self.auth_events
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
@ -461,14 +534,22 @@ class FrozenEventV2(EventBase):
|
|||
return "<%s event_id=%r, type=%r, state_key=%r>" % (
|
||||
self.__class__.__name__,
|
||||
self.event_id,
|
||||
self.get("type", None),
|
||||
self.get("state_key", None),
|
||||
self.type,
|
||||
self.state_key if self.is_state() else None,
|
||||
)
|
||||
|
||||
def auth_event_ids(self):
|
||||
return list(self._auth_event_ids.get(False))
|
||||
|
||||
def prev_event_ids(self):
|
||||
return list(self._prev_event_ids.get(False))
|
||||
|
||||
|
||||
class FrozenEventV3(FrozenEventV2):
|
||||
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
|
||||
|
||||
__slots__ = ["_event_id"]
|
||||
|
||||
format_version = EventFormatVersions.V3 # All events of this type are V3
|
||||
|
||||
@property
|
||||
|
@ -484,6 +565,12 @@ class FrozenEventV3(FrozenEventV2):
|
|||
)
|
||||
return self._event_id
|
||||
|
||||
def auth_event_ids(self):
|
||||
return list(self._auth_event_ids.get(True))
|
||||
|
||||
def prev_event_ids(self):
|
||||
return list(self._prev_event_ids.get(True))
|
||||
|
||||
|
||||
def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
|
||||
"""Returns the python type to use to construct an Event object for the
|
||||
|
|
|
@ -38,6 +38,8 @@ class EventValidator:
|
|||
if event.format_version == EventFormatVersions.V1:
|
||||
EventID.from_string(event.event_id)
|
||||
|
||||
event_dict = event.get_dict()
|
||||
|
||||
required = [
|
||||
"auth_events",
|
||||
"content",
|
||||
|
@ -49,7 +51,7 @@ class EventValidator:
|
|||
]
|
||||
|
||||
for k in required:
|
||||
if not hasattr(event, k):
|
||||
if k not in event_dict:
|
||||
raise SynapseError(400, "Event does not have key %s" % (k,))
|
||||
|
||||
# Check that the following keys have string values
|
||||
|
|
|
@ -73,10 +73,10 @@ class FederationBase:
|
|||
* throws a SynapseError if the signature check failed.
|
||||
The deferreds run their callbacks in the sentinel
|
||||
"""
|
||||
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
|
||||
|
||||
ctx = current_context()
|
||||
|
||||
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def callback(_, pdu: EventBase):
|
||||
with PreserveLoggingContext(ctx):
|
||||
|
@ -90,9 +90,7 @@ class FederationBase:
|
|||
# received event was probably a redacted copy (but we then use our
|
||||
# *actual* redacted copy to be on the safe side.)
|
||||
redacted_event = prune_event(pdu)
|
||||
if set(redacted_event.keys()) == set(pdu.keys()) and set(
|
||||
redacted_event.content.keys()
|
||||
) == set(pdu.content.keys()):
|
||||
if set(redacted_event.content.keys()) == set(pdu.content.keys()):
|
||||
logger.info(
|
||||
"Event %s seems to have been redacted; using our redacted "
|
||||
"copy",
|
||||
|
@ -137,11 +135,7 @@ class FederationBase:
|
|||
return deferreds
|
||||
|
||||
|
||||
class PduToCheckSig(
|
||||
namedtuple(
|
||||
"PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
|
||||
)
|
||||
):
|
||||
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -184,7 +178,6 @@ def _check_sigs_on_pdus(
|
|||
pdus_to_check = [
|
||||
PduToCheckSig(
|
||||
pdu=p,
|
||||
redacted_pdu_json=prune_event(p).get_pdu_json(),
|
||||
sender_domain=get_domain_from_id(p.sender),
|
||||
deferreds=[],
|
||||
)
|
||||
|
@ -195,13 +188,12 @@ def _check_sigs_on_pdus(
|
|||
# (except if its a 3pid invite, in which case it may be sent by any server)
|
||||
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server(
|
||||
more_deferreds = keyring.verify_events_for_server(
|
||||
[
|
||||
(
|
||||
p.sender_domain,
|
||||
p.redacted_pdu_json,
|
||||
p.pdu,
|
||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
p.pdu.event_id,
|
||||
)
|
||||
for p in pdus_to_check_sender
|
||||
]
|
||||
|
@ -230,13 +222,12 @@ def _check_sigs_on_pdus(
|
|||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
||||
]
|
||||
|
||||
more_deferreds = keyring.verify_json_objects_for_server(
|
||||
more_deferreds = keyring.verify_events_for_server(
|
||||
[
|
||||
(
|
||||
get_domain_from_id(p.pdu.event_id),
|
||||
p.redacted_pdu_json,
|
||||
p.pdu,
|
||||
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
p.pdu.event_id,
|
||||
)
|
||||
for p in pdus_to_check_event_id
|
||||
]
|
||||
|
|
|
@ -33,6 +33,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import attr
|
||||
import ijson
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -55,11 +56,16 @@ from synapse.api.room_versions import (
|
|||
)
|
||||
from synapse.events import EventBase, builder
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.logging.context import (
|
||||
get_thread_resource_usage,
|
||||
make_deferred_yieldable,
|
||||
preserve_fn,
|
||||
)
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -385,7 +391,6 @@ class FederationClient(FederationBase):
|
|||
Returns:
|
||||
A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
async def handle_check_result(pdu: EventBase, deferred: Deferred):
|
||||
try:
|
||||
|
@ -420,6 +425,7 @@ class FederationClient(FederationBase):
|
|||
return res
|
||||
|
||||
handle = preserve_fn(handle_check_result)
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
|
||||
|
||||
valid_pdus = await make_deferred_yieldable(
|
||||
|
@ -667,19 +673,37 @@ class FederationClient(FederationBase):
|
|||
async def send_request(destination) -> Dict[str, Any]:
|
||||
content = await self._do_send_join(destination, pdu)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
# logger.debug("Got content: %s", content.getvalue())
|
||||
|
||||
state = [
|
||||
event_from_pdu_json(p, room_version, outlier=True)
|
||||
for p in content.get("state", [])
|
||||
]
|
||||
# logger.info("send_join content: %d", len(content))
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, room_version, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
content.seek(0)
|
||||
|
||||
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
|
||||
r = get_thread_resource_usage()
|
||||
logger.info("Memory before state: %s", r.ru_maxrss)
|
||||
|
||||
state = []
|
||||
for i, p in enumerate(ijson.items(content, "state.item")):
|
||||
state.append(event_from_pdu_json(p, room_version, outlier=True))
|
||||
if i % 1000 == 999:
|
||||
await self._clock.sleep(0)
|
||||
|
||||
r = get_thread_resource_usage()
|
||||
logger.info("Memory after state: %s", r.ru_maxrss)
|
||||
|
||||
logger.info("Parsed state: %d", len(state))
|
||||
content.seek(0)
|
||||
|
||||
auth_chain = []
|
||||
for i, p in enumerate(ijson.items(content, "auth_chain.item")):
|
||||
auth_chain.append(event_from_pdu_json(p, room_version, outlier=True))
|
||||
if i % 1000 == 999:
|
||||
await self._clock.sleep(0)
|
||||
|
||||
r = get_thread_resource_usage()
|
||||
logger.info("Memory after: %s", r.ru_maxrss)
|
||||
|
||||
logger.info("Parsed auth chain: %d", len(auth_chain))
|
||||
|
||||
create_event = None
|
||||
for e in state:
|
||||
|
@ -704,12 +728,19 @@ class FederationClient(FederationBase):
|
|||
% (create_room_version,)
|
||||
)
|
||||
|
||||
valid_pdus = await self._check_sigs_and_hash_and_fetch(
|
||||
destination,
|
||||
list(pdus.values()),
|
||||
outlier=True,
|
||||
room_version=room_version,
|
||||
)
|
||||
valid_pdus = []
|
||||
|
||||
for chunk in batch_iter(itertools.chain(state, auth_chain), 1000):
|
||||
logger.info("Handling next _check_sigs_and_hash_and_fetch chunk")
|
||||
new_valid_pdus = await self._check_sigs_and_hash_and_fetch(
|
||||
destination,
|
||||
chunk,
|
||||
outlier=True,
|
||||
room_version=room_version,
|
||||
)
|
||||
valid_pdus.extend(new_valid_pdus)
|
||||
|
||||
logger.info("_check_sigs_and_hash_and_fetch done")
|
||||
|
||||
valid_pdus_map = {p.event_id: p for p in valid_pdus}
|
||||
|
||||
|
@ -744,6 +775,8 @@ class FederationClient(FederationBase):
|
|||
% (auth_chain_create_events,)
|
||||
)
|
||||
|
||||
logger.info("Returning from send_join")
|
||||
|
||||
return {
|
||||
"state": signed_state,
|
||||
"auth_chain": signed_auth,
|
||||
|
@ -769,6 +802,8 @@ class FederationClient(FederationBase):
|
|||
if not self._is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
||||
|
||||
resp = await self.transport_layer.send_join_v1(
|
||||
|
|
|
@ -244,7 +244,10 @@ class TransportLayerClient:
|
|||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
return_string_io=True,
|
||||
)
|
||||
|
||||
return response
|
||||
|
@ -254,7 +257,10 @@ class TransportLayerClient:
|
|||
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = await self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
return_string_io=True,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
@ -78,7 +78,7 @@ class DirectoryHandler(BaseHandler):
|
|||
# TODO(erikj): Add transactions.
|
||||
# TODO(erikj): Check if there is a current association.
|
||||
if not servers:
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
servers = {get_domain_from_id(u) for u in users}
|
||||
|
||||
if not servers:
|
||||
|
@ -270,7 +270,7 @@ class DirectoryHandler(BaseHandler):
|
|||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
extra_servers = {get_domain_from_id(u) for u in users}
|
||||
servers = set(extra_servers) | set(servers)
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence.
|
||||
if event.state_key == auth_user_id:
|
||||
# Send down presence for everyone in the room.
|
||||
users = await self.state.get_current_users_in_room(
|
||||
users = await self.store.get_users_in_room(
|
||||
event.room_id
|
||||
) # type: Iterable[str]
|
||||
else:
|
||||
|
|
|
@ -552,7 +552,7 @@ class FederationHandler(BaseHandler):
|
|||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||
) -> List[EventBase]:
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
|
@ -573,11 +573,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
desired_events = set(state_event_ids + auth_event_ids)
|
||||
|
||||
event_map = await self._get_events_from_store_or_dest(
|
||||
failed_to_fetch = await self._get_events_from_store_or_dest(
|
||||
destination, room_id, desired_events
|
||||
)
|
||||
|
||||
failed_to_fetch = desired_events - event_map.keys()
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state/auth events for %s %s",
|
||||
|
@ -585,55 +584,12 @@ class FederationHandler(BaseHandler):
|
|||
failed_to_fetch,
|
||||
)
|
||||
|
||||
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
|
||||
|
||||
remote_state = [
|
||||
event_map[e_id] for e_id in state_event_ids if e_id in event_map
|
||||
]
|
||||
|
||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
return remote_state, auth_chain
|
||||
|
||||
async def _get_events_from_store_or_dest(
|
||||
self, destination: str, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[str, EventBase]:
|
||||
"""Fetch events from a remote destination, checking if we already have them.
|
||||
|
||||
Persists any events we don't already have as outliers.
|
||||
|
||||
If we fail to fetch any of the events, a warning will be logged, and the event
|
||||
will be omitted from the result. Likewise, any events which turn out not to
|
||||
be in the given room.
|
||||
|
||||
This function *does not* automatically get missing auth events of the
|
||||
newly fetched events. Callers must include the full auth chain of
|
||||
of the missing events in the `event_ids` argument, to ensure that any
|
||||
missing auth events are correctly fetched.
|
||||
|
||||
Returns:
|
||||
map from event_id to event
|
||||
"""
|
||||
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
|
||||
|
||||
missing_events = set(event_ids) - fetched_events.keys()
|
||||
|
||||
if missing_events:
|
||||
logger.debug(
|
||||
"Fetching unknown state/auth events %s for room %s",
|
||||
missing_events,
|
||||
room_id,
|
||||
)
|
||||
|
||||
await self._get_events_and_persist(
|
||||
destination=destination, room_id=room_id, events=missing_events
|
||||
)
|
||||
|
||||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
fetched_events.update(
|
||||
(await self.store.get_events(missing_events, allow_rejected=True))
|
||||
)
|
||||
|
||||
# check for events which were in the wrong room.
|
||||
#
|
||||
# this can happen if a remote server claims that the state or
|
||||
|
@ -641,7 +597,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
bad_events = [
|
||||
(event_id, event.room_id)
|
||||
for event_id, event in fetched_events.items()
|
||||
for idx, event in enumerate(remote_state)
|
||||
if event.room_id != room_id
|
||||
]
|
||||
|
||||
|
@ -658,9 +614,49 @@ class FederationHandler(BaseHandler):
|
|||
room_id,
|
||||
)
|
||||
|
||||
del fetched_events[bad_event_id]
|
||||
if bad_events:
|
||||
remote_state = [e for e in remote_state if e.room_id == room_id]
|
||||
|
||||
return fetched_events
|
||||
return remote_state
|
||||
|
||||
async def _get_events_from_store_or_dest(
|
||||
self, destination: str, room_id: str, event_ids: Iterable[str]
|
||||
) -> Set[str]:
|
||||
"""Fetch events from a remote destination, checking if we already have them.
|
||||
|
||||
Persists any events we don't already have as outliers.
|
||||
|
||||
If we fail to fetch any of the events, a warning will be logged, and the event
|
||||
will be omitted from the result. Likewise, any events which turn out not to
|
||||
be in the given room.
|
||||
|
||||
This function *does not* automatically get missing auth events of the
|
||||
newly fetched events. Callers must include the full auth chain of
|
||||
of the missing events in the `event_ids` argument, to ensure that any
|
||||
missing auth events are correctly fetched.
|
||||
|
||||
Returns:
|
||||
map from event_id to event
|
||||
"""
|
||||
have_events = await self.store.have_seen_events(event_ids)
|
||||
|
||||
missing_events = set(event_ids) - have_events
|
||||
|
||||
if not missing_events:
|
||||
return set()
|
||||
|
||||
logger.debug(
|
||||
"Fetching unknown state/auth events %s for room %s",
|
||||
missing_events,
|
||||
room_id,
|
||||
)
|
||||
|
||||
await self._get_events_and_persist(
|
||||
destination=destination, room_id=room_id, events=missing_events
|
||||
)
|
||||
|
||||
new_events = await self.store.have_seen_events(missing_events)
|
||||
return missing_events - new_events
|
||||
|
||||
async def _get_state_after_missing_prev_event(
|
||||
self,
|
||||
|
@ -963,27 +959,23 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# For each edge get the current state.
|
||||
|
||||
auth_events = {}
|
||||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = await self._get_state_for_room(
|
||||
state = await self._get_state_for_room(
|
||||
destination=dest,
|
||||
room_id=room_id,
|
||||
event_id=e_id,
|
||||
)
|
||||
auth_events.update({a.event_id: a for a in auth})
|
||||
auth_events.update({s.event_id: s for s in state})
|
||||
state_events.update({s.event_id: s for s in state})
|
||||
events_to_state[e_id] = state
|
||||
|
||||
required_auth = {
|
||||
a_id
|
||||
for event in events
|
||||
+ list(state_events.values())
|
||||
+ list(auth_events.values())
|
||||
for event in events + list(state_events.values())
|
||||
for a_id in event.auth_event_ids()
|
||||
}
|
||||
auth_events = await self.store.get_events(required_auth, allow_rejected=True)
|
||||
auth_events.update(
|
||||
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
|
||||
)
|
||||
|
@ -1452,7 +1444,7 @@ class FederationHandler(BaseHandler):
|
|||
# room stuff after join currently doesn't work on workers.
|
||||
assert self.config.worker.worker_app is None
|
||||
|
||||
logger.debug("Joining %s to %s", joinee, room_id)
|
||||
logger.info("Joining %s to %s", joinee, room_id)
|
||||
|
||||
origin, event, room_version_obj = await self._make_and_verify_event(
|
||||
target_hosts,
|
||||
|
@ -1463,6 +1455,8 @@ class FederationHandler(BaseHandler):
|
|||
params={"ver": KNOWN_ROOM_VERSIONS},
|
||||
)
|
||||
|
||||
logger.info("make_join done from %s", origin)
|
||||
|
||||
# This shouldn't happen, because the RoomMemberHandler has a
|
||||
# linearizer lock which only allows one operation per user per room
|
||||
# at a time - so this is just paranoia.
|
||||
|
@ -1482,10 +1476,13 @@ class FederationHandler(BaseHandler):
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
logger.info("Sending join")
|
||||
ret = await self.federation_client.send_join(
|
||||
host_list, event, room_version_obj
|
||||
)
|
||||
|
||||
logger.info("send join done")
|
||||
|
||||
origin = ret["origin"]
|
||||
state = ret["state"]
|
||||
auth_chain = ret["auth_chain"]
|
||||
|
@ -1510,10 +1507,14 @@ class FederationHandler(BaseHandler):
|
|||
room_version=room_version_obj,
|
||||
)
|
||||
|
||||
logger.info("Persisting auth true")
|
||||
|
||||
max_stream_id = await self._persist_auth_tree(
|
||||
origin, room_id, auth_chain, state, event, room_version_obj
|
||||
)
|
||||
|
||||
logger.info("Persisted auth true")
|
||||
|
||||
# We wait here until this instance has seen the events come down
|
||||
# replication (if we're using replication) as the below uses caches.
|
||||
await self._replication.wait_for_stream_position(
|
||||
|
@ -2166,6 +2167,8 @@ class FederationHandler(BaseHandler):
|
|||
ctx = await self.state_handler.compute_event_context(e)
|
||||
events_to_context[e.event_id] = ctx
|
||||
|
||||
logger.info("Computed contexts")
|
||||
|
||||
event_map = {
|
||||
e.event_id: e for e in itertools.chain(auth_events, state, [event])
|
||||
}
|
||||
|
@ -2207,6 +2210,8 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
logger.info("Failed to find auth event %r", e_id)
|
||||
|
||||
logger.info("Got missing events")
|
||||
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
auth_for_e = {
|
||||
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
|
||||
|
@ -2231,6 +2236,8 @@ class FederationHandler(BaseHandler):
|
|||
raise
|
||||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
logger.info("Authed events")
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
room_id,
|
||||
[
|
||||
|
@ -2239,10 +2246,14 @@ class FederationHandler(BaseHandler):
|
|||
],
|
||||
)
|
||||
|
||||
logger.info("Persisted events")
|
||||
|
||||
new_event_context = await self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
logger.info("Computed context")
|
||||
|
||||
return await self.persist_events_and_notify(
|
||||
room_id, [(event, new_event_context)]
|
||||
)
|
||||
|
|
|
@ -258,7 +258,7 @@ class MessageHandler:
|
|||
"Getting joined members after leaving is not implemented"
|
||||
)
|
||||
|
||||
users_with_profile = await self.state.get_current_users_in_room(room_id)
|
||||
users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
|
||||
|
||||
# If this is an AS, double check that they are allowed to see the members.
|
||||
# This can either be because the AS user is in the room or because there
|
||||
|
@ -1108,7 +1108,7 @@ class EventCreationHandler:
|
|||
# it's not a self-redaction (to avoid having to look up whether the
|
||||
# user is actually admin or not).
|
||||
is_admin_redaction = False
|
||||
if event.type == EventTypes.Redaction:
|
||||
if event.type == EventTypes.Redaction and event.redacts:
|
||||
original_event = await self.store.get_event(
|
||||
event.redacts,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
|
@ -1195,7 +1195,7 @@ class EventCreationHandler:
|
|||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(returned_invite.signatures)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
if event.type == EventTypes.Redaction and event.redacts:
|
||||
original_event = await self.store.get_event(
|
||||
event.redacts,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
|
@ -1401,7 +1401,7 @@ class EventCreationHandler:
|
|||
]
|
||||
|
||||
for k in immutable_fields:
|
||||
if getattr(builder, k, None) != original_event.get(k):
|
||||
if getattr(builder, k, None) != getattr(original_event, k, None):
|
||||
raise Exception(
|
||||
"Third party rules module created an invalid event: "
|
||||
"cannot change field " + k
|
||||
|
|
|
@ -1183,7 +1183,16 @@ class PresenceHandler(BasePresenceHandler):
|
|||
max_pos, deltas = await self.store.get_current_state_deltas(
|
||||
self._event_pos, room_max_stream_ordering
|
||||
)
|
||||
await self._handle_state_delta(deltas)
|
||||
|
||||
# We may get multiple deltas for different rooms, but we want to
|
||||
# handle them on a room by room basis, so we batch them up by
|
||||
# room.
|
||||
deltas_by_room: Dict[str, List[JsonDict]] = {}
|
||||
for delta in deltas:
|
||||
deltas_by_room.setdefault(delta["room_id"], []).append(delta)
|
||||
|
||||
for room_id, deltas_for_room in deltas_by_room.items():
|
||||
await self._handle_state_delta(room_id, deltas_for_room)
|
||||
|
||||
self._event_pos = max_pos
|
||||
|
||||
|
@ -1192,17 +1201,21 @@ class PresenceHandler(BasePresenceHandler):
|
|||
max_pos
|
||||
)
|
||||
|
||||
async def _handle_state_delta(self, deltas: List[JsonDict]) -> None:
|
||||
"""Process current state deltas to find new joins that need to be
|
||||
handled.
|
||||
async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None:
|
||||
"""Process current state deltas for the room to find new joins that need
|
||||
to be handled.
|
||||
"""
|
||||
# A map of destination to a set of user state that they should receive
|
||||
presence_destinations = {} # type: Dict[str, Set[UserPresenceState]]
|
||||
|
||||
# Sets of newly joined users. Note that if the local server is
|
||||
# joining a remote room for the first time we'll see both the joining
|
||||
# user and all remote users as newly joined.
|
||||
newly_joined_users = set()
|
||||
|
||||
for delta in deltas:
|
||||
assert room_id == delta["room_id"]
|
||||
|
||||
typ = delta["type"]
|
||||
state_key = delta["state_key"]
|
||||
room_id = delta["room_id"]
|
||||
event_id = delta["event_id"]
|
||||
prev_event_id = delta["prev_event_id"]
|
||||
|
||||
|
@ -1231,72 +1244,55 @@ class PresenceHandler(BasePresenceHandler):
|
|||
# Ignore changes to join events.
|
||||
continue
|
||||
|
||||
# Retrieve any user presence state updates that need to be sent as a result,
|
||||
# and the destinations that need to receive it
|
||||
destinations, user_presence_states = await self._on_user_joined_room(
|
||||
room_id, state_key
|
||||
)
|
||||
newly_joined_users.add(state_key)
|
||||
|
||||
# Insert the destinations and respective updates into our destinations dict
|
||||
for destination in destinations:
|
||||
presence_destinations.setdefault(destination, set()).update(
|
||||
user_presence_states
|
||||
)
|
||||
if not newly_joined_users:
|
||||
# If nobody has joined then there's nothing to do.
|
||||
return
|
||||
|
||||
# Send out user presence updates for each destination
|
||||
for destination, user_state_set in presence_destinations.items():
|
||||
self._federation_queue.send_presence_to_destinations(
|
||||
destinations=[destination], states=user_state_set
|
||||
)
|
||||
# We want to send:
|
||||
# 1. presence states of all local users in the room to newly joined
|
||||
# remote servers
|
||||
# 2. presence states of newly joined users to all remote servers in
|
||||
# the room.
|
||||
#
|
||||
# TODO: Only send presence states to remote hosts that don't already
|
||||
# have them (because they already share rooms).
|
||||
|
||||
async def _on_user_joined_room(
|
||||
self, room_id: str, user_id: str
|
||||
) -> Tuple[List[str], List[UserPresenceState]]:
|
||||
"""Called when we detect a user joining the room via the current state
|
||||
delta stream. Returns the destinations that need to be updated and the
|
||||
presence updates to send to them.
|
||||
# Get all the users who were already in the room, by fetching the
|
||||
# current users in the room and removing the newly joined users.
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
prev_users = set(users) - newly_joined_users
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room that the user has joined.
|
||||
user_id: The ID of the user that has joined the room.
|
||||
# Construct sets for all the local users and remote hosts that were
|
||||
# already in the room
|
||||
prev_local_users = []
|
||||
prev_remote_hosts = set()
|
||||
for user_id in prev_users:
|
||||
if self.is_mine_id(user_id):
|
||||
prev_local_users.append(user_id)
|
||||
else:
|
||||
prev_remote_hosts.add(get_domain_from_id(user_id))
|
||||
|
||||
Returns:
|
||||
A tuple of destinations and presence updates to send to them.
|
||||
"""
|
||||
if self.is_mine_id(user_id):
|
||||
# If this is a local user then we need to send their presence
|
||||
# out to hosts in the room (who don't already have it)
|
||||
# Similarly, construct sets for all the local users and remote hosts
|
||||
# that were *not* already in the room. Care needs to be taken with the
|
||||
# calculating the remote hosts, as a host may have already been in the
|
||||
# room even if there is a newly joined user from that host.
|
||||
newly_joined_local_users = []
|
||||
newly_joined_remote_hosts = set()
|
||||
for user_id in newly_joined_users:
|
||||
if self.is_mine_id(user_id):
|
||||
newly_joined_local_users.append(user_id)
|
||||
else:
|
||||
host = get_domain_from_id(user_id)
|
||||
if host not in prev_remote_hosts:
|
||||
newly_joined_remote_hosts.add(host)
|
||||
|
||||
# TODO: We should be able to filter the hosts down to those that
|
||||
# haven't previously seen the user
|
||||
|
||||
remote_hosts = await self.state.get_current_hosts_in_room(room_id)
|
||||
|
||||
# Filter out ourselves.
|
||||
filtered_remote_hosts = [
|
||||
host for host in remote_hosts if host != self.server_name
|
||||
]
|
||||
|
||||
state = await self.current_state_for_user(user_id)
|
||||
return filtered_remote_hosts, [state]
|
||||
else:
|
||||
# A remote user has joined the room, so we need to:
|
||||
# 1. Check if this is a new server in the room
|
||||
# 2. If so send any presence they don't already have for
|
||||
# local users in the room.
|
||||
|
||||
# TODO: We should be able to filter the users down to those that
|
||||
# the server hasn't previously seen
|
||||
|
||||
# TODO: Check that this is actually a new server joining the
|
||||
# room.
|
||||
|
||||
remote_host = get_domain_from_id(user_id)
|
||||
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
user_ids = list(filter(self.is_mine_id, users))
|
||||
|
||||
states_d = await self.current_state_for_users(user_ids)
|
||||
# Send presence states of all local users in the room to newly joined
|
||||
# remote servers. (We actually only send states for local users already
|
||||
# in the room, as we'll send states for newly joined local users below.)
|
||||
if prev_local_users and newly_joined_remote_hosts:
|
||||
local_states = await self.current_state_for_users(prev_local_users)
|
||||
|
||||
# Filter out old presence, i.e. offline presence states where
|
||||
# the user hasn't been active for a week. We can change this
|
||||
|
@ -1306,13 +1302,27 @@ class PresenceHandler(BasePresenceHandler):
|
|||
now = self.clock.time_msec()
|
||||
states = [
|
||||
state
|
||||
for state in states_d.values()
|
||||
for state in local_states.values()
|
||||
if state.state != PresenceState.OFFLINE
|
||||
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
|
||||
or state.status_msg is not None
|
||||
]
|
||||
|
||||
return [remote_host], states
|
||||
self._federation_queue.send_presence_to_destinations(
|
||||
destinations=newly_joined_remote_hosts,
|
||||
states=states,
|
||||
)
|
||||
|
||||
# Send presence states of newly joined users to all remote servers in
|
||||
# the room
|
||||
if newly_joined_local_users and (
|
||||
prev_remote_hosts or newly_joined_remote_hosts
|
||||
):
|
||||
local_states = await self.current_state_for_users(newly_joined_local_users)
|
||||
self._federation_queue.send_presence_to_destinations(
|
||||
destinations=prev_remote_hosts | newly_joined_remote_hosts,
|
||||
states=list(local_states.values()),
|
||||
)
|
||||
|
||||
|
||||
def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool:
|
||||
|
|
|
@ -475,7 +475,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
):
|
||||
await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
UserID.from_string(old_event["state_key"]),
|
||||
UserID.from_string(old_event.state_key),
|
||||
new_room_id,
|
||||
"ban",
|
||||
ratelimit=False,
|
||||
|
@ -1327,7 +1327,7 @@ class RoomShutdownHandler:
|
|||
new_room_id = None
|
||||
logger.info("Shutting down room %r", room_id)
|
||||
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
kicked_users = []
|
||||
failed_to_kick_users = []
|
||||
for user_id in users:
|
||||
|
|
|
@ -1190,7 +1190,7 @@ class SyncHandler:
|
|||
|
||||
# Step 1b, check for newly joined rooms
|
||||
for room_id in newly_joined_rooms:
|
||||
joined_users = await self.state.get_current_users_in_room(room_id)
|
||||
joined_users = await self.store.get_users_in_room(room_id)
|
||||
newly_joined_or_invited_users.update(joined_users)
|
||||
|
||||
# TODO: Check that these users are actually new, i.e. either they
|
||||
|
@ -1206,7 +1206,7 @@ class SyncHandler:
|
|||
|
||||
# Now find users that we no longer track
|
||||
for room_id in newly_left_rooms:
|
||||
left_users = await self.state.get_current_users_in_room(room_id)
|
||||
left_users = await self.store.get_users_in_room(room_id)
|
||||
newly_left_users.update(left_users)
|
||||
|
||||
# Remove any users that we still share a room with.
|
||||
|
@ -1361,7 +1361,7 @@ class SyncHandler:
|
|||
|
||||
extra_users_ids = set(newly_joined_or_invited_users)
|
||||
for room_id in newly_joined_rooms:
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
extra_users_ids.update(users)
|
||||
extra_users_ids.discard(user.to_string())
|
||||
|
||||
|
|
|
@ -154,6 +154,7 @@ async def _handle_json_response(
|
|||
request: MatrixFederationRequest,
|
||||
response: IResponse,
|
||||
start_ms: int,
|
||||
return_string_io=False,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Reads the JSON body of a response, with a timeout
|
||||
|
@ -175,12 +176,12 @@ async def _handle_json_response(
|
|||
d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
|
||||
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
|
||||
|
||||
def parse(_len: int):
|
||||
return json_decoder.decode(buf.getvalue())
|
||||
await make_deferred_yieldable(d)
|
||||
|
||||
d.addCallback(parse)
|
||||
|
||||
body = await make_deferred_yieldable(d)
|
||||
if return_string_io:
|
||||
body = buf
|
||||
else:
|
||||
body = json_decoder.decode(buf.getvalue())
|
||||
except BodyExceededMaxSize as e:
|
||||
# The response was too big.
|
||||
logger.warning(
|
||||
|
@ -225,12 +226,13 @@ async def _handle_json_response(
|
|||
time_taken_secs = reactor.seconds() - start_ms / 1000
|
||||
|
||||
logger.info(
|
||||
"{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
|
||||
"{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
time_taken_secs,
|
||||
len(buf.getvalue()),
|
||||
request.method,
|
||||
request.uri.decode("ascii"),
|
||||
)
|
||||
|
@ -683,6 +685,7 @@ class MatrixFederationHttpClient:
|
|||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
return_string_io=False,
|
||||
) -> Union[JsonDict, list]:
|
||||
"""Sends the specified json data using PUT
|
||||
|
||||
|
@ -757,7 +760,12 @@ class MatrixFederationHttpClient:
|
|||
_sec_timeout = self.default_timeout
|
||||
|
||||
body = await _handle_json_response(
|
||||
self.reactor, _sec_timeout, request, response, start_ms
|
||||
self.reactor,
|
||||
_sec_timeout,
|
||||
request,
|
||||
response,
|
||||
start_ms,
|
||||
return_string_io=return_string_io,
|
||||
)
|
||||
|
||||
return body
|
||||
|
|
|
@ -535,6 +535,13 @@ class ReactorLastSeenMetric:
|
|||
|
||||
REGISTRY.register(ReactorLastSeenMetric())
|
||||
|
||||
# The minimum time in seconds between GCs for each generation, regardless of the current GC
|
||||
# thresholds and counts.
|
||||
MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
|
||||
|
||||
# The time (in seconds since the epoch) of the last time we did a GC for each generation.
|
||||
_last_gc = [0.0, 0.0, 0.0]
|
||||
|
||||
|
||||
def runUntilCurrentTimer(reactor, func):
|
||||
@functools.wraps(func)
|
||||
|
@ -575,11 +582,16 @@ def runUntilCurrentTimer(reactor, func):
|
|||
return ret
|
||||
|
||||
# Check if we need to do a manual GC (since its been disabled), and do
|
||||
# one if necessary.
|
||||
# one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
|
||||
# promote an object into gen 2, and we don't want to handle the same
|
||||
# object multiple times.
|
||||
threshold = gc.get_threshold()
|
||||
counts = gc.get_count()
|
||||
for i in (2, 1, 0):
|
||||
if threshold[i] < counts[i]:
|
||||
# We check if we need to do one based on a straightforward
|
||||
# comparison between the threshold and count. We also do an extra
|
||||
# check to make sure that we don't a GC too often.
|
||||
if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
|
||||
if i == 0:
|
||||
logger.debug("Collecting gc %d", i)
|
||||
else:
|
||||
|
@ -589,6 +601,8 @@ def runUntilCurrentTimer(reactor, func):
|
|||
unreachable = gc.collect(i)
|
||||
end = time.time()
|
||||
|
||||
_last_gc[i] = end
|
||||
|
||||
gc_time.labels(i).observe(end - start)
|
||||
gc_unreachable.labels(i).set(unreachable)
|
||||
|
||||
|
@ -615,6 +629,7 @@ try:
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MetricsResource",
|
||||
"generate_latest",
|
||||
|
|
191
synapse/metrics/jemalloc.py
Normal file
191
synapse/metrics/jemalloc.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from synapse.metrics import REGISTRY, GaugeMetricFamily
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _setup_jemalloc_stats():
|
||||
"""Checks to see if jemalloc is loaded, and hooks up a collector to record
|
||||
statistics exposed by jemalloc.
|
||||
"""
|
||||
|
||||
# Try to find the loaded jemalloc shared library, if any. We need to
|
||||
# introspect into what is loaded, rather than loading whatever is on the
|
||||
# path, as if we load a *different* jemalloc version things will seg fault.
|
||||
|
||||
# We look in `/proc/self/maps`, which only exists on linux.
|
||||
if not os.path.exists("/proc/self/maps"):
|
||||
logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
|
||||
return
|
||||
|
||||
# We're looking for a path at the end of the line that includes
|
||||
# "libjemalloc".
|
||||
regex = re.compile(r"/\S+/libjemalloc.*$")
|
||||
|
||||
jemalloc_path = None
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
match = regex.search(line.strip())
|
||||
if match:
|
||||
jemalloc_path = match.group()
|
||||
|
||||
if not jemalloc_path:
|
||||
# No loaded jemalloc was found.
|
||||
logger.debug("jemalloc not found")
|
||||
return
|
||||
|
||||
jemalloc = ctypes.CDLL(jemalloc_path)
|
||||
|
||||
def _mallctl(
|
||||
name: str, read: bool = True, write: Optional[int] = None
|
||||
) -> Optional[int]:
|
||||
"""Wrapper around `mallctl` for reading and writing integers to
|
||||
jemalloc.
|
||||
|
||||
Args:
|
||||
name: The name of the option to read from/write to.
|
||||
read: Whether to try and read the value.
|
||||
write: The value to write, if given.
|
||||
|
||||
Returns:
|
||||
The value read if `read` is True, otherwise None.
|
||||
|
||||
Raises:
|
||||
An exception if `mallctl` returns a non-zero error code.
|
||||
"""
|
||||
|
||||
input_var = None
|
||||
input_var_ref = None
|
||||
input_len_ref = None
|
||||
if read:
|
||||
input_var = ctypes.c_size_t(0)
|
||||
input_len = ctypes.c_size_t(ctypes.sizeof(input_var))
|
||||
|
||||
input_var_ref = ctypes.byref(input_var)
|
||||
input_len_ref = ctypes.byref(input_len)
|
||||
|
||||
write_var_ref = None
|
||||
write_len = ctypes.c_size_t(0)
|
||||
if write is not None:
|
||||
write_var = ctypes.c_size_t(write)
|
||||
write_len = ctypes.c_size_t(ctypes.sizeof(write_var))
|
||||
|
||||
write_var_ref = ctypes.byref(write_var)
|
||||
|
||||
# The interface is:
|
||||
#
|
||||
# int mallctl(
|
||||
# const char *name,
|
||||
# void *oldp,
|
||||
# size_t *oldlenp,
|
||||
# void *newp,
|
||||
# size_t newlen
|
||||
# )
|
||||
#
|
||||
# Where oldp/oldlenp is a buffer where the old value will be written to
|
||||
# (if not null), and newp/newlen is the buffer with the new value to set
|
||||
# (if not null). Note that they're all references *except* newlen.
|
||||
result = jemalloc.mallctl(
|
||||
name.encode("ascii"),
|
||||
input_var_ref,
|
||||
input_len_ref,
|
||||
write_var_ref,
|
||||
write_len,
|
||||
)
|
||||
|
||||
if result != 0:
|
||||
raise Exception("Failed to call mallctl")
|
||||
|
||||
if input_var is None:
|
||||
return None
|
||||
|
||||
return input_var.value
|
||||
|
||||
def _jemalloc_refresh_stats() -> None:
|
||||
"""Request that jemalloc updates its internal statistics. This needs to
|
||||
be called before querying for stats, otherwise it will return stale
|
||||
values.
|
||||
"""
|
||||
try:
|
||||
_mallctl("epoch", read=False, write=1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class JemallocCollector:
|
||||
"""Metrics for internal jemalloc stats."""
|
||||
|
||||
def collect(self):
|
||||
_jemalloc_refresh_stats()
|
||||
|
||||
g = GaugeMetricFamily(
|
||||
"jemalloc_stats_app_memory_bytes",
|
||||
"The stats reported by jemalloc",
|
||||
labels=["type"],
|
||||
)
|
||||
|
||||
# Read the relevant global stats from jemalloc. Note that these may
|
||||
# not be accurate if python is configured to use its internal small
|
||||
# object allocator (which is on by default, disable by setting the
|
||||
# env `PYTHONMALLOC=malloc`).
|
||||
#
|
||||
# See the jemalloc manpage for details about what each value means,
|
||||
# roughly:
|
||||
# - allocated ─ Total number of bytes allocated by the app
|
||||
# - active ─ Total number of bytes in active pages allocated by
|
||||
# the application, this is bigger than `allocated`.
|
||||
# - resident ─ Maximum number of bytes in physically resident data
|
||||
# pages mapped by the allocator, comprising all pages dedicated
|
||||
# to allocator metadata, pages backing active allocations, and
|
||||
# unused dirty pages. This is bigger than `active`.
|
||||
# - mapped ─ Total number of bytes in active extents mapped by the
|
||||
# allocator.
|
||||
# - metadata ─ Total number of bytes dedicated to jemalloc
|
||||
# metadata.
|
||||
for t in (
|
||||
"allocated",
|
||||
"active",
|
||||
"resident",
|
||||
"mapped",
|
||||
"metadata",
|
||||
):
|
||||
try:
|
||||
value = _mallctl(f"stats.{t}")
|
||||
except Exception:
|
||||
# There was an error fetching the value, skip.
|
||||
continue
|
||||
|
||||
g.add_metric([t], value=value)
|
||||
|
||||
yield g
|
||||
|
||||
REGISTRY.register(JemallocCollector())
|
||||
|
||||
logger.debug("Added jemalloc stats")
|
||||
|
||||
|
||||
def setup_jemalloc_stats():
|
||||
"""Try to setup jemalloc stats, if jemalloc is loaded."""
|
||||
|
||||
try:
|
||||
_setup_jemalloc_stats()
|
||||
except Exception as e:
|
||||
logger.info("Failed to setup collector to record jemalloc stats: %s", e)
|
|
@ -277,7 +277,7 @@ class Notifier:
|
|||
event_pos=event_pos,
|
||||
room_id=event.room_id,
|
||||
event_type=event.type,
|
||||
state_key=event.get("state_key"),
|
||||
state_key=getattr(event, "state_key", None),
|
||||
membership=event.content.get("membership"),
|
||||
max_room_stream_token=max_room_stream_token,
|
||||
extra_users=extra_users or [],
|
||||
|
|
|
@ -125,7 +125,7 @@ class PushRuleEvaluatorForEvent:
|
|||
self._power_levels = power_levels
|
||||
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
self._value_cache = _flatten_dict(event.get_dict())
|
||||
|
||||
def matches(
|
||||
self, condition: Dict[str, Any], user_id: str, display_name: str
|
||||
|
@ -271,7 +271,7 @@ def _re_word_boundary(r: str) -> str:
|
|||
|
||||
|
||||
def _flatten_dict(
|
||||
d: Union[EventBase, dict],
|
||||
d: dict,
|
||||
prefix: Optional[List[str]] = None,
|
||||
result: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
|
|
|
@ -116,6 +116,8 @@ CONDITIONAL_REQUIREMENTS = {
|
|||
# hiredis is not a *strict* dependency, but it makes things much faster.
|
||||
# (if it is not installed, we fall back to slow code.)
|
||||
"redis": ["txredisapi>=1.4.7", "hiredis"],
|
||||
# Required to use experimental `caches.track_memory_usage` config option.
|
||||
"cache_memory": ["pympler"],
|
||||
}
|
||||
|
||||
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
|
||||
|
|
|
@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
|
|||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource):
|
|||
# If there is a cache miss, request the missing keys, then recurse (and
|
||||
# ensure the result is sent).
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
await self.fetcher.get_keys(cache_misses)
|
||||
await yieldable_gather_results(
|
||||
lambda t: self.fetcher.get_keys(*t),
|
||||
(
|
||||
(server_name, list(keys), 0)
|
||||
for server_name, keys in cache_misses.items()
|
||||
),
|
||||
)
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||
else:
|
||||
signed_keys = []
|
||||
|
|
|
@ -213,19 +213,23 @@ class StateHandler:
|
|||
return ret.state
|
||||
|
||||
async def get_current_users_in_room(
|
||||
self, room_id: str, latest_event_ids: Optional[List[str]] = None
|
||||
self, room_id: str, latest_event_ids: List[str]
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
"""
|
||||
Get the users who are currently in a room.
|
||||
|
||||
Note: This is much slower than using the equivalent method
|
||||
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
|
||||
so this should only be used when wanting the users at a particular point
|
||||
in the room.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room.
|
||||
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
|
||||
Returns:
|
||||
Dictionary of user IDs to their profileinfo.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
||||
|
|
|
@ -69,6 +69,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
|
||||
|
||||
|
|
|
@ -205,8 +205,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
|
||||
sql = """
|
||||
SELECT user_id, display_name, avatar_url FROM room_memberships
|
||||
WHERE room_id = ? AND membership = ?
|
||||
SELECT state_key, display_name, avatar_url FROM room_memberships as m
|
||||
INNER JOIN current_state_events as c
|
||||
ON m.event_id = c.event_id
|
||||
AND m.room_id = c.room_id
|
||||
AND m.user_id = c.state_key
|
||||
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
|
||||
"""
|
||||
txn.execute(sql, (room_id, Membership.JOIN))
|
||||
|
||||
|
|
|
@ -142,8 +142,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
batch_size (int): Maximum number of state events to process
|
||||
per cycle.
|
||||
"""
|
||||
state = self.hs.get_state_handler()
|
||||
|
||||
# If we don't have progress filed, delete everything.
|
||||
if not progress:
|
||||
await self.delete_all_from_user_dir()
|
||||
|
@ -197,7 +195,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
room_id
|
||||
)
|
||||
|
||||
users_with_profile = await state.get_current_users_in_room(room_id)
|
||||
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
|
||||
user_ids = set(users_with_profile)
|
||||
|
||||
# Update each user in the user directory.
|
||||
|
|
|
@ -24,6 +24,11 @@ from synapse.config.cache import add_resizable_cache
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Whether to track estimated memory usage of the LruCaches.
|
||||
TRACK_MEMORY_USAGE = False
|
||||
|
||||
|
||||
caches_by_name = {} # type: Dict[str, Sized]
|
||||
collectors_by_name = {} # type: Dict[str, CacheMetric]
|
||||
|
||||
|
@ -32,6 +37,11 @@ cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
|
|||
cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
|
||||
cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
|
||||
cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
|
||||
cache_memory_usage = Gauge(
|
||||
"synapse_util_caches_cache_size_bytes",
|
||||
"Estimated memory usage of the caches",
|
||||
["name"],
|
||||
)
|
||||
|
||||
response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
|
||||
response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
|
||||
|
@ -52,6 +62,7 @@ class CacheMetric:
|
|||
hits = attr.ib(default=0)
|
||||
misses = attr.ib(default=0)
|
||||
evicted_size = attr.ib(default=0)
|
||||
memory_usage = attr.ib(default=None)
|
||||
|
||||
def inc_hits(self):
|
||||
self.hits += 1
|
||||
|
@ -62,6 +73,19 @@ class CacheMetric:
|
|||
def inc_evictions(self, size=1):
|
||||
self.evicted_size += size
|
||||
|
||||
def inc_memory_usage(self, memory: int):
|
||||
if self.memory_usage is None:
|
||||
self.memory_usage = 0
|
||||
|
||||
self.memory_usage += memory
|
||||
|
||||
def dec_memory_usage(self, memory: int):
|
||||
self.memory_usage -= memory
|
||||
|
||||
def clear_memory_usage(self):
|
||||
if self.memory_usage is not None:
|
||||
self.memory_usage = 0
|
||||
|
||||
def describe(self):
|
||||
return []
|
||||
|
||||
|
@ -81,6 +105,13 @@ class CacheMetric:
|
|||
cache_total.labels(self._cache_name).set(self.hits + self.misses)
|
||||
if getattr(self._cache, "max_size", None):
|
||||
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
|
||||
|
||||
if TRACK_MEMORY_USAGE:
|
||||
# self.memory_usage can be None if nothing has been inserted
|
||||
# into the cache yet.
|
||||
cache_memory_usage.labels(self._cache_name).set(
|
||||
self.memory_usage or 0
|
||||
)
|
||||
if self._collect_callback:
|
||||
self._collect_callback()
|
||||
except Exception as e:
|
||||
|
|
|
@ -32,9 +32,36 @@ from typing import (
|
|||
from typing_extensions import Literal
|
||||
|
||||
from synapse.config import cache as cache_config
|
||||
from synapse.util import caches
|
||||
from synapse.util.caches import CacheMetric, register_cache
|
||||
from synapse.util.caches.treecache import TreeCache
|
||||
|
||||
try:
|
||||
from pympler.asizeof import Asizer
|
||||
|
||||
def _get_size_of(val: Any, *, recurse=True) -> int:
|
||||
"""Get an estimate of the size in bytes of the object.
|
||||
|
||||
Args:
|
||||
val: The object to size.
|
||||
recurse: If true will include referenced values in the size,
|
||||
otherwise only sizes the given object.
|
||||
"""
|
||||
# Ignore singleton values when calculating memory usage.
|
||||
if val in ((), None, ""):
|
||||
return 0
|
||||
|
||||
sizer = Asizer()
|
||||
sizer.exclude_refs((), None, "")
|
||||
return sizer.asizeof(val, limit=100 if recurse else 0)
|
||||
|
||||
|
||||
except ImportError:
|
||||
|
||||
def _get_size_of(val: Any, *, recurse=True) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
# Function type: the type used for invalidation callbacks
|
||||
FT = TypeVar("FT", bound=Callable[..., Any])
|
||||
|
||||
|
@ -56,7 +83,7 @@ def enumerate_leaves(node, depth):
|
|||
|
||||
|
||||
class _Node:
|
||||
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
|
||||
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -84,6 +111,16 @@ class _Node:
|
|||
|
||||
self.add_callbacks(callbacks)
|
||||
|
||||
self.memory = 0
|
||||
if caches.TRACK_MEMORY_USAGE:
|
||||
self.memory = (
|
||||
_get_size_of(key)
|
||||
+ _get_size_of(value)
|
||||
+ _get_size_of(self.callbacks, recurse=False)
|
||||
+ _get_size_of(self, recurse=False)
|
||||
)
|
||||
self.memory += _get_size_of(self.memory, recurse=False)
|
||||
|
||||
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
|
||||
"""Add to stored list of callbacks, removing duplicates."""
|
||||
|
||||
|
@ -233,6 +270,9 @@ class LruCache(Generic[KT, VT]):
|
|||
if size_callback:
|
||||
cached_cache_len[0] += size_callback(node.value)
|
||||
|
||||
if caches.TRACK_MEMORY_USAGE and metrics:
|
||||
metrics.inc_memory_usage(node.memory)
|
||||
|
||||
def move_node_to_front(node):
|
||||
prev_node = node.prev_node
|
||||
next_node = node.next_node
|
||||
|
@ -258,6 +298,9 @@ class LruCache(Generic[KT, VT]):
|
|||
|
||||
node.run_and_clear_callbacks()
|
||||
|
||||
if caches.TRACK_MEMORY_USAGE and metrics:
|
||||
metrics.dec_memory_usage(node.memory)
|
||||
|
||||
return deleted_len
|
||||
|
||||
@overload
|
||||
|
@ -373,6 +416,9 @@ class LruCache(Generic[KT, VT]):
|
|||
if size_callback:
|
||||
cached_cache_len[0] = 0
|
||||
|
||||
if caches.TRACK_MEMORY_USAGE and metrics:
|
||||
metrics.clear_memory_usage()
|
||||
|
||||
@synchronized
|
||||
def cache_contains(key: KT) -> bool:
|
||||
return key in cache
|
||||
|
|
|
@ -729,7 +729,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(expected_state.state, PresenceState.ONLINE)
|
||||
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
||||
destinations=["server2"], states={expected_state}
|
||||
destinations={"server2"}, states=[expected_state]
|
||||
)
|
||||
|
||||
#
|
||||
|
@ -740,7 +740,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
self._add_new_user(room_id, "@bob:server3")
|
||||
|
||||
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
||||
destinations=["server3"], states={expected_state}
|
||||
destinations={"server3"}, states=[expected_state]
|
||||
)
|
||||
|
||||
def test_remote_gets_presence_when_local_user_joins(self):
|
||||
|
@ -788,14 +788,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||
self.presence_handler.current_state_for_user("@test2:server")
|
||||
)
|
||||
self.assertEqual(expected_state.state, PresenceState.ONLINE)
|
||||
self.assertEqual(
|
||||
self.federation_sender.send_presence_to_destinations.call_count, 2
|
||||
)
|
||||
self.federation_sender.send_presence_to_destinations.assert_any_call(
|
||||
destinations=["server3"], states={expected_state}
|
||||
)
|
||||
self.federation_sender.send_presence_to_destinations.assert_any_call(
|
||||
destinations=["server2"], states={expected_state}
|
||||
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
||||
destinations={"server2", "server3"}, states=[expected_state]
|
||||
)
|
||||
|
||||
def _add_new_user(self, room_id, user_id):
|
||||
|
|
Loading…
Reference in a new issue