Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Patrick Cloke 2021-06-02 11:38:54 -04:00
commit 09361655d2
43 changed files with 945 additions and 615 deletions

View file

@ -1,5 +1,12 @@
Synapse 1.35.0rc3 (2021-05-28)
==============================
Synapse 1.35.0 (2021-06-01)
===========================
Note that [the tag](https://github.com/matrix-org/synapse/releases/tag/v1.35.0rc3) and [docker images](https://hub.docker.com/layers/matrixdotorg/synapse/v1.35.0rc3/images/sha256-34ccc87bd99a17e2cbc0902e678b5937d16bdc1991ead097eee6096481ecf2c4?context=explore) for `v1.35.0rc3` were incorrectly built. If you are experiencing issues with either, it is recommended to upgrade to the equivalent tag or docker image for the `v1.35.0` release.
Deprecations and Removals
-------------------------
- The core Synapse development team plan to drop support for the [unstable API of MSC2858](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), including the undocumented `experimental.msc2858_enabled` config option, in August 2021. Client authors should ensure that their clients are updated to use the stable API (which has been supported since Synapse 1.30) well before that time, to give their users time to upgrade. ([\#10101](https://github.com/matrix-org/synapse/issues/10101))
Bugfixes
--------

View file

@ -0,0 +1 @@
Rewrite logic around verifying JSON object and fetching server keys to be more performant and use less memory.

1
changelog.d/10048.misc Normal file
View file

@ -0,0 +1 @@
Add `parse_strings_from_args` for parsing an array from query parameters.

1
changelog.d/10074.misc Normal file
View file

@ -0,0 +1 @@
Update opentracing to inject the right context into the carrier.

View file

@ -0,0 +1 @@
Make reason and score parameters optional for reporting content. Implements [MSC2414](https://github.com/matrix-org/matrix-doc/pull/2414). Contributed by Callum Brown.

View file

@ -0,0 +1 @@
Add support for routing more requests to workers.

1
changelog.d/10091.misc Normal file
View file

@ -0,0 +1 @@
Log method and path when dropping request due to size limit.

1
changelog.d/10092.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug in the `force_tracing_for_users` option introduced in Synapse v1.35 which meant that the OpenTracing spans produced were missing most tags.

1
changelog.d/10102.misc Normal file
View file

@ -0,0 +1 @@
Make `/sync` do fewer state resolutions.

1
changelog.d/10109.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug introduced in v1.35.0 where invite-only rooms would be shown to users in a space who were not invited.

1
changelog.d/9953.feature Normal file
View file

@ -0,0 +1 @@
Improve performance of incoming federation transactions in large rooms.

1
changelog.d/9973.feature Normal file
View file

@ -0,0 +1 @@
Improve performance of incoming federation transactions in large rooms.

View file

@ -1 +0,0 @@
Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.35.0) stable; urgency=medium
* New synapse release 1.35.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 01 Jun 2021 13:23:35 +0100
matrix-synapse-py3 (1.34.0) stable; urgency=medium
* New synapse release 1.34.0.

View file

@ -75,9 +75,9 @@ The following fields are returned in the JSON response body:
* `name`: string - The name of the room.
* `event_id`: string - The ID of the reported event.
* `user_id`: string - This is the user who reported the event and wrote the reason.
* `reason`: string - Comment made by the `user_id` in this report. May be blank.
* `reason`: string - Comment made by the `user_id` in this report. May be blank or `null`.
* `score`: integer - Content is reported based upon a negative score, where -100 is
"most offensive" and 0 is "inoffensive".
"most offensive" and 0 is "inoffensive". May be `null`.
* `sender`: string - This is the ID of the user who sent the original message/event that
was reported.
* `canonical_alias`: string - The canonical alias of the room. `null` if the room does not

View file

@ -228,6 +228,9 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/event/
^/_matrix/client/(api/v1|r0|unstable)/joined_rooms$
^/_matrix/client/(api/v1|r0|unstable)/search$
# Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$

View file

@ -47,7 +47,7 @@ try:
except ImportError:
pass
__version__ = "1.35.0rc3"
__version__ = "1.35.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -206,11 +206,11 @@ class Auth:
requester = create_requester(user_id, app_service=app_service)
request.requester = user_id
if user_id in self._force_tracing_for_users:
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if user_id in self._force_tracing_for_users:
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
return requester
@ -259,12 +259,12 @@ class Auth:
)
request.requester = requester
if user_info.token_owner in self._force_tracing_for_users:
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
opentracing.set_tag("device_id", device_id)
if user_info.token_owner in self._force_tracing_for_users:
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
return requester
except KeyError:

View file

@ -109,7 +109,7 @@ from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.search import SearchWorkerStore
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@ -242,7 +242,7 @@ class GenericWorkerSlavedStore(
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
ServerMetricsStore,
SearchWorkerStore,
SearchStore,
TransactionWorkerStore,
BaseSlavedStore,
):

View file

@ -16,8 +16,7 @@
import abc
import logging
import urllib
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
from signedjson.key import (
@ -44,17 +43,12 @@ from synapse.api.errors import (
from synapse.config.key import TrustedKeyServer
from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
preserve_fn,
run_in_background,
)
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.batching_queue import BatchingQueue
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@ -80,32 +74,19 @@ class VerifyJsonRequest:
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
request_name: The name of the request.
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
logcontext.
If we are unable to find a key which satisfies the request, the deferred
errbacks with an M_UNAUTHORIZED SynapseError.
"""
server_name = attr.ib(type=str)
get_json_object = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
key_ids = attr.ib(type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
@staticmethod
def from_json_object(
server_name: str,
json_object: JsonDict,
minimum_valid_until_ms: int,
request_name: str,
):
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
object for the given server.
@ -115,7 +96,6 @@ class VerifyJsonRequest:
server_name,
lambda: json_object,
minimum_valid_until_ms,
request_name=request_name,
key_ids=key_ids,
)
@ -135,16 +115,48 @@ class VerifyJsonRequest:
# memory than the Event object itself.
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
request_name=event.event_id,
key_ids=key_ids,
)
def to_fetch_key_request(self) -> "_FetchKeyRequest":
"""Create a key fetch request for all keys needed to satisfy the
verification request.
"""
return _FetchKeyRequest(
server_name=self.server_name,
minimum_valid_until_ts=self.minimum_valid_until_ts,
key_ids=self.key_ids,
)
class KeyLookupError(ValueError):
pass
@attr.s(slots=True)
class _FetchKeyRequest:
"""A request for keys for a given server.
We will continue to try and fetch until we have all the keys listed under
`key_ids` (with an appropriate `valid_until_ts` property) or we run out of
places to fetch keys from.
Attributes:
server_name: The name of the server that owns the keys.
minimum_valid_until_ts: The timestamp which the keys must be valid until.
key_ids: The IDs of the keys to attempt to fetch
"""
server_name = attr.ib(type=str)
minimum_valid_until_ts = attr.ib(type=int)
key_ids = attr.ib(type=List[str])
class Keyring:
"""Handles verifying signed JSON objects and fetching the keys needed to do
so.
"""
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
):
@ -158,22 +170,22 @@ class Keyring:
)
self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} # type: Dict[str, defer.Deferred]
self._server_queue = BatchingQueue(
"keyring_server",
clock=hs.get_clock(),
process_batch_callback=self._inner_fetch_key_requests,
) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
def verify_json_for_server(
async def verify_json_for_server(
self,
server_name: str,
json_object: JsonDict,
validity_time: int,
request_name: str,
) -> defer.Deferred:
) -> None:
"""Verify that a JSON object has been signed by a given server
Completes if the the object was correctly signed, otherwise raises.
Args:
server_name: name of the server which must have signed this object
@ -181,52 +193,45 @@ class Keyring:
validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
request_name: an identifier for this json object (eg, an event id)
for logging.
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
request = VerifyJsonRequest.from_json_object(
server_name,
json_object,
validity_time,
request_name,
)
requests = (request,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
return await self.process_request(request)
def verify_json_objects_for_server(
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
self, server_and_json: Iterable[Tuple[str, dict, int]]
) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json:
Iterable of (server_name, json_object, validity_time, request_name)
Iterable of (server_name, json_object, validity_time)
tuples.
validity_time is a timestamp at which the signing key must be
valid.
request_name is an identifier for this json object (eg, an event id)
for logging.
Returns:
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest.from_json_object(
server_name, json_object, validity_time, request_name
return [
run_in_background(
self.process_request,
VerifyJsonRequest.from_json_object(
server_name,
json_object,
validity_time,
),
)
for server_name, json_object, validity_time, request_name in server_and_json
)
for server_name, json_object, validity_time in server_and_json
]
def verify_events_for_server(
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
@ -252,321 +257,223 @@ class Keyring:
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
return self._verify_objects(
VerifyJsonRequest.from_event(server_name, event, validity_time)
return [
run_in_background(
self.process_request,
VerifyJsonRequest.from_event(
server_name,
event,
validity_time,
),
)
for server_name, event, validity_time in server_and_events
]
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed
by the server, the signatures don't match or we failed to fetch the
necessary keys.
"""
if not verify_request.key_ids:
raise SynapseError(
400,
f"Not signed by {verify_request.server_name}",
Codes.UNAUTHORIZED,
)
# Add the keys we need to verify to the queue for retrieval. We queue
# up requests for the same server so we don't end up with many in flight
# requests for the same keys.
key_request = verify_request.to_fetch_key_request()
found_keys_by_server = await self._server_queue.add_to_queue(
key_request, key=verify_request.server_name
)
def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server
# Since we batch up requests the returned set of keys may contain keys
# from other servers, so we pull out only the ones we care about.s
found_keys = found_keys_by_server.get(verify_request.server_name, {})
# Verify each signature we got valid keys for, raising if we can't
# verify any of them.
verified = False
for key_id in verify_request.key_ids:
key_result = found_keys.get(key_id)
if not key_result:
continue
Args:
verify_requests: Iterable of verification requests.
if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
continue
Returns:
List<Deferred[None]>: for each input item, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
# a list of VerifyJsonRequests which are awaiting a key lookup
key_lookups = []
handle = preserve_fn(_handle_key_deferred)
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which
will complete or fail (in the sentinel context) when verification completes.
"""
if not verify_request.key_ids:
return defer.fail(
SynapseError(
400,
"Not signed by %s" % (verify_request.server_name,),
Codes.UNAUTHORIZED,
)
verify_key = key_result.verify_key
json_object = verify_request.get_json_object()
try:
verify_signed_json(
json_object,
verify_request.server_name,
verify_key,
)
verified = True
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
verify_request.server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s: %s"
% (
verify_request.server_name,
verify_key.alg,
verify_key.version,
str(e),
),
Codes.UNAUTHORIZED,
)
logger.debug(
"Verifying %s for %s with key_ids %s, min_validity %i",
verify_request.request_name,
if not verified:
raise SynapseError(
401,
f"Failed to find any key to satisfy: {key_request}",
Codes.UNAUTHORIZED,
)
async def _inner_fetch_key_requests(
self, requests: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""Processing function for the queue of `_FetchKeyRequest`."""
logger.debug("Starting fetch for %s", requests)
# First we need to deduplicate requests for the same key. We do this by
# taking the *maximum* requested `minimum_valid_until_ts` for each pair
# of server name/key ID.
server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]]
for request in requests:
by_server = server_to_key_to_ts.setdefault(request.server_name, {})
for key_id in request.key_ids:
existing_ts = by_server.get(key_id, 0)
by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
deduped_requests = [
_FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
for server_name, by_server in server_to_key_to_ts.items()
for key_id, minimum_valid_ts in by_server.items()
]
logger.debug("Deduplicated key requests to %s", deduped_requests)
# For each key we call `_inner_verify_request` which will handle
# fetching each key. Note these shouldn't throw if we fail to contact
# other servers etc.
results_per_request = await yieldable_gather_results(
self._inner_fetch_key_request,
deduped_requests,
)
# We now convert the returned list of results into a map from server
# name to key ID to FetchKeyResult, to return.
to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (request, results) in zip(deduped_requests, results_per_request):
to_return_by_server = to_return.setdefault(request.server_name, {})
for key_id, key_result in results.items():
existing = to_return_by_server.get(key_id)
if not existing or existing.valid_until_ts < key_result.valid_until_ts:
to_return_by_server[key_id] = key_result
return to_return
async def _inner_fetch_key_request(
self, verify_request: _FetchKeyRequest
) -> Dict[str, FetchKeyResult]:
"""Attempt to fetch the given key by calling each key fetcher one by
one.
"""
logger.debug("Starting fetch for %s", verify_request)
found_keys: Dict[str, FetchKeyResult] = {}
missing_key_ids = set(verify_request.key_ids)
for fetcher in self._key_fetchers:
if not missing_key_ids:
break
logger.debug("Getting keys from %s for %s", fetcher, verify_request)
keys = await fetcher.get_keys(
verify_request.server_name,
verify_request.key_ids,
list(missing_key_ids),
verify_request.minimum_valid_until_ts,
)
# add the key request to the queue, but don't start it off yet.
key_lookups.append(verify_request)
# now run _handle_key_deferred, which will wait for the key request
# to complete and then do the verification.
#
# We want _handle_key_request to log to the right context, so we
# wrap it with preserve_fn (aka run_in_background)
return handle(verify_request)
results = [process(r) for r in verify_requests]
if key_lookups:
run_in_background(self._start_key_lookups, key_lookups)
return results
async def _start_key_lookups(
self, verify_requests: List[VerifyJsonRequest]
) -> None:
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
Args:
verify_requests:
"""
try:
# map from server name to a set of outstanding request ids
server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding.
await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
for server_name in server_to_request_ids.keys():
self.key_downloads[server_name] = defer.Deferred()
logger.debug("Got key lookup lock on %s", server_name)
# When we've finished fetching all the keys for a given server_name,
# drop the lock by resolving the deferred in key_downloads.
def drop_server_lock(server_name):
d = self.key_downloads.pop(server_name)
d.callback(None)
def lookup_done(res, verify_request):
server_name = verify_request.server_name
server_requests = server_to_request_ids[server_name]
server_requests.remove(id(verify_request))
# if there are no more requests for this server, we can drop the lock.
if not server_requests:
logger.debug("Releasing key lookup lock on %s", server_name)
drop_server_lock(server_name)
return res
for verify_request in verify_requests:
verify_request.key_ready.addBoth(lookup_done, verify_request)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names: list of servers which we want to look up
Returns:
Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
[w[0] for w in wait_on],
loop_count,
)
with PreserveLoggingContext():
await defer.DeferredList((w[1] for w in wait_on))
loop_count += 1
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
Args:
verify_requests: list of verify requests
"""
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
async def do_iterations():
try:
with Measure(self.clock, "get_server_verify_keys"):
for f in self._key_fetchers:
if not remaining_requests:
return
await self._attempt_key_fetches_with_fetcher(
f, remaining_requests
)
# look for any requests which weren't satisfied
while remaining_requests:
verify_request = remaining_requests.pop()
rq_str = (
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
% (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
)
)
# If we run the errback immediately, it may cancel our
# loggingcontext while we are still in it, so instead we
# schedule it for the next time round the reactor.
#
# (this also ensures that we don't get a stack overflow if we
# has a massive queue of lookups waiting for this server).
self.clock.call_later(
0,
verify_request.key_ready.errback,
SynapseError(
401,
"Failed to find any key to satisfy %s" % (rq_str,),
Codes.UNAUTHORIZED,
),
)
except Exception as err:
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved.
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in remaining_requests:
if not verify_request.key_ready.called:
verify_request.key_ready.errback(err)
run_in_background(do_iterations)
async def _attempt_key_fetches_with_fetcher(
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
fetcher: fetcher to use to fetch the keys
remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
# The keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.key_ready.called
keys_for_server = missing_keys[verify_request.server_name]
for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
verify_request.minimum_valid_until_ts,
)
results = await fetcher.get_keys(missing_keys)
completed = []
for verify_request in remaining_requests:
server_name = verify_request.server_name
# see if any of the keys we got this time are sufficient to
# complete this VerifyJsonRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id)
if not fetch_key_result:
# we didn't get a result for this key
for key_id, key in keys.items():
if not key:
continue
if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue
# If we already have a result for the given key ID we keep the
# one with the highest `valid_until_ts`.
existing_key = found_keys.get(key_id)
if existing_key:
if key.valid_until_ts <= existing_key.valid_until_ts:
continue
# we have a valid key for this request. If we run the callback
# immediately, it may cancel our loggingcontext while we are still in
# it, so instead we schedule it for the next time round the reactor.
# We always store the returned key even if it doesn't the
# `minimum_valid_until_ts` requirement, as some verification
# requests may still be able to be satisfied by it.
#
# (this also ensures that we don't get a stack overflow if we had
# a massive queue of lookups waiting for this server).
logger.debug(
"Found key %s:%s for %s",
server_name,
key_id,
verify_request.request_name,
)
self.clock.call_later(
0,
verify_request.key_ready.callback,
(server_name, key_id, fetch_key_result.verify_key),
)
completed.append(verify_request)
break
# We still keep looking for the key from other fetchers in that
# case though.
found_keys[key_id] = key
remaining_requests.difference_update(completed)
if key.valid_until_ts < verify_request.minimum_valid_until_ts:
continue
missing_key_ids.discard(key_id)
return found_keys
class KeyFetcher(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
def __init__(self, hs: "HomeServer"):
self._queue = BatchingQueue(
self.__class__.__name__, hs.get_clock(), self._fetch_keys
)
Returns:
Map from server_name -> key_id -> FetchKeyResult
"""
raise NotImplementedError
async def get_keys(
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
results = await self._queue.add_to_queue(
_FetchKeyRequest(
server_name=server_name,
key_ids=key_ids,
minimum_valid_until_ts=minimum_valid_until_ts,
)
)
return results.get(server_name, {})
@abc.abstractmethod
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
pass
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
key_ids_to_fetch = (
(server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
(queue_value.server_name, key_id)
for queue_value in keys_to_fetch
for key_id in queue_value.key_ids
)
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher):
class BaseV2KeyFetcher(KeyFetcher):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
self.config = hs.config
@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
"""see KeyFetcher._fetch_keys"""
async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
async def get_server_verify_key_v2_indirect(
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
the keys to be fetched.
key_server: notary server to query for the keys
@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name = key_server.server_name
logger.info(
"Requesting keys %s from notary server %s",
keys_to_fetch.items(),
keys_to_fetch,
perspective_name,
)
@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/query",
data={
"server_keys": {
server_name: {
key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
queue_value.server_name: {
key_id: {
"minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
}
for key_id in queue_value.key_ids
}
for server_name, server_keys in keys_to_fetch.items()
for queue_value in keys_to_fetch
}
},
)
@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
results = await self._queue.add_to_queue(
_FetchKeyRequest(
server_name=server_name,
key_ids=key_ids,
minimum_valid_until_ts=minimum_valid_until_ts,
),
key=server_name,
)
return results.get(server_name, {})
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {}
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item
async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
server_name = key_to_fetch_item.server_name
key_ids = key_to_fetch_item.key_ids
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
await yieldable_gather_results(get_key, keys_to_fetch.items())
await yieldable_gather_results(get_key, keys_to_fetch)
return results
async def get_server_verify_key_v2_direct(
@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys.update(response_keys)
return keys
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
verify_request:
Raises:
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.get_json_object()
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
server_name,
verify_key.alg,
verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s: %s"
% (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)

View file

@ -37,6 +37,7 @@ from synapse.http.servlet import (
)
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
SynapseTags,
start_active_span,
start_active_span_from_request,
tags,
@ -151,7 +152,9 @@ class Authenticator:
)
await self.keyring.verify_json_for_server(
origin, json_request, now, "Incoming request"
origin,
json_request,
now,
)
logger.debug("Request from %s", origin)
@ -314,7 +317,7 @@ class BaseFederationServlet:
raise
request_tags = {
"request_id": request.get_request_id(),
SynapseTags.REQUEST_ID: request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),

View file

@ -108,7 +108,9 @@ class GroupAttestationSigning:
assert server_name is not None
await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation"
server_name,
attestation,
now,
)
def create_attestation(self, group_id: str, user_id: str) -> JsonDict:

View file

@ -577,7 +577,9 @@ class FederationHandler(BaseHandler):
# Fetch the state events from the DB, and check we have the auth events.
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
auth_events_in_store = await self.store.have_seen_events(auth_event_ids)
auth_events_in_store = await self.store.have_seen_events(
room_id, auth_event_ids
)
# Check for missing events. We handle state and auth event seperately,
# as we want to pull the state from the DB, but we don't for the auth
@ -610,7 +612,7 @@ class FederationHandler(BaseHandler):
if missing_auth_events:
auth_events_in_store = await self.store.have_seen_events(
missing_auth_events
room_id, missing_auth_events
)
missing_auth_events.difference_update(auth_events_in_store)
@ -710,7 +712,7 @@ class FederationHandler(BaseHandler):
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events.difference_update(
await self.store.have_seen_events(missing_auth_events)
await self.store.have_seen_events(room_id, missing_auth_events)
)
logger.debug("We are also missing %i auth events", len(missing_auth_events))
@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler):
#
# we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
have_events = await self.store.have_seen_events(missing_auth)
have_events = await self.store.have_seen_events(event.room_id, missing_auth)
logger.debug("Events %s are in the store", have_events)
missing_auth.difference_update(have_events)
@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler):
return context
seen_remotes = await self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
event.room_id, [e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:

View file

@ -26,7 +26,6 @@ from synapse.api.constants import (
HistoryVisibility,
Membership,
)
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
@ -456,16 +455,16 @@ class SpaceSummaryHandler:
return True
# Otherwise, check if they should be allowed access via membership in a space.
try:
await self._event_auth_handler.check_restricted_join_rules(
state_ids, room_version, requester, member_event
if self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
allowed_spaces = (
await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
)
except AuthError:
# The user doesn't have access due to spaces, but might have access
# another way. Keep trying.
pass
else:
return True
if await self._event_auth_handler.is_user_in_rooms(
allowed_spaces, requester
):
return True
# If this is a request over federation, check if the host is in the room or
# is in one of the spaces specified via the join rules.

View file

@ -464,7 +464,7 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids = frozenset() # type: FrozenSet[str]
if any(e.is_state() for e in recents):
current_state_ids_map = await self.state.get_current_state_ids(
current_state_ids_map = await self.store.get_current_state_ids(
room_id
)
current_state_ids = frozenset(current_state_ids_map.values())
@ -524,7 +524,7 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids_map = await self.state.get_current_state_ids(
current_state_ids_map = await self.store.get_current_state_ids(
room_id
)
current_state_ids = frozenset(current_state_ids_map.values())

View file

@ -15,6 +15,9 @@
""" This module contains base REST classes for constructing REST servlets. """
import logging
from typing import Iterable, List, Optional, Union, overload
from typing_extensions import Literal
from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(
request,
name,
default=None,
required=False,
allowed_values=None,
param_type="string",
encoding="ascii",
name: Union[bytes, str],
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
):
"""
Parse a string parameter from the request query string.
@ -122,18 +124,17 @@ def parse_string(
Args:
request: the twisted HTTP request.
name (bytes|unicode): the name of the query parameter.
default (bytes|unicode|None): value to use if the parameter is absent,
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[bytes|unicode]): List of allowed values for the
allowed_values: List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding (str|None): The encoding to decode the string content with.
encoding : The encoding to decode the string content with.
Returns:
bytes/unicode|None: A string value or the default. Unicode if encoding
A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises:
@ -142,45 +143,105 @@ def parse_string(
is not one of those allowed values.
"""
return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type, encoding
request.args, name, default, required, allowed_values, encoding
)
def parse_string_from_args(
args,
name,
default=None,
required=False,
allowed_values=None,
param_type="string",
encoding="ascii",
):
def _parse_string_value(
value: Union[str, bytes],
allowed_values: Optional[Iterable[str]],
name: str,
encoding: Optional[str],
) -> Union[str, bytes]:
if encoding:
try:
value = value.decode(encoding)
except ValueError:
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name,
", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
return value
@overload
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Literal[None] = None,
) -> Optional[List[bytes]]:
...
@overload
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
) -> Optional[List[Union[bytes, str]]]:
"""
Parse a string parameter from the request query string list.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args:
args: the twisted HTTP request.args list.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required : whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[bytes|unicode]): List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the string content with.
Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
if not isinstance(name, bytes):
name = name.encode("ascii")
if name in args:
value = args[name][0]
values = args[name]
if encoding:
try:
value = value.decode(encoding)
except ValueError:
raise SynapseError(
400, "Query parameter %r must be %s" % (name, encoding)
)
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name,
", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
return value
return [
_parse_string_value(value, allowed_values, name=name, encoding=encoding)
for value in values
]
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
message = "Missing string query parameter %r" % (name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else:
@ -190,6 +251,55 @@ def parse_string_from_args(
return default
def parse_string_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
) -> Optional[Union[bytes, str]]:
"""
Parse the string parameter from the request query string list
and return the first result.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args:
args: the twisted HTTP request.args list.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values: List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the string content with.
Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
strings = parse_strings_from_args(
args,
name,
default=[default],
required=required,
allowed_values=allowed_values,
encoding=encoding,
)
return strings[0]
def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request.
@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
try:
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s", e)
logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content

View file

@ -265,6 +265,12 @@ class SynapseTags:
# Whether the sync response has new data to be returned to the client.
SYNC_RESULT = "sync.new_data"
# incoming HTTP request ID (as written in the logs)
REQUEST_ID = "request_id"
# HTTP request tag (used to distinguish full vs incremental syncs, etc)
REQUEST_TAG = "request_tag"
# Block everything by default
# A regex which matches the server_names to expose traces for.
@ -588,7 +594,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
span = opentracing.tracer.active_span
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
headers.addRawHeaders(key, value)
@ -625,7 +631,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
headers[key.encode()] = [value.encode()]
@ -659,7 +665,7 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
return
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
@ -681,7 +687,7 @@ def get_active_span_text_map(destination=None):
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
return carrier
@ -696,7 +702,7 @@ def active_span_context_as_string():
carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
)
return json_encoder.encode(carrier)
@ -824,7 +830,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
return
request_tags = {
"request_id": request.get_request_id(),
SynapseTags.REQUEST_ID: request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
@ -833,9 +839,9 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
request_name = request.request_metrics.name
if extract_context:
scope = start_active_span_from_request(request, request_name, tags=request_tags)
scope = start_active_span_from_request(request, request_name)
else:
scope = start_active_span(request_name, tags=request_tags)
scope = start_active_span(request_name)
with scope:
try:
@ -845,4 +851,11 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
# with JsonResource).
scope.span.set_operation_name(request.request_metrics.name)
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
# set the tags *after* the servlet completes, in case it decided to
# prioritise the span (tags will get dropped on unprioritised spans)
request_tags[
SynapseTags.REQUEST_TAG
] = request.request_metrics.start_context.tag
for k, v in request_tags.items():
scope.span.set_tag(k, v)

View file

@ -22,7 +22,11 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
from synapse.logging.opentracing import (
SynapseTags,
noop_context_manager,
start_active_span,
)
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
@ -202,7 +206,9 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
try:
ctx = noop_context_manager()
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": str(context)})
ctx = start_active_span(
desc, tags={SynapseTags.REQUEST_ID: str(context)}
)
with ctx:
return await maybe_awaitable(func(*args, **kwargs))
except Exception:

View file

@ -1061,15 +1061,15 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomTypingRestServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
RoomSpaceSummaryRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
RoomAliasListServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
if not is_worker:
RoomCreateRestServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server)
RoomAliasListServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):

View file

@ -16,11 +16,7 @@ import logging
from http import HTTPStatus
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("reason", "score"))
if not isinstance(body["reason"], str):
if not isinstance(body.get("reason", ""), str):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'reason' must be a string",
Codes.BAD_JSON,
)
if not isinstance(body["score"], int):
if not isinstance(body.get("score", 0), int):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'score' must be an integer",
@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet):
room_id=room_id,
event_id=event_id,
user_id=user_id,
reason=body["reason"],
reason=body.get("reason"),
content=body,
received_ts=self.clock.time_msec(),
)

View file

@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@ -210,7 +211,13 @@ class RemoteKey(DirectServeJsonResource):
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await yieldable_gather_results(
lambda t: self.fetcher.get_keys(*t),
(
(server_name, list(keys), 0)
for server_name, keys in cache_misses.items()
),
)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []

View file

@ -168,6 +168,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled,
):
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
self.get_latest_event_ids_in_room.invalidate((room_id,))

View file

@ -22,6 +22,7 @@ from typing import (
Iterable,
List,
Optional,
Set,
Tuple,
overload,
)
@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
async def have_seen_events(self, event_ids):
async def have_seen_events(
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
"""Given a list of event ids, check if we have already processed them.
The room_id is only used to structure the cache (so that it can later be
invalidated by room_id) - there is no guarantee that the events are actually
in the room in question.
Args:
event_ids (iterable[str]):
room_id: Room we are polling
event_ids: events we are looking for
Returns:
set[str]: The events we have already seen.
"""
# if the event cache contains the event, obviously we've seen it.
results = {x for x in event_ids if self._get_event_cache.contains(x)}
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
)
return {eid for ((_rid, eid), have_event) in res.items() if have_event}
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE "
@cachedList("have_seen_event", "keys")
async def _have_seen_events_dict(
self, keys: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
"""Helper for have_seen_events
Returns:
a dict {(room_id, event_id)-> bool}
"""
# if the event cache contains the event, obviously we've seen it.
cache_results = {
(rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
}
results = {x: True for x in cache_results}
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
# We therefore pull the events from the database into a set...
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", chunk
txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
)
txn.execute(sql + clause, args)
results.update(row[0] for row in txn)
found_events = {eid for eid, in txn}
for chunk in batch_iter((x for x in event_ids if x not in results), 100):
# ... and then we can update the results for each row in the batch
results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
# each batch requires its own index scan, so we make the batches as big as
# possible.
for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str):
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.

View file

@ -16,14 +16,14 @@ import logging
from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]:
@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"DELETE FROM event_to_state_groups "
"WHERE event_id IN (SELECT event_id from events_to_purge)"
)
for event_id, _ in event_rows:
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
# Delete all remote non-state events
for table in (
@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# so make sure to keep this actually last.
txn.execute("DROP TABLE events_to_purge")
for event_id, should_delete in event_rows:
self._invalidate_cache_and_stream(
txn, self._get_state_group_for_event, (event_id,)
)
# XXX: This is racy, since have_seen_events could be called between the
# transaction completing and the invalidation running. On the other hand,
# that's no different to calling `have_seen_events` just before the
# event is deleted from the database.
if should_delete:
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
logger.info("[purge] done")
return referenced_state_groups
@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# index on them. In any case we should be clearing out 'stream' tables
# periodically anyway (#5888)
# TODO: we could probably usefully do a bunch of cache invalidation here
# TODO: we could probably usefully do a bunch more cache invalidation here
# XXX: as with purge_history, this is racy, but no worse than other races
# that already exist.
self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
logger.info("[purge] done")

View file

@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
room_id: str,
event_id: str,
user_id: str,
reason: str,
reason: Optional[str],
content: JsonDict,
received_ts: int,
) -> None:

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Dict, List
from unittest.mock import Mock
import attr
@ -21,7 +22,6 @@ import signedjson.sign
from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# deferred completes.
first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch):
self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
async def first_lookup_fetch(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
# self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(server_name, "server10")
self.assertEqual(key_ids, [get_key_id(key1)])
self.assertEqual(minimum_valid_until_ts, 0)
await make_deferred_yieldable(first_lookup_deferred)
return {
"server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
}
}
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup():
with LoggingContext("context_11", request=FakeRequest("context_11")):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
[("server10", json1, 0), ("server11", {}, 0)]
)
# the unsigned json should be rejected pretty quickly
@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d0 = ensureDeferred(first_lookup())
self.pump()
mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch):
self.assertEquals(current_context().request.id, "context_12")
return {
"server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
}
}
async def second_lookup_fetch(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
# self.assertEquals(current_context().request.id, "context_12")
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
mock_fetcher.get_keys.reset_mock()
mock_fetcher.get_keys.side_effect = second_lookup_fetch
@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
async def second_lookup():
with LoggingContext("context_12", request=FakeRequest("context_12")):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
[
(
"server10",
json1,
0,
)
]
)
res_deferreds_2[0].addBoth(self.check_context, None)
second_lookup_state[0] = 1
@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
d = kr.verify_json_for_server("server9", {}, 0)
self.get_failure(d, SynapseError)
# should succeed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
d = kr.verify_json_for_server("server9", json1, 500)
# self.assertFalse(d.called)
self.get_success(d)
@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
d = kr.verify_json_for_server("server9", {}, 0)
self.get_failure(d, SynapseError)
# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
d = _verify_json_for_server(
kr, "server9", json1, 500, "test signed non-zero min"
)
d = kr.verify_json_for_server("server9", json1, 500)
self.get_failure(d, SynapseError)
# We expect the keyring tried to refetch the key once.
mock_fetcher.get_keys.assert_called_once_with(
{"server9": {get_key_id(key1): 500}}
"server9", [get_key_id(key1)], 500
)
# should succeed on a signed object with a 0 minimum_valid_until_ms
d = _verify_json_for_server(
kr, "server9", json1, 0, "test signed with zero min"
d = kr.verify_json_for_server(
"server9",
json1,
0,
)
self.get_success(d)
@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
async def get_keys(keys_to_fetch):
async def get_keys(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
self.assertEqual(server_name, "server1")
self.assertEqual(key_ids, [get_key_id(key1)])
self.assertEqual(minimum_valid_until_ts, 1500)
return {
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the first request should succeed; the second should fail because the key
# has expired
results = kr.verify_json_objects_for_server(
[("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
[
(
"server1",
json1,
500,
),
("server1", json1, 1500),
]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return {
"server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
}
async def get_keys1(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
self.assertEqual(server_name, "server1")
self.assertEqual(key_ids, [get_key_id(key1)])
self.assertEqual(minimum_valid_until_ts, 1500)
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return {
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
async def get_keys2(
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
self.assertEqual(server_name, "server1")
self.assertEqual(key_ids, [get_key_id(key1)])
self.assertEqual(minimum_valid_until_ts, 1500)
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
mock_fetcher1 = Mock()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
signedjson.sign.sign_json(json1, "server1", key1)
results = kr.verify_json_objects_for_server(
[("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
[
(
"server1",
json1,
1200,
),
(
"server1",
json1,
1500,
),
]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id]
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
k = keys[testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertEqual(keys, {})
@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertIn(testverifykey_id, keys)
k = keys[testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertIn(testverifykey_id, keys)
k = keys[testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
return self.get_success(fetcher.get_keys(keys_to_fetch))
return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
# start with a valid response so we can check we are testing the right thing
response = build_response()
keys = get_key_from_perspectives(response)
k = keys[SERVER_NAME][testverifykey_id]
k = keys[testverifykey_id]
self.assertEqual(k.verify_key, testverifykey)
# remove the perspectives server's signature
@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_id(key):
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
return "%s:%s" % (key.alg, key.version)
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
return rv
def _verify_json_for_server(kr, *args):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
@defer.inlineCallbacks
def v():
rv1 = yield kr.verify_json_for_server(*args)
return rv1
return run_in_context(v)

View file

@ -64,7 +64,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
user_tok=self.admin_user_tok,
)
for _ in range(5):
self._create_event_and_report(
self._create_event_and_report_without_parameters(
room_id=self.room_id2,
user_tok=self.admin_user_tok,
)
@ -378,6 +378,19 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _create_event_and_report_without_parameters(self, room_id, user_tok):
"""Create and report an event, but omit reason and score"""
resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"]
channel = self.make_request(
"POST",
"rooms/%s/report/%s" % (room_id, event_id),
json.dumps({}),
access_token=user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
"""Checks that all attributes are present in an event report"""
for c in content:

View file

@ -0,0 +1,83 @@
# Copyright 2021 Callum Brown
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import report_event
from tests import unittest
class ReportEventTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
report_event.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_tok = self.login("user", "pass")
self.room_id = self.helper.create_room_as(
self.other_user, tok=self.other_user_tok, is_public=True
)
self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok)
resp = self.helper.send(self.room_id, tok=self.admin_user_tok)
self.event_id = resp["event_id"]
self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id)
def test_reason_str_and_score_int(self):
data = {"reason": "this makes me sad", "score": -100}
self._assert_status(200, data)
def test_no_reason(self):
data = {"score": 0}
self._assert_status(200, data)
def test_no_score(self):
data = {"reason": "this makes me sad"}
self._assert_status(200, data)
def test_no_reason_and_no_score(self):
data = {}
self._assert_status(200, data)
def test_reason_int_and_score_str(self):
data = {"reason": 10, "score": "string"}
self._assert_status(400, data)
def test_reason_zero_and_score_blank(self):
data = {"reason": 0, "score": ""}
self._assert_status(400, data)
def test_reason_and_score_null(self):
data = {"reason": None, "score": None}
self._assert_status(400, data)
def _assert_status(self, response_status, data):
channel = self.make_request(
"POST",
self.report_path,
json.dumps(data),
access_token=self.other_user_tok,
)
self.assertEqual(
response_status, int(channel.result["code"]), msg=channel.result["body"]
)

View file

@ -208,10 +208,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
keyid = "ed25519:%s" % (testkey.version,)
fetcher = PerspectivesKeyFetcher(self.hs2)
d = fetcher.get_keys({"targetserver": {keyid: 1000}})
d = fetcher.get_keys("targetserver", [keyid], 1000)
res = self.get_success(d)
self.assertIn("targetserver", res)
keyres = res["targetserver"][keyid]
self.assertIn(keyid, res)
keyres = res[keyid]
assert isinstance(keyres, FetchKeyResult)
self.assertEqual(
signedjson.key.encode_verify_key_base64(keyres.verify_key),
@ -230,10 +230,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
keyid = "ed25519:%s" % (testkey.version,)
fetcher = PerspectivesKeyFetcher(self.hs2)
d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
res = self.get_success(d)
self.assertIn(self.hs.hostname, res)
keyres = res[self.hs.hostname][keyid]
self.assertIn(keyid, res)
keyres = res[keyid]
assert isinstance(keyres, FetchKeyResult)
self.assertEqual(
signedjson.key.encode_verify_key_base64(keyres.verify_key),
@ -247,10 +247,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
keyid = "ed25519:%s" % (self.hs_signing_key.version,)
fetcher = PerspectivesKeyFetcher(self.hs2)
d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
res = self.get_success(d)
self.assertIn(self.hs.hostname, res)
keyres = res[self.hs.hostname][keyid]
self.assertIn(keyid, res)
keyres = res[keyid]
assert isinstance(keyres, FetchKeyResult)
self.assertEqual(
signedjson.key.encode_verify_key_base64(keyres.verify_key),

View file

@ -0,0 +1,13 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,13 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,96 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from synapse.logging.context import LoggingContext
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from tests import unittest
class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store: EventsWorkerStore = hs.get_datastore()
# insert some test data
for rid in ("room1", "room2"):
self.get_success(
self.store.db_pool.simple_insert(
"rooms",
{"room_id": rid, "room_version": 4},
)
)
for idx, (rid, eid) in enumerate(
(
("room1", "event10"),
("room1", "event11"),
("room1", "event12"),
("room2", "event20"),
)
):
self.get_success(
self.store.db_pool.simple_insert(
"events",
{
"event_id": eid,
"room_id": rid,
"topological_ordering": idx,
"stream_ordering": idx,
"type": "test",
"processed": True,
"outlier": False,
},
)
)
self.get_success(
self.store.db_pool.simple_insert(
"event_json",
{
"event_id": eid,
"room_id": rid,
"json": json.dumps({"type": "test", "room_id": rid}),
"internal_metadata": "{}",
"format_version": 3,
},
)
)
def test_simple(self):
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events("room1", ["event10", "event19"])
)
self.assertEquals(res, {"event10"})
# that should result in a single db query
self.assertEquals(ctx.get_resource_usage().db_txn_count, 1)
# a second lookup of the same events should cause no queries
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events("room1", ["event10", "event19"])
)
self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
def test_query_via_event_cache(self):
# fetch an event into the event cache
self.get_success(self.store.get_event("event10"))
# looking it up should now cause no db hits
with LoggingContext(name="test") as ctx:
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)

View file

@ -45,37 +45,32 @@ class BatchingQueueTestCase(TestCase):
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)
def _get_sample_with_name(self, metric, name) -> int:
"""For a prometheus metric get the value of the sample that has a
matching "name" label.
"""
for sample in metric.collect()[0].samples:
if sample.labels.get("name") == name:
return sample.value
self.fail("Found no matching sample")
def _assert_metrics(self, queued, keys, in_flight):
"""Assert that the metrics are correct"""
self.assertEqual(len(number_queued.collect()), 1)
self.assertEqual(len(number_queued.collect()[0].samples), 1)
sample = self._get_sample_with_name(number_queued, self.queue._name)
self.assertEqual(
number_queued.collect()[0].samples[0].labels,
{"name": self.queue._name},
)
self.assertEqual(
number_queued.collect()[0].samples[0].value,
sample,
queued,
"number_queued",
)
self.assertEqual(len(number_of_keys.collect()), 1)
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
)
self.assertEqual(
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
)
sample = self._get_sample_with_name(number_of_keys, self.queue._name)
self.assertEqual(sample, keys, "number_of_keys")
self.assertEqual(len(number_in_flight.collect()), 1)
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
sample = self._get_sample_with_name(number_in_flight, self.queue._name)
self.assertEqual(
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
)
self.assertEqual(
number_in_flight.collect()[0].samples[0].value,
sample,
in_flight,
"number_in_flight",
)