Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
09361655d2
11
CHANGES.md
11
CHANGES.md
|
@ -1,5 +1,12 @@
|
||||||
Synapse 1.35.0rc3 (2021-05-28)
|
Synapse 1.35.0 (2021-06-01)
|
||||||
==============================
|
===========================
|
||||||
|
|
||||||
|
Note that [the tag](https://github.com/matrix-org/synapse/releases/tag/v1.35.0rc3) and [docker images](https://hub.docker.com/layers/matrixdotorg/synapse/v1.35.0rc3/images/sha256-34ccc87bd99a17e2cbc0902e678b5937d16bdc1991ead097eee6096481ecf2c4?context=explore) for `v1.35.0rc3` were incorrectly built. If you are experiencing issues with either, it is recommended to upgrade to the equivalent tag or docker image for the `v1.35.0` release.
|
||||||
|
|
||||||
|
Deprecations and Removals
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
- The core Synapse development team plan to drop support for the [unstable API of MSC2858](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), including the undocumented `experimental.msc2858_enabled` config option, in August 2021. Client authors should ensure that their clients are updated to use the stable API (which has been supported since Synapse 1.30) well before that time, to give their users time to upgrade. ([\#10101](https://github.com/matrix-org/synapse/issues/10101))
|
||||||
|
|
||||||
Bugfixes
|
Bugfixes
|
||||||
--------
|
--------
|
||||||
|
|
1
changelog.d/10035.feature
Normal file
1
changelog.d/10035.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Rewrite logic around verifying JSON object and fetching server keys to be more performant and use less memory.
|
1
changelog.d/10048.misc
Normal file
1
changelog.d/10048.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add `parse_strings_from_args` for parsing an array from query parameters.
|
1
changelog.d/10074.misc
Normal file
1
changelog.d/10074.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update opentracing to inject the right context into the carrier.
|
1
changelog.d/10077.feature
Normal file
1
changelog.d/10077.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Make reason and score parameters optional for reporting content. Implements [MSC2414](https://github.com/matrix-org/matrix-doc/pull/2414). Contributed by Callum Brown.
|
1
changelog.d/10084.feature
Normal file
1
changelog.d/10084.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for routing more requests to workers.
|
1
changelog.d/10091.misc
Normal file
1
changelog.d/10091.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Log method and path when dropping request due to size limit.
|
1
changelog.d/10092.bugfix
Normal file
1
changelog.d/10092.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug in the `force_tracing_for_users` option introduced in Synapse v1.35 which meant that the OpenTracing spans produced were missing most tags.
|
1
changelog.d/10102.misc
Normal file
1
changelog.d/10102.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Make `/sync` do fewer state resolutions.
|
1
changelog.d/10109.bugfix
Normal file
1
changelog.d/10109.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug introduced in v1.35.0 where invite-only rooms would be shown to users in a space who were not invited.
|
1
changelog.d/9953.feature
Normal file
1
changelog.d/9953.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve performance of incoming federation transactions in large rooms.
|
1
changelog.d/9973.feature
Normal file
1
changelog.d/9973.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve performance of incoming federation transactions in large rooms.
|
|
@ -1 +0,0 @@
|
||||||
Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`.
|
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.35.0) stable; urgency=medium
|
||||||
|
|
||||||
|
* New synapse release 1.35.0.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Tue, 01 Jun 2021 13:23:35 +0100
|
||||||
|
|
||||||
matrix-synapse-py3 (1.34.0) stable; urgency=medium
|
matrix-synapse-py3 (1.34.0) stable; urgency=medium
|
||||||
|
|
||||||
* New synapse release 1.34.0.
|
* New synapse release 1.34.0.
|
||||||
|
|
|
@ -75,9 +75,9 @@ The following fields are returned in the JSON response body:
|
||||||
* `name`: string - The name of the room.
|
* `name`: string - The name of the room.
|
||||||
* `event_id`: string - The ID of the reported event.
|
* `event_id`: string - The ID of the reported event.
|
||||||
* `user_id`: string - This is the user who reported the event and wrote the reason.
|
* `user_id`: string - This is the user who reported the event and wrote the reason.
|
||||||
* `reason`: string - Comment made by the `user_id` in this report. May be blank.
|
* `reason`: string - Comment made by the `user_id` in this report. May be blank or `null`.
|
||||||
* `score`: integer - Content is reported based upon a negative score, where -100 is
|
* `score`: integer - Content is reported based upon a negative score, where -100 is
|
||||||
"most offensive" and 0 is "inoffensive".
|
"most offensive" and 0 is "inoffensive". May be `null`.
|
||||||
* `sender`: string - This is the ID of the user who sent the original message/event that
|
* `sender`: string - This is the ID of the user who sent the original message/event that
|
||||||
was reported.
|
was reported.
|
||||||
* `canonical_alias`: string - The canonical alias of the room. `null` if the room does not
|
* `canonical_alias`: string - The canonical alias of the room. `null` if the room does not
|
||||||
|
|
|
@ -228,6 +228,9 @@ expressions:
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
|
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
|
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
|
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/event/
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/joined_rooms$
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/search$
|
||||||
|
|
||||||
# Registration/login requests
|
# Registration/login requests
|
||||||
^/_matrix/client/(api/v1|r0|unstable)/login$
|
^/_matrix/client/(api/v1|r0|unstable)/login$
|
||||||
|
|
|
@ -47,7 +47,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.35.0rc3"
|
__version__ = "1.35.0"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -206,11 +206,11 @@ class Auth:
|
||||||
requester = create_requester(user_id, app_service=app_service)
|
requester = create_requester(user_id, app_service=app_service)
|
||||||
|
|
||||||
request.requester = user_id
|
request.requester = user_id
|
||||||
|
if user_id in self._force_tracing_for_users:
|
||||||
|
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
||||||
opentracing.set_tag("authenticated_entity", user_id)
|
opentracing.set_tag("authenticated_entity", user_id)
|
||||||
opentracing.set_tag("user_id", user_id)
|
opentracing.set_tag("user_id", user_id)
|
||||||
opentracing.set_tag("appservice_id", app_service.id)
|
opentracing.set_tag("appservice_id", app_service.id)
|
||||||
if user_id in self._force_tracing_for_users:
|
|
||||||
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
|
||||||
|
|
||||||
return requester
|
return requester
|
||||||
|
|
||||||
|
@ -259,12 +259,12 @@ class Auth:
|
||||||
)
|
)
|
||||||
|
|
||||||
request.requester = requester
|
request.requester = requester
|
||||||
|
if user_info.token_owner in self._force_tracing_for_users:
|
||||||
|
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
||||||
opentracing.set_tag("authenticated_entity", user_info.token_owner)
|
opentracing.set_tag("authenticated_entity", user_info.token_owner)
|
||||||
opentracing.set_tag("user_id", user_info.user_id)
|
opentracing.set_tag("user_id", user_info.user_id)
|
||||||
if device_id:
|
if device_id:
|
||||||
opentracing.set_tag("device_id", device_id)
|
opentracing.set_tag("device_id", device_id)
|
||||||
if user_info.token_owner in self._force_tracing_for_users:
|
|
||||||
opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
|
|
||||||
|
|
||||||
return requester
|
return requester
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|
|
@ -109,7 +109,7 @@ from synapse.storage.databases.main.monthly_active_users import (
|
||||||
MonthlyActiveUsersWorkerStore,
|
MonthlyActiveUsersWorkerStore,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.presence import PresenceStore
|
from synapse.storage.databases.main.presence import PresenceStore
|
||||||
from synapse.storage.databases.main.search import SearchWorkerStore
|
from synapse.storage.databases.main.search import SearchStore
|
||||||
from synapse.storage.databases.main.stats import StatsStore
|
from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
||||||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
||||||
|
@ -242,7 +242,7 @@ class GenericWorkerSlavedStore(
|
||||||
MonthlyActiveUsersWorkerStore,
|
MonthlyActiveUsersWorkerStore,
|
||||||
MediaRepositoryStore,
|
MediaRepositoryStore,
|
||||||
ServerMetricsStore,
|
ServerMetricsStore,
|
||||||
SearchWorkerStore,
|
SearchStore,
|
||||||
TransactionWorkerStore,
|
TransactionWorkerStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
):
|
):
|
||||||
|
|
|
@ -16,8 +16,7 @@
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from collections import defaultdict
|
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import (
|
from signedjson.key import (
|
||||||
|
@ -44,17 +43,12 @@ from synapse.api.errors import (
|
||||||
from synapse.config.key import TrustedKeyServer
|
from synapse.config.key import TrustedKeyServer
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import prune_event_dict
|
from synapse.events.utils import prune_event_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
PreserveLoggingContext,
|
|
||||||
make_deferred_yieldable,
|
|
||||||
preserve_fn,
|
|
||||||
run_in_background,
|
|
||||||
)
|
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.batching_queue import BatchingQueue
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -80,32 +74,19 @@ class VerifyJsonRequest:
|
||||||
minimum_valid_until_ts: time at which we require the signing key to
|
minimum_valid_until_ts: time at which we require the signing key to
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
request_name: The name of the request.
|
|
||||||
|
|
||||||
key_ids: The set of key_ids to that could be used to verify the JSON object
|
key_ids: The set of key_ids to that could be used to verify the JSON object
|
||||||
|
|
||||||
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
|
|
||||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
|
||||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
|
||||||
logcontext.
|
|
||||||
|
|
||||||
If we are unable to find a key which satisfies the request, the deferred
|
|
||||||
errbacks with an M_UNAUTHORIZED SynapseError.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
server_name = attr.ib(type=str)
|
server_name = attr.ib(type=str)
|
||||||
get_json_object = attr.ib(type=Callable[[], JsonDict])
|
get_json_object = attr.ib(type=Callable[[], JsonDict])
|
||||||
minimum_valid_until_ts = attr.ib(type=int)
|
minimum_valid_until_ts = attr.ib(type=int)
|
||||||
request_name = attr.ib(type=str)
|
|
||||||
key_ids = attr.ib(type=List[str])
|
key_ids = attr.ib(type=List[str])
|
||||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json_object(
|
def from_json_object(
|
||||||
server_name: str,
|
server_name: str,
|
||||||
json_object: JsonDict,
|
json_object: JsonDict,
|
||||||
minimum_valid_until_ms: int,
|
minimum_valid_until_ms: int,
|
||||||
request_name: str,
|
|
||||||
):
|
):
|
||||||
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
|
||||||
object for the given server.
|
object for the given server.
|
||||||
|
@ -115,7 +96,6 @@ class VerifyJsonRequest:
|
||||||
server_name,
|
server_name,
|
||||||
lambda: json_object,
|
lambda: json_object,
|
||||||
minimum_valid_until_ms,
|
minimum_valid_until_ms,
|
||||||
request_name=request_name,
|
|
||||||
key_ids=key_ids,
|
key_ids=key_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -135,16 +115,48 @@ class VerifyJsonRequest:
|
||||||
# memory than the Event object itself.
|
# memory than the Event object itself.
|
||||||
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
|
||||||
minimum_valid_until_ms,
|
minimum_valid_until_ms,
|
||||||
request_name=event.event_id,
|
|
||||||
key_ids=key_ids,
|
key_ids=key_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_fetch_key_request(self) -> "_FetchKeyRequest":
|
||||||
|
"""Create a key fetch request for all keys needed to satisfy the
|
||||||
|
verification request.
|
||||||
|
"""
|
||||||
|
return _FetchKeyRequest(
|
||||||
|
server_name=self.server_name,
|
||||||
|
minimum_valid_until_ts=self.minimum_valid_until_ts,
|
||||||
|
key_ids=self.key_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KeyLookupError(ValueError):
|
class KeyLookupError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _FetchKeyRequest:
|
||||||
|
"""A request for keys for a given server.
|
||||||
|
|
||||||
|
We will continue to try and fetch until we have all the keys listed under
|
||||||
|
`key_ids` (with an appropriate `valid_until_ts` property) or we run out of
|
||||||
|
places to fetch keys from.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
server_name: The name of the server that owns the keys.
|
||||||
|
minimum_valid_until_ts: The timestamp which the keys must be valid until.
|
||||||
|
key_ids: The IDs of the keys to attempt to fetch
|
||||||
|
"""
|
||||||
|
|
||||||
|
server_name = attr.ib(type=str)
|
||||||
|
minimum_valid_until_ts = attr.ib(type=int)
|
||||||
|
key_ids = attr.ib(type=List[str])
|
||||||
|
|
||||||
|
|
||||||
class Keyring:
|
class Keyring:
|
||||||
|
"""Handles verifying signed JSON objects and fetching the keys needed to do
|
||||||
|
so.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||||
):
|
):
|
||||||
|
@ -158,22 +170,22 @@ class Keyring:
|
||||||
)
|
)
|
||||||
self._key_fetchers = key_fetchers
|
self._key_fetchers = key_fetchers
|
||||||
|
|
||||||
# map from server name to Deferred. Has an entry for each server with
|
self._server_queue = BatchingQueue(
|
||||||
# an ongoing key download; the Deferred completes once the download
|
"keyring_server",
|
||||||
# completes.
|
clock=hs.get_clock(),
|
||||||
#
|
process_batch_callback=self._inner_fetch_key_requests,
|
||||||
# These are regular, logcontext-agnostic Deferreds.
|
) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
|
||||||
self.key_downloads = {} # type: Dict[str, defer.Deferred]
|
|
||||||
|
|
||||||
def verify_json_for_server(
|
async def verify_json_for_server(
|
||||||
self,
|
self,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
json_object: JsonDict,
|
json_object: JsonDict,
|
||||||
validity_time: int,
|
validity_time: int,
|
||||||
request_name: str,
|
) -> None:
|
||||||
) -> defer.Deferred:
|
|
||||||
"""Verify that a JSON object has been signed by a given server
|
"""Verify that a JSON object has been signed by a given server
|
||||||
|
|
||||||
|
Completes if the the object was correctly signed, otherwise raises.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name: name of the server which must have signed this object
|
server_name: name of the server which must have signed this object
|
||||||
|
|
||||||
|
@ -181,52 +193,45 @@ class Keyring:
|
||||||
|
|
||||||
validity_time: timestamp at which we require the signing key to
|
validity_time: timestamp at which we require the signing key to
|
||||||
be valid. (0 implies we don't care)
|
be valid. (0 implies we don't care)
|
||||||
|
|
||||||
request_name: an identifier for this json object (eg, an event id)
|
|
||||||
for logging.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[None]: completes if the the object was correctly signed, otherwise
|
|
||||||
errbacks with an error
|
|
||||||
"""
|
"""
|
||||||
request = VerifyJsonRequest.from_json_object(
|
request = VerifyJsonRequest.from_json_object(
|
||||||
server_name,
|
server_name,
|
||||||
json_object,
|
json_object,
|
||||||
validity_time,
|
validity_time,
|
||||||
request_name,
|
|
||||||
)
|
)
|
||||||
requests = (request,)
|
return await self.process_request(request)
|
||||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
|
||||||
|
|
||||||
def verify_json_objects_for_server(
|
def verify_json_objects_for_server(
|
||||||
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
|
self, server_and_json: Iterable[Tuple[str, dict, int]]
|
||||||
) -> List[defer.Deferred]:
|
) -> List[defer.Deferred]:
|
||||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||||
necessary.
|
necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_and_json:
|
server_and_json:
|
||||||
Iterable of (server_name, json_object, validity_time, request_name)
|
Iterable of (server_name, json_object, validity_time)
|
||||||
tuples.
|
tuples.
|
||||||
|
|
||||||
validity_time is a timestamp at which the signing key must be
|
validity_time is a timestamp at which the signing key must be
|
||||||
valid.
|
valid.
|
||||||
|
|
||||||
request_name is an identifier for this json object (eg, an event id)
|
|
||||||
for logging.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
List<Deferred[None]>: for each input triplet, a deferred indicating success
|
||||||
or failure to verify each json object's signature for the given
|
or failure to verify each json object's signature for the given
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
logcontext.
|
logcontext.
|
||||||
"""
|
"""
|
||||||
return self._verify_objects(
|
return [
|
||||||
VerifyJsonRequest.from_json_object(
|
run_in_background(
|
||||||
server_name, json_object, validity_time, request_name
|
self.process_request,
|
||||||
|
VerifyJsonRequest.from_json_object(
|
||||||
|
server_name,
|
||||||
|
json_object,
|
||||||
|
validity_time,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for server_name, json_object, validity_time, request_name in server_and_json
|
for server_name, json_object, validity_time in server_and_json
|
||||||
)
|
]
|
||||||
|
|
||||||
def verify_events_for_server(
|
def verify_events_for_server(
|
||||||
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
|
||||||
|
@ -252,321 +257,223 @@ class Keyring:
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
logcontext.
|
logcontext.
|
||||||
"""
|
"""
|
||||||
return self._verify_objects(
|
return [
|
||||||
VerifyJsonRequest.from_event(server_name, event, validity_time)
|
run_in_background(
|
||||||
|
self.process_request,
|
||||||
|
VerifyJsonRequest.from_event(
|
||||||
|
server_name,
|
||||||
|
event,
|
||||||
|
validity_time,
|
||||||
|
),
|
||||||
|
)
|
||||||
for server_name, event, validity_time in server_and_events
|
for server_name, event, validity_time in server_and_events
|
||||||
|
]
|
||||||
|
|
||||||
|
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
|
||||||
|
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed
|
||||||
|
by the server, the signatures don't match or we failed to fetch the
|
||||||
|
necessary keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not verify_request.key_ids:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
f"Not signed by {verify_request.server_name}",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the keys we need to verify to the queue for retrieval. We queue
|
||||||
|
# up requests for the same server so we don't end up with many in flight
|
||||||
|
# requests for the same keys.
|
||||||
|
key_request = verify_request.to_fetch_key_request()
|
||||||
|
found_keys_by_server = await self._server_queue.add_to_queue(
|
||||||
|
key_request, key=verify_request.server_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def _verify_objects(
|
# Since we batch up requests the returned set of keys may contain keys
|
||||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
# from other servers, so we pull out only the ones we care about.s
|
||||||
) -> List[defer.Deferred]:
|
found_keys = found_keys_by_server.get(verify_request.server_name, {})
|
||||||
"""Does the work of verify_json_[objects_]for_server
|
|
||||||
|
|
||||||
|
# Verify each signature we got valid keys for, raising if we can't
|
||||||
|
# verify any of them.
|
||||||
|
verified = False
|
||||||
|
for key_id in verify_request.key_ids:
|
||||||
|
key_result = found_keys.get(key_id)
|
||||||
|
if not key_result:
|
||||||
|
continue
|
||||||
|
|
||||||
Args:
|
if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
|
||||||
verify_requests: Iterable of verification requests.
|
continue
|
||||||
|
|
||||||
Returns:
|
verify_key = key_result.verify_key
|
||||||
List<Deferred[None]>: for each input item, a deferred indicating success
|
json_object = verify_request.get_json_object()
|
||||||
or failure to verify each json object's signature for the given
|
try:
|
||||||
server_name. The deferreds run their callbacks in the sentinel
|
verify_signed_json(
|
||||||
logcontext.
|
json_object,
|
||||||
"""
|
verify_request.server_name,
|
||||||
# a list of VerifyJsonRequests which are awaiting a key lookup
|
verify_key,
|
||||||
key_lookups = []
|
)
|
||||||
handle = preserve_fn(_handle_key_deferred)
|
verified = True
|
||||||
|
except SignatureVerifyException as e:
|
||||||
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
|
logger.debug(
|
||||||
"""Process an entry in the request list
|
"Error verifying signature for %s:%s:%s with key %s: %s",
|
||||||
|
verify_request.server_name,
|
||||||
Adds a key request to key_lookups, and returns a deferred which
|
verify_key.alg,
|
||||||
will complete or fail (in the sentinel context) when verification completes.
|
verify_key.version,
|
||||||
"""
|
encode_verify_key_base64(verify_key),
|
||||||
if not verify_request.key_ids:
|
str(e),
|
||||||
return defer.fail(
|
)
|
||||||
SynapseError(
|
raise SynapseError(
|
||||||
400,
|
401,
|
||||||
"Not signed by %s" % (verify_request.server_name,),
|
"Invalid signature for server %s with key %s:%s: %s"
|
||||||
Codes.UNAUTHORIZED,
|
% (
|
||||||
)
|
verify_request.server_name,
|
||||||
|
verify_key.alg,
|
||||||
|
verify_key.version,
|
||||||
|
str(e),
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
if not verified:
|
||||||
"Verifying %s for %s with key_ids %s, min_validity %i",
|
raise SynapseError(
|
||||||
verify_request.request_name,
|
401,
|
||||||
|
f"Failed to find any key to satisfy: {key_request}",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _inner_fetch_key_requests(
|
||||||
|
self, requests: List[_FetchKeyRequest]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
|
"""Processing function for the queue of `_FetchKeyRequest`."""
|
||||||
|
|
||||||
|
logger.debug("Starting fetch for %s", requests)
|
||||||
|
|
||||||
|
# First we need to deduplicate requests for the same key. We do this by
|
||||||
|
# taking the *maximum* requested `minimum_valid_until_ts` for each pair
|
||||||
|
# of server name/key ID.
|
||||||
|
server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]]
|
||||||
|
for request in requests:
|
||||||
|
by_server = server_to_key_to_ts.setdefault(request.server_name, {})
|
||||||
|
for key_id in request.key_ids:
|
||||||
|
existing_ts = by_server.get(key_id, 0)
|
||||||
|
by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
|
||||||
|
|
||||||
|
deduped_requests = [
|
||||||
|
_FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
|
||||||
|
for server_name, by_server in server_to_key_to_ts.items()
|
||||||
|
for key_id, minimum_valid_ts in by_server.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug("Deduplicated key requests to %s", deduped_requests)
|
||||||
|
|
||||||
|
# For each key we call `_inner_verify_request` which will handle
|
||||||
|
# fetching each key. Note these shouldn't throw if we fail to contact
|
||||||
|
# other servers etc.
|
||||||
|
results_per_request = await yieldable_gather_results(
|
||||||
|
self._inner_fetch_key_request,
|
||||||
|
deduped_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We now convert the returned list of results into a map from server
|
||||||
|
# name to key ID to FetchKeyResult, to return.
|
||||||
|
to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||||
|
for (request, results) in zip(deduped_requests, results_per_request):
|
||||||
|
to_return_by_server = to_return.setdefault(request.server_name, {})
|
||||||
|
for key_id, key_result in results.items():
|
||||||
|
existing = to_return_by_server.get(key_id)
|
||||||
|
if not existing or existing.valid_until_ts < key_result.valid_until_ts:
|
||||||
|
to_return_by_server[key_id] = key_result
|
||||||
|
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
async def _inner_fetch_key_request(
|
||||||
|
self, verify_request: _FetchKeyRequest
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
|
"""Attempt to fetch the given key by calling each key fetcher one by
|
||||||
|
one.
|
||||||
|
"""
|
||||||
|
logger.debug("Starting fetch for %s", verify_request)
|
||||||
|
|
||||||
|
found_keys: Dict[str, FetchKeyResult] = {}
|
||||||
|
missing_key_ids = set(verify_request.key_ids)
|
||||||
|
|
||||||
|
for fetcher in self._key_fetchers:
|
||||||
|
if not missing_key_ids:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug("Getting keys from %s for %s", fetcher, verify_request)
|
||||||
|
keys = await fetcher.get_keys(
|
||||||
verify_request.server_name,
|
verify_request.server_name,
|
||||||
verify_request.key_ids,
|
list(missing_key_ids),
|
||||||
verify_request.minimum_valid_until_ts,
|
verify_request.minimum_valid_until_ts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add the key request to the queue, but don't start it off yet.
|
for key_id, key in keys.items():
|
||||||
key_lookups.append(verify_request)
|
if not key:
|
||||||
|
|
||||||
# now run _handle_key_deferred, which will wait for the key request
|
|
||||||
# to complete and then do the verification.
|
|
||||||
#
|
|
||||||
# We want _handle_key_request to log to the right context, so we
|
|
||||||
# wrap it with preserve_fn (aka run_in_background)
|
|
||||||
return handle(verify_request)
|
|
||||||
|
|
||||||
results = [process(r) for r in verify_requests]
|
|
||||||
|
|
||||||
if key_lookups:
|
|
||||||
run_in_background(self._start_key_lookups, key_lookups)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def _start_key_lookups(
|
|
||||||
self, verify_requests: List[VerifyJsonRequest]
|
|
||||||
) -> None:
|
|
||||||
"""Sets off the key fetches for each verify request
|
|
||||||
|
|
||||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verify_requests:
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# map from server name to a set of outstanding request ids
|
|
||||||
server_to_request_ids = {} # type: Dict[str, Set[int]]
|
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
request_id = id(verify_request)
|
|
||||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
|
||||||
|
|
||||||
# Wait for any previous lookups to complete before proceeding.
|
|
||||||
await self.wait_for_previous_lookups(server_to_request_ids.keys())
|
|
||||||
|
|
||||||
# take out a lock on each of the servers by sticking a Deferred in
|
|
||||||
# key_downloads
|
|
||||||
for server_name in server_to_request_ids.keys():
|
|
||||||
self.key_downloads[server_name] = defer.Deferred()
|
|
||||||
logger.debug("Got key lookup lock on %s", server_name)
|
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
|
||||||
# drop the lock by resolving the deferred in key_downloads.
|
|
||||||
def drop_server_lock(server_name):
|
|
||||||
d = self.key_downloads.pop(server_name)
|
|
||||||
d.callback(None)
|
|
||||||
|
|
||||||
def lookup_done(res, verify_request):
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
server_requests = server_to_request_ids[server_name]
|
|
||||||
server_requests.remove(id(verify_request))
|
|
||||||
|
|
||||||
# if there are no more requests for this server, we can drop the lock.
|
|
||||||
if not server_requests:
|
|
||||||
logger.debug("Releasing key lookup lock on %s", server_name)
|
|
||||||
drop_server_lock(server_name)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
verify_request.key_ready.addBoth(lookup_done, verify_request)
|
|
||||||
|
|
||||||
# Actually start fetching keys.
|
|
||||||
self._get_server_verify_keys(verify_requests)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error starting key lookups")
|
|
||||||
|
|
||||||
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
|
|
||||||
"""Waits for any previous key lookups for the given servers to finish.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_names: list of servers which we want to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resolves once all key lookups for the given servers have
|
|
||||||
completed. Follows the synapse rules of logcontext preservation.
|
|
||||||
"""
|
|
||||||
loop_count = 1
|
|
||||||
while True:
|
|
||||||
wait_on = [
|
|
||||||
(server_name, self.key_downloads[server_name])
|
|
||||||
for server_name in server_names
|
|
||||||
if server_name in self.key_downloads
|
|
||||||
]
|
|
||||||
if not wait_on:
|
|
||||||
break
|
|
||||||
logger.info(
|
|
||||||
"Waiting for existing lookups for %s to complete [loop %i]",
|
|
||||||
[w[0] for w in wait_on],
|
|
||||||
loop_count,
|
|
||||||
)
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
await defer.DeferredList((w[1] for w in wait_on))
|
|
||||||
|
|
||||||
loop_count += 1
|
|
||||||
|
|
||||||
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
|
|
||||||
"""Tries to find at least one key for each verify request
|
|
||||||
|
|
||||||
For each verify_request, verify_request.key_ready is called back with
|
|
||||||
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
|
|
||||||
with a SynapseError if none of the keys are found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verify_requests: list of verify requests
|
|
||||||
"""
|
|
||||||
|
|
||||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
|
||||||
|
|
||||||
async def do_iterations():
|
|
||||||
try:
|
|
||||||
with Measure(self.clock, "get_server_verify_keys"):
|
|
||||||
for f in self._key_fetchers:
|
|
||||||
if not remaining_requests:
|
|
||||||
return
|
|
||||||
await self._attempt_key_fetches_with_fetcher(
|
|
||||||
f, remaining_requests
|
|
||||||
)
|
|
||||||
|
|
||||||
# look for any requests which weren't satisfied
|
|
||||||
while remaining_requests:
|
|
||||||
verify_request = remaining_requests.pop()
|
|
||||||
rq_str = (
|
|
||||||
"VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
|
|
||||||
% (
|
|
||||||
verify_request.server_name,
|
|
||||||
verify_request.key_ids,
|
|
||||||
verify_request.minimum_valid_until_ts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we run the errback immediately, it may cancel our
|
|
||||||
# loggingcontext while we are still in it, so instead we
|
|
||||||
# schedule it for the next time round the reactor.
|
|
||||||
#
|
|
||||||
# (this also ensures that we don't get a stack overflow if we
|
|
||||||
# has a massive queue of lookups waiting for this server).
|
|
||||||
self.clock.call_later(
|
|
||||||
0,
|
|
||||||
verify_request.key_ready.errback,
|
|
||||||
SynapseError(
|
|
||||||
401,
|
|
||||||
"Failed to find any key to satisfy %s" % (rq_str,),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
# we don't really expect to get here, because any errors should already
|
|
||||||
# have been caught and logged. But if we do, let's log the error and make
|
|
||||||
# sure that all of the deferreds are resolved.
|
|
||||||
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
if not verify_request.key_ready.called:
|
|
||||||
verify_request.key_ready.errback(err)
|
|
||||||
|
|
||||||
run_in_background(do_iterations)
|
|
||||||
|
|
||||||
async def _attempt_key_fetches_with_fetcher(
|
|
||||||
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
|
|
||||||
):
|
|
||||||
"""Use a key fetcher to attempt to satisfy some key requests
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fetcher: fetcher to use to fetch the keys
|
|
||||||
remaining_requests: outstanding key requests.
|
|
||||||
Any successfully-completed requests will be removed from the list.
|
|
||||||
"""
|
|
||||||
# The keys to fetch.
|
|
||||||
# server_name -> key_id -> min_valid_ts
|
|
||||||
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
|
|
||||||
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
# any completed requests should already have been removed
|
|
||||||
assert not verify_request.key_ready.called
|
|
||||||
keys_for_server = missing_keys[verify_request.server_name]
|
|
||||||
|
|
||||||
for key_id in verify_request.key_ids:
|
|
||||||
# If we have several requests for the same key, then we only need to
|
|
||||||
# request that key once, but we should do so with the greatest
|
|
||||||
# min_valid_until_ts of the requests, so that we can satisfy all of
|
|
||||||
# the requests.
|
|
||||||
keys_for_server[key_id] = max(
|
|
||||||
keys_for_server.get(key_id, -1),
|
|
||||||
verify_request.minimum_valid_until_ts,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = await fetcher.get_keys(missing_keys)
|
|
||||||
|
|
||||||
completed = []
|
|
||||||
for verify_request in remaining_requests:
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
|
|
||||||
# see if any of the keys we got this time are sufficient to
|
|
||||||
# complete this VerifyJsonRequest.
|
|
||||||
result_keys = results.get(server_name, {})
|
|
||||||
for key_id in verify_request.key_ids:
|
|
||||||
fetch_key_result = result_keys.get(key_id)
|
|
||||||
if not fetch_key_result:
|
|
||||||
# we didn't get a result for this key
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
# If we already have a result for the given key ID we keep the
|
||||||
fetch_key_result.valid_until_ts
|
# one with the highest `valid_until_ts`.
|
||||||
< verify_request.minimum_valid_until_ts
|
existing_key = found_keys.get(key_id)
|
||||||
):
|
if existing_key:
|
||||||
# key was not valid at this point
|
if key.valid_until_ts <= existing_key.valid_until_ts:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we have a valid key for this request. If we run the callback
|
# We always store the returned key even if it doesn't the
|
||||||
# immediately, it may cancel our loggingcontext while we are still in
|
# `minimum_valid_until_ts` requirement, as some verification
|
||||||
# it, so instead we schedule it for the next time round the reactor.
|
# requests may still be able to be satisfied by it.
|
||||||
#
|
#
|
||||||
# (this also ensures that we don't get a stack overflow if we had
|
# We still keep looking for the key from other fetchers in that
|
||||||
# a massive queue of lookups waiting for this server).
|
# case though.
|
||||||
logger.debug(
|
found_keys[key_id] = key
|
||||||
"Found key %s:%s for %s",
|
|
||||||
server_name,
|
|
||||||
key_id,
|
|
||||||
verify_request.request_name,
|
|
||||||
)
|
|
||||||
self.clock.call_later(
|
|
||||||
0,
|
|
||||||
verify_request.key_ready.callback,
|
|
||||||
(server_name, key_id, fetch_key_result.verify_key),
|
|
||||||
)
|
|
||||||
completed.append(verify_request)
|
|
||||||
break
|
|
||||||
|
|
||||||
remaining_requests.difference_update(completed)
|
if key.valid_until_ts < verify_request.minimum_valid_until_ts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
missing_key_ids.discard(key_id)
|
||||||
|
|
||||||
|
return found_keys
|
||||||
|
|
||||||
|
|
||||||
class KeyFetcher(metaclass=abc.ABCMeta):
|
class KeyFetcher(metaclass=abc.ABCMeta):
|
||||||
@abc.abstractmethod
|
def __init__(self, hs: "HomeServer"):
|
||||||
async def get_keys(
|
self._queue = BatchingQueue(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
self.__class__.__name__, hs.get_clock(), self._fetch_keys
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
)
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
keys_to_fetch:
|
|
||||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
|
||||||
|
|
||||||
Returns:
|
async def get_keys(
|
||||||
Map from server_name -> key_id -> FetchKeyResult
|
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
"""
|
) -> Dict[str, FetchKeyResult]:
|
||||||
raise NotImplementedError
|
results = await self._queue.add_to_queue(
|
||||||
|
_FetchKeyRequest(
|
||||||
|
server_name=server_name,
|
||||||
|
key_ids=key_ids,
|
||||||
|
minimum_valid_until_ts=minimum_valid_until_ts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results.get(server_name, {})
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _fetch_keys(
|
||||||
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StoreKeyFetcher(KeyFetcher):
|
class StoreKeyFetcher(KeyFetcher):
|
||||||
"""KeyFetcher impl which fetches keys from our data store"""
|
"""KeyFetcher impl which fetches keys from our data store"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def get_keys(
|
async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
|
||||||
"""see KeyFetcher.get_keys"""
|
|
||||||
|
|
||||||
key_ids_to_fetch = (
|
key_ids_to_fetch = (
|
||||||
(server_name, key_id)
|
(queue_value.server_name, key_id)
|
||||||
for server_name, keys_for_server in keys_to_fetch.items()
|
for queue_value in keys_to_fetch
|
||||||
for key_id in keys_for_server.keys()
|
for key_id in queue_value.key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||||
|
@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher):
|
||||||
|
|
||||||
class BaseV2KeyFetcher(KeyFetcher):
|
class BaseV2KeyFetcher(KeyFetcher):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
|
@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
self.key_servers = self.config.key_servers
|
self.key_servers = self.config.key_servers
|
||||||
|
|
||||||
async def get_keys(
|
async def _fetch_keys(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""see KeyFetcher.get_keys"""
|
"""see KeyFetcher._fetch_keys"""
|
||||||
|
|
||||||
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
||||||
try:
|
try:
|
||||||
|
@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
return union_of_keys
|
return union_of_keys
|
||||||
|
|
||||||
async def get_server_verify_key_v2_indirect(
|
async def get_server_verify_key_v2_indirect(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
|
self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
keys_to_fetch:
|
keys_to_fetch:
|
||||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
the keys to be fetched.
|
||||||
|
|
||||||
key_server: notary server to query for the keys
|
key_server: notary server to query for the keys
|
||||||
|
|
||||||
|
@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
perspective_name = key_server.server_name
|
perspective_name = key_server.server_name
|
||||||
logger.info(
|
logger.info(
|
||||||
"Requesting keys %s from notary server %s",
|
"Requesting keys %s from notary server %s",
|
||||||
keys_to_fetch.items(),
|
keys_to_fetch,
|
||||||
perspective_name,
|
perspective_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||||
path="/_matrix/key/v2/query",
|
path="/_matrix/key/v2/query",
|
||||||
data={
|
data={
|
||||||
"server_keys": {
|
"server_keys": {
|
||||||
server_name: {
|
queue_value.server_name: {
|
||||||
key_id: {"minimum_valid_until_ts": min_valid_ts}
|
key_id: {
|
||||||
for key_id, min_valid_ts in server_keys.items()
|
"minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
|
||||||
|
}
|
||||||
|
for key_id in queue_value.key_ids
|
||||||
}
|
}
|
||||||
for server_name, server_keys in keys_to_fetch.items()
|
for queue_value in keys_to_fetch
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
|
|
||||||
async def get_keys(
|
async def get_keys(
|
||||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
|
results = await self._queue.add_to_queue(
|
||||||
|
_FetchKeyRequest(
|
||||||
|
server_name=server_name,
|
||||||
|
key_ids=key_ids,
|
||||||
|
minimum_valid_until_ts=minimum_valid_until_ts,
|
||||||
|
),
|
||||||
|
key=server_name,
|
||||||
|
)
|
||||||
|
return results.get(server_name, {})
|
||||||
|
|
||||||
|
async def _fetch_keys(
|
||||||
|
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
|
async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
|
||||||
server_name, key_ids = key_to_fetch_item
|
server_name = key_to_fetch_item.server_name
|
||||||
|
key_ids = key_to_fetch_item.key_ids
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||||
results[server_name] = keys
|
results[server_name] = keys
|
||||||
|
@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||||
|
|
||||||
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
await yieldable_gather_results(get_key, keys_to_fetch)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_server_verify_key_v2_direct(
|
async def get_server_verify_key_v2_direct(
|
||||||
|
@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
|
||||||
"""Waits for the key to become available, and then performs a verification
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verify_request:
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
SynapseError if there was a problem performing the verification
|
|
||||||
"""
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
_, key_id, verify_key = await verify_request.key_ready
|
|
||||||
|
|
||||||
json_object = verify_request.get_json_object()
|
|
||||||
|
|
||||||
try:
|
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
|
||||||
except SignatureVerifyException as e:
|
|
||||||
logger.debug(
|
|
||||||
"Error verifying signature for %s:%s:%s with key %s: %s",
|
|
||||||
server_name,
|
|
||||||
verify_key.alg,
|
|
||||||
verify_key.version,
|
|
||||||
encode_verify_key_base64(verify_key),
|
|
||||||
str(e),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"Invalid signature for server %s with key %s:%s: %s"
|
|
||||||
% (server_name, verify_key.alg, verify_key.version, str(e)),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
|
SynapseTags,
|
||||||
start_active_span,
|
start_active_span,
|
||||||
start_active_span_from_request,
|
start_active_span_from_request,
|
||||||
tags,
|
tags,
|
||||||
|
@ -151,7 +152,9 @@ class Authenticator:
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.keyring.verify_json_for_server(
|
await self.keyring.verify_json_for_server(
|
||||||
origin, json_request, now, "Incoming request"
|
origin,
|
||||||
|
json_request,
|
||||||
|
now,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Request from %s", origin)
|
logger.debug("Request from %s", origin)
|
||||||
|
@ -314,7 +317,7 @@ class BaseFederationServlet:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
request_tags = {
|
request_tags = {
|
||||||
"request_id": request.get_request_id(),
|
SynapseTags.REQUEST_ID: request.get_request_id(),
|
||||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||||
tags.HTTP_METHOD: request.get_method(),
|
tags.HTTP_METHOD: request.get_method(),
|
||||||
tags.HTTP_URL: request.get_redacted_uri(),
|
tags.HTTP_URL: request.get_redacted_uri(),
|
||||||
|
|
|
@ -108,7 +108,9 @@ class GroupAttestationSigning:
|
||||||
|
|
||||||
assert server_name is not None
|
assert server_name is not None
|
||||||
await self.keyring.verify_json_for_server(
|
await self.keyring.verify_json_for_server(
|
||||||
server_name, attestation, now, "Group attestation"
|
server_name,
|
||||||
|
attestation,
|
||||||
|
now,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
|
def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
|
||||||
|
|
|
@ -577,7 +577,9 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# Fetch the state events from the DB, and check we have the auth events.
|
# Fetch the state events from the DB, and check we have the auth events.
|
||||||
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
|
event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
|
||||||
auth_events_in_store = await self.store.have_seen_events(auth_event_ids)
|
auth_events_in_store = await self.store.have_seen_events(
|
||||||
|
room_id, auth_event_ids
|
||||||
|
)
|
||||||
|
|
||||||
# Check for missing events. We handle state and auth event seperately,
|
# Check for missing events. We handle state and auth event seperately,
|
||||||
# as we want to pull the state from the DB, but we don't for the auth
|
# as we want to pull the state from the DB, but we don't for the auth
|
||||||
|
@ -610,7 +612,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if missing_auth_events:
|
if missing_auth_events:
|
||||||
auth_events_in_store = await self.store.have_seen_events(
|
auth_events_in_store = await self.store.have_seen_events(
|
||||||
missing_auth_events
|
room_id, missing_auth_events
|
||||||
)
|
)
|
||||||
missing_auth_events.difference_update(auth_events_in_store)
|
missing_auth_events.difference_update(auth_events_in_store)
|
||||||
|
|
||||||
|
@ -710,7 +712,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
||||||
missing_auth_events.difference_update(
|
missing_auth_events.difference_update(
|
||||||
await self.store.have_seen_events(missing_auth_events)
|
await self.store.have_seen_events(room_id, missing_auth_events)
|
||||||
)
|
)
|
||||||
logger.debug("We are also missing %i auth events", len(missing_auth_events))
|
logger.debug("We are also missing %i auth events", len(missing_auth_events))
|
||||||
|
|
||||||
|
@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler):
|
||||||
#
|
#
|
||||||
# we start by checking if they are in the store, and then try calling /event_auth/.
|
# we start by checking if they are in the store, and then try calling /event_auth/.
|
||||||
if missing_auth:
|
if missing_auth:
|
||||||
have_events = await self.store.have_seen_events(missing_auth)
|
have_events = await self.store.have_seen_events(event.room_id, missing_auth)
|
||||||
logger.debug("Events %s are in the store", have_events)
|
logger.debug("Events %s are in the store", have_events)
|
||||||
missing_auth.difference_update(have_events)
|
missing_auth.difference_update(have_events)
|
||||||
|
|
||||||
|
@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler):
|
||||||
return context
|
return context
|
||||||
|
|
||||||
seen_remotes = await self.store.have_seen_events(
|
seen_remotes = await self.store.have_seen_events(
|
||||||
[e.event_id for e in remote_auth_chain]
|
event.room_id, [e.event_id for e in remote_auth_chain]
|
||||||
)
|
)
|
||||||
|
|
||||||
for e in remote_auth_chain:
|
for e in remote_auth_chain:
|
||||||
|
|
|
@ -26,7 +26,6 @@ from synapse.api.constants import (
|
||||||
HistoryVisibility,
|
HistoryVisibility,
|
||||||
Membership,
|
Membership,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError
|
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import format_event_for_client_v2
|
from synapse.events.utils import format_event_for_client_v2
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
@ -456,16 +455,16 @@ class SpaceSummaryHandler:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Otherwise, check if they should be allowed access via membership in a space.
|
# Otherwise, check if they should be allowed access via membership in a space.
|
||||||
try:
|
if self._event_auth_handler.has_restricted_join_rules(
|
||||||
await self._event_auth_handler.check_restricted_join_rules(
|
state_ids, room_version
|
||||||
state_ids, room_version, requester, member_event
|
):
|
||||||
|
allowed_spaces = (
|
||||||
|
await self._event_auth_handler.get_spaces_that_allow_join(state_ids)
|
||||||
)
|
)
|
||||||
except AuthError:
|
if await self._event_auth_handler.is_user_in_rooms(
|
||||||
# The user doesn't have access due to spaces, but might have access
|
allowed_spaces, requester
|
||||||
# another way. Keep trying.
|
):
|
||||||
pass
|
return True
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If this is a request over federation, check if the host is in the room or
|
# If this is a request over federation, check if the host is in the room or
|
||||||
# is in one of the spaces specified via the join rules.
|
# is in one of the spaces specified via the join rules.
|
||||||
|
|
|
@ -464,7 +464,7 @@ class SyncHandler:
|
||||||
# ensure that we always include current state in the timeline
|
# ensure that we always include current state in the timeline
|
||||||
current_state_ids = frozenset() # type: FrozenSet[str]
|
current_state_ids = frozenset() # type: FrozenSet[str]
|
||||||
if any(e.is_state() for e in recents):
|
if any(e.is_state() for e in recents):
|
||||||
current_state_ids_map = await self.state.get_current_state_ids(
|
current_state_ids_map = await self.store.get_current_state_ids(
|
||||||
room_id
|
room_id
|
||||||
)
|
)
|
||||||
current_state_ids = frozenset(current_state_ids_map.values())
|
current_state_ids = frozenset(current_state_ids_map.values())
|
||||||
|
@ -524,7 +524,7 @@ class SyncHandler:
|
||||||
# ensure that we always include current state in the timeline
|
# ensure that we always include current state in the timeline
|
||||||
current_state_ids = frozenset()
|
current_state_ids = frozenset()
|
||||||
if any(e.is_state() for e in loaded_recents):
|
if any(e.is_state() for e in loaded_recents):
|
||||||
current_state_ids_map = await self.state.get_current_state_ids(
|
current_state_ids_map = await self.store.get_current_state_ids(
|
||||||
room_id
|
room_id
|
||||||
)
|
)
|
||||||
current_state_ids = frozenset(current_state_ids_map.values())
|
current_state_ids = frozenset(current_state_ids_map.values())
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
""" This module contains base REST classes for constructing REST servlets. """
|
""" This module contains base REST classes for constructing REST servlets. """
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Iterable, List, Optional, Union, overload
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
|
||||||
|
|
||||||
def parse_string(
|
def parse_string(
|
||||||
request,
|
request,
|
||||||
name,
|
name: Union[bytes, str],
|
||||||
default=None,
|
default: Optional[str] = None,
|
||||||
required=False,
|
required: bool = False,
|
||||||
allowed_values=None,
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
param_type="string",
|
encoding: Optional[str] = "ascii",
|
||||||
encoding="ascii",
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parse a string parameter from the request query string.
|
Parse a string parameter from the request query string.
|
||||||
|
@ -122,18 +124,17 @@ def parse_string(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: the twisted HTTP request.
|
request: the twisted HTTP request.
|
||||||
name (bytes|unicode): the name of the query parameter.
|
name: the name of the query parameter.
|
||||||
default (bytes|unicode|None): value to use if the parameter is absent,
|
default: value to use if the parameter is absent,
|
||||||
defaults to None. Must be bytes if encoding is None.
|
defaults to None. Must be bytes if encoding is None.
|
||||||
required (bool): whether to raise a 400 SynapseError if the
|
required: whether to raise a 400 SynapseError if the
|
||||||
parameter is absent, defaults to False.
|
parameter is absent, defaults to False.
|
||||||
allowed_values (list[bytes|unicode]): List of allowed values for the
|
allowed_values: List of allowed values for the
|
||||||
string, or None if any value is allowed, defaults to None. Must be
|
string, or None if any value is allowed, defaults to None. Must be
|
||||||
the same type as name, if given.
|
the same type as name, if given.
|
||||||
encoding (str|None): The encoding to decode the string content with.
|
encoding : The encoding to decode the string content with.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bytes/unicode|None: A string value or the default. Unicode if encoding
|
A string value or the default. Unicode if encoding
|
||||||
was given, bytes otherwise.
|
was given, bytes otherwise.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -142,45 +143,105 @@ def parse_string(
|
||||||
is not one of those allowed values.
|
is not one of those allowed values.
|
||||||
"""
|
"""
|
||||||
return parse_string_from_args(
|
return parse_string_from_args(
|
||||||
request.args, name, default, required, allowed_values, param_type, encoding
|
request.args, name, default, required, allowed_values, encoding
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_string_from_args(
|
def _parse_string_value(
|
||||||
args,
|
value: Union[str, bytes],
|
||||||
name,
|
allowed_values: Optional[Iterable[str]],
|
||||||
default=None,
|
name: str,
|
||||||
required=False,
|
encoding: Optional[str],
|
||||||
allowed_values=None,
|
) -> Union[str, bytes]:
|
||||||
param_type="string",
|
if encoding:
|
||||||
encoding="ascii",
|
try:
|
||||||
):
|
value = value.decode(encoding)
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
|
||||||
|
|
||||||
|
if allowed_values is not None and value not in allowed_values:
|
||||||
|
message = "Query parameter %r must be one of [%s]" % (
|
||||||
|
name,
|
||||||
|
", ".join(repr(v) for v in allowed_values),
|
||||||
|
)
|
||||||
|
raise SynapseError(400, message)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_strings_from_args(
|
||||||
|
args: List[str],
|
||||||
|
name: Union[bytes, str],
|
||||||
|
default: Optional[List[str]] = None,
|
||||||
|
required: bool = False,
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: Literal[None] = None,
|
||||||
|
) -> Optional[List[bytes]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_strings_from_args(
|
||||||
|
args: List[str],
|
||||||
|
name: Union[bytes, str],
|
||||||
|
default: Optional[List[str]] = None,
|
||||||
|
required: bool = False,
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: str = "ascii",
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def parse_strings_from_args(
|
||||||
|
args: List[str],
|
||||||
|
name: Union[bytes, str],
|
||||||
|
default: Optional[List[str]] = None,
|
||||||
|
required: bool = False,
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: Optional[str] = "ascii",
|
||||||
|
) -> Optional[List[Union[bytes, str]]]:
|
||||||
|
"""
|
||||||
|
Parse a string parameter from the request query string list.
|
||||||
|
|
||||||
|
If encoding is not None, the content of the query param will be
|
||||||
|
decoded to Unicode using the encoding, otherwise it will be encoded
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: the twisted HTTP request.args list.
|
||||||
|
name: the name of the query parameter.
|
||||||
|
default: value to use if the parameter is absent,
|
||||||
|
defaults to None. Must be bytes if encoding is None.
|
||||||
|
required : whether to raise a 400 SynapseError if the
|
||||||
|
parameter is absent, defaults to False.
|
||||||
|
allowed_values (list[bytes|unicode]): List of allowed values for the
|
||||||
|
string, or None if any value is allowed, defaults to None. Must be
|
||||||
|
the same type as name, if given.
|
||||||
|
encoding: The encoding to decode the string content with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string value or the default. Unicode if encoding
|
||||||
|
was given, bytes otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if the parameter is absent and required, or if the
|
||||||
|
parameter is present, must be one of a list of allowed values and
|
||||||
|
is not one of those allowed values.
|
||||||
|
"""
|
||||||
|
|
||||||
if not isinstance(name, bytes):
|
if not isinstance(name, bytes):
|
||||||
name = name.encode("ascii")
|
name = name.encode("ascii")
|
||||||
|
|
||||||
if name in args:
|
if name in args:
|
||||||
value = args[name][0]
|
values = args[name]
|
||||||
|
|
||||||
if encoding:
|
return [
|
||||||
try:
|
_parse_string_value(value, allowed_values, name=name, encoding=encoding)
|
||||||
value = value.decode(encoding)
|
for value in values
|
||||||
except ValueError:
|
]
|
||||||
raise SynapseError(
|
|
||||||
400, "Query parameter %r must be %s" % (name, encoding)
|
|
||||||
)
|
|
||||||
|
|
||||||
if allowed_values is not None and value not in allowed_values:
|
|
||||||
message = "Query parameter %r must be one of [%s]" % (
|
|
||||||
name,
|
|
||||||
", ".join(repr(v) for v in allowed_values),
|
|
||||||
)
|
|
||||||
raise SynapseError(400, message)
|
|
||||||
else:
|
|
||||||
return value
|
|
||||||
else:
|
else:
|
||||||
if required:
|
if required:
|
||||||
message = "Missing %s query parameter %r" % (param_type, name)
|
message = "Missing string query parameter %r" % (name)
|
||||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
@ -190,6 +251,55 @@ def parse_string_from_args(
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def parse_string_from_args(
|
||||||
|
args: List[str],
|
||||||
|
name: Union[bytes, str],
|
||||||
|
default: Optional[str] = None,
|
||||||
|
required: bool = False,
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: Optional[str] = "ascii",
|
||||||
|
) -> Optional[Union[bytes, str]]:
|
||||||
|
"""
|
||||||
|
Parse the string parameter from the request query string list
|
||||||
|
and return the first result.
|
||||||
|
|
||||||
|
If encoding is not None, the content of the query param will be
|
||||||
|
decoded to Unicode using the encoding, otherwise it will be encoded
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: the twisted HTTP request.args list.
|
||||||
|
name: the name of the query parameter.
|
||||||
|
default: value to use if the parameter is absent,
|
||||||
|
defaults to None. Must be bytes if encoding is None.
|
||||||
|
required: whether to raise a 400 SynapseError if the
|
||||||
|
parameter is absent, defaults to False.
|
||||||
|
allowed_values: List of allowed values for the
|
||||||
|
string, or None if any value is allowed, defaults to None. Must be
|
||||||
|
the same type as name, if given.
|
||||||
|
encoding: The encoding to decode the string content with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string value or the default. Unicode if encoding
|
||||||
|
was given, bytes otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if the parameter is absent and required, or if the
|
||||||
|
parameter is present, must be one of a list of allowed values and
|
||||||
|
is not one of those allowed values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
strings = parse_strings_from_args(
|
||||||
|
args,
|
||||||
|
name,
|
||||||
|
default=[default],
|
||||||
|
required=required,
|
||||||
|
allowed_values=allowed_values,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
return strings[0]
|
||||||
|
|
||||||
|
|
||||||
def parse_json_value_from_request(request, allow_empty_body=False):
|
def parse_json_value_from_request(request, allow_empty_body=False):
|
||||||
"""Parse a JSON value from the body of a twisted HTTP request.
|
"""Parse a JSON value from the body of a twisted HTTP request.
|
||||||
|
|
||||||
|
@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
||||||
try:
|
try:
|
||||||
content = json_decoder.decode(content_bytes.decode("utf-8"))
|
content = json_decoder.decode(content_bytes.decode("utf-8"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Unable to parse JSON: %s", e)
|
logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
|
||||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
|
@ -265,6 +265,12 @@ class SynapseTags:
|
||||||
# Whether the sync response has new data to be returned to the client.
|
# Whether the sync response has new data to be returned to the client.
|
||||||
SYNC_RESULT = "sync.new_data"
|
SYNC_RESULT = "sync.new_data"
|
||||||
|
|
||||||
|
# incoming HTTP request ID (as written in the logs)
|
||||||
|
REQUEST_ID = "request_id"
|
||||||
|
|
||||||
|
# HTTP request tag (used to distinguish full vs incremental syncs, etc)
|
||||||
|
REQUEST_TAG = "request_tag"
|
||||||
|
|
||||||
|
|
||||||
# Block everything by default
|
# Block everything by default
|
||||||
# A regex which matches the server_names to expose traces for.
|
# A regex which matches the server_names to expose traces for.
|
||||||
|
@ -588,7 +594,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
|
||||||
|
|
||||||
span = opentracing.tracer.active_span
|
span = opentracing.tracer.active_span
|
||||||
carrier = {} # type: Dict[str, str]
|
carrier = {} # type: Dict[str, str]
|
||||||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
|
||||||
|
|
||||||
for key, value in carrier.items():
|
for key, value in carrier.items():
|
||||||
headers.addRawHeaders(key, value)
|
headers.addRawHeaders(key, value)
|
||||||
|
@ -625,7 +631,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
|
||||||
span = opentracing.tracer.active_span
|
span = opentracing.tracer.active_span
|
||||||
|
|
||||||
carrier = {} # type: Dict[str, str]
|
carrier = {} # type: Dict[str, str]
|
||||||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
|
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
|
||||||
|
|
||||||
for key, value in carrier.items():
|
for key, value in carrier.items():
|
||||||
headers[key.encode()] = [value.encode()]
|
headers[key.encode()] = [value.encode()]
|
||||||
|
@ -659,7 +665,7 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
|
||||||
return
|
return
|
||||||
|
|
||||||
opentracing.tracer.inject(
|
opentracing.tracer.inject(
|
||||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -681,7 +687,7 @@ def get_active_span_text_map(destination=None):
|
||||||
|
|
||||||
carrier = {} # type: Dict[str, str]
|
carrier = {} # type: Dict[str, str]
|
||||||
opentracing.tracer.inject(
|
opentracing.tracer.inject(
|
||||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
|
||||||
)
|
)
|
||||||
|
|
||||||
return carrier
|
return carrier
|
||||||
|
@ -696,7 +702,7 @@ def active_span_context_as_string():
|
||||||
carrier = {} # type: Dict[str, str]
|
carrier = {} # type: Dict[str, str]
|
||||||
if opentracing:
|
if opentracing:
|
||||||
opentracing.tracer.inject(
|
opentracing.tracer.inject(
|
||||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
|
||||||
)
|
)
|
||||||
return json_encoder.encode(carrier)
|
return json_encoder.encode(carrier)
|
||||||
|
|
||||||
|
@ -824,7 +830,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||||
return
|
return
|
||||||
|
|
||||||
request_tags = {
|
request_tags = {
|
||||||
"request_id": request.get_request_id(),
|
SynapseTags.REQUEST_ID: request.get_request_id(),
|
||||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||||
tags.HTTP_METHOD: request.get_method(),
|
tags.HTTP_METHOD: request.get_method(),
|
||||||
tags.HTTP_URL: request.get_redacted_uri(),
|
tags.HTTP_URL: request.get_redacted_uri(),
|
||||||
|
@ -833,9 +839,9 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||||
|
|
||||||
request_name = request.request_metrics.name
|
request_name = request.request_metrics.name
|
||||||
if extract_context:
|
if extract_context:
|
||||||
scope = start_active_span_from_request(request, request_name, tags=request_tags)
|
scope = start_active_span_from_request(request, request_name)
|
||||||
else:
|
else:
|
||||||
scope = start_active_span(request_name, tags=request_tags)
|
scope = start_active_span(request_name)
|
||||||
|
|
||||||
with scope:
|
with scope:
|
||||||
try:
|
try:
|
||||||
|
@ -845,4 +851,11 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||||
# with JsonResource).
|
# with JsonResource).
|
||||||
scope.span.set_operation_name(request.request_metrics.name)
|
scope.span.set_operation_name(request.request_metrics.name)
|
||||||
|
|
||||||
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
|
# set the tags *after* the servlet completes, in case it decided to
|
||||||
|
# prioritise the span (tags will get dropped on unprioritised spans)
|
||||||
|
request_tags[
|
||||||
|
SynapseTags.REQUEST_TAG
|
||||||
|
] = request.request_metrics.start_context.tag
|
||||||
|
|
||||||
|
for k, v in request_tags.items():
|
||||||
|
scope.span.set_tag(k, v)
|
||||||
|
|
|
@ -22,7 +22,11 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.logging.opentracing import noop_context_manager, start_active_span
|
from synapse.logging.opentracing import (
|
||||||
|
SynapseTags,
|
||||||
|
noop_context_manager,
|
||||||
|
start_active_span,
|
||||||
|
)
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -202,7 +206,9 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
||||||
try:
|
try:
|
||||||
ctx = noop_context_manager()
|
ctx = noop_context_manager()
|
||||||
if bg_start_span:
|
if bg_start_span:
|
||||||
ctx = start_active_span(desc, tags={"request_id": str(context)})
|
ctx = start_active_span(
|
||||||
|
desc, tags={SynapseTags.REQUEST_ID: str(context)}
|
||||||
|
)
|
||||||
with ctx:
|
with ctx:
|
||||||
return await maybe_awaitable(func(*args, **kwargs))
|
return await maybe_awaitable(func(*args, **kwargs))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -1061,15 +1061,15 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
|
||||||
RoomTypingRestServlet(hs).register(http_server)
|
RoomTypingRestServlet(hs).register(http_server)
|
||||||
RoomEventContextServlet(hs).register(http_server)
|
RoomEventContextServlet(hs).register(http_server)
|
||||||
RoomSpaceSummaryRestServlet(hs).register(http_server)
|
RoomSpaceSummaryRestServlet(hs).register(http_server)
|
||||||
|
RoomEventServlet(hs).register(http_server)
|
||||||
|
JoinedRoomsRestServlet(hs).register(http_server)
|
||||||
|
RoomAliasListServlet(hs).register(http_server)
|
||||||
|
SearchRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
# Some servlets only get registered for the main process.
|
# Some servlets only get registered for the main process.
|
||||||
if not is_worker:
|
if not is_worker:
|
||||||
RoomCreateRestServlet(hs).register(http_server)
|
RoomCreateRestServlet(hs).register(http_server)
|
||||||
RoomForgetRestServlet(hs).register(http_server)
|
RoomForgetRestServlet(hs).register(http_server)
|
||||||
SearchRestServlet(hs).register(http_server)
|
|
||||||
JoinedRoomsRestServlet(hs).register(http_server)
|
|
||||||
RoomEventServlet(hs).register(http_server)
|
|
||||||
RoomAliasListServlet(hs).register(http_server)
|
|
||||||
|
|
||||||
|
|
||||||
def register_deprecated_servlets(hs, http_server):
|
def register_deprecated_servlets(hs, http_server):
|
||||||
|
|
|
@ -16,11 +16,7 @@ import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
RestServlet,
|
|
||||||
assert_params_in_dict,
|
|
||||||
parse_json_object_from_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
|
||||||
|
@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet):
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
assert_params_in_dict(body, ("reason", "score"))
|
|
||||||
|
|
||||||
if not isinstance(body["reason"], str):
|
if not isinstance(body.get("reason", ""), str):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"Param 'reason' must be a string",
|
"Param 'reason' must be a string",
|
||||||
Codes.BAD_JSON,
|
Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
if not isinstance(body["score"], int):
|
if not isinstance(body.get("score", 0), int):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"Param 'score' must be an integer",
|
"Param 'score' must be an integer",
|
||||||
|
@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet):
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_id=event_id,
|
event_id=event_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
reason=body["reason"],
|
reason=body.get("reason"),
|
||||||
content=body,
|
content=body,
|
||||||
received_ts=self.clock.time_msec(),
|
received_ts=self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -210,7 +211,13 @@ class RemoteKey(DirectServeJsonResource):
|
||||||
# If there is a cache miss, request the missing keys, then recurse (and
|
# If there is a cache miss, request the missing keys, then recurse (and
|
||||||
# ensure the result is sent).
|
# ensure the result is sent).
|
||||||
if cache_misses and query_remote_on_cache_miss:
|
if cache_misses and query_remote_on_cache_miss:
|
||||||
await self.fetcher.get_keys(cache_misses)
|
await yieldable_gather_results(
|
||||||
|
lambda t: self.fetcher.get_keys(*t),
|
||||||
|
(
|
||||||
|
(server_name, list(keys), 0)
|
||||||
|
for server_name, keys in cache_misses.items()
|
||||||
|
),
|
||||||
|
)
|
||||||
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||||
else:
|
else:
|
||||||
signed_keys = []
|
signed_keys = []
|
||||||
|
|
|
@ -168,6 +168,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
backfilled,
|
backfilled,
|
||||||
):
|
):
|
||||||
self._invalidate_get_event_cache(event_id)
|
self._invalidate_get_event_cache(event_id)
|
||||||
|
self.have_seen_event.invalidate((room_id, event_id))
|
||||||
|
|
||||||
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return {r["event_id"] for r in rows}
|
return {r["event_id"] for r in rows}
|
||||||
|
|
||||||
async def have_seen_events(self, event_ids):
|
async def have_seen_events(
|
||||||
|
self, room_id: str, event_ids: Iterable[str]
|
||||||
|
) -> Set[str]:
|
||||||
"""Given a list of event ids, check if we have already processed them.
|
"""Given a list of event ids, check if we have already processed them.
|
||||||
|
|
||||||
|
The room_id is only used to structure the cache (so that it can later be
|
||||||
|
invalidated by room_id) - there is no guarantee that the events are actually
|
||||||
|
in the room in question.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids (iterable[str]):
|
room_id: Room we are polling
|
||||||
|
event_ids: events we are looking for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
set[str]: The events we have already seen.
|
set[str]: The events we have already seen.
|
||||||
"""
|
"""
|
||||||
# if the event cache contains the event, obviously we've seen it.
|
res = await self._have_seen_events_dict(
|
||||||
results = {x for x in event_ids if self._get_event_cache.contains(x)}
|
(room_id, event_id) for event_id in event_ids
|
||||||
|
)
|
||||||
|
return {eid for ((_rid, eid), have_event) in res.items() if have_event}
|
||||||
|
|
||||||
def have_seen_events_txn(txn, chunk):
|
@cachedList("have_seen_event", "keys")
|
||||||
sql = "SELECT event_id FROM events as e WHERE "
|
async def _have_seen_events_dict(
|
||||||
|
self, keys: Iterable[Tuple[str, str]]
|
||||||
|
) -> Dict[Tuple[str, str], bool]:
|
||||||
|
"""Helper for have_seen_events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dict {(room_id, event_id)-> bool}
|
||||||
|
"""
|
||||||
|
# if the event cache contains the event, obviously we've seen it.
|
||||||
|
|
||||||
|
cache_results = {
|
||||||
|
(rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
|
||||||
|
}
|
||||||
|
results = {x: True for x in cache_results}
|
||||||
|
|
||||||
|
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
|
||||||
|
# we deliberately do *not* query the database for room_id, to make the
|
||||||
|
# query an index-only lookup on `events_event_id_key`.
|
||||||
|
#
|
||||||
|
# We therefore pull the events from the database into a set...
|
||||||
|
|
||||||
|
sql = "SELECT event_id FROM events AS e WHERE "
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
txn.database_engine, "e.event_id", chunk
|
txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
|
||||||
)
|
)
|
||||||
txn.execute(sql + clause, args)
|
txn.execute(sql + clause, args)
|
||||||
results.update(row[0] for row in txn)
|
found_events = {eid for eid, in txn}
|
||||||
|
|
||||||
for chunk in batch_iter((x for x in event_ids if x not in results), 100):
|
# ... and then we can update the results for each row in the batch
|
||||||
|
results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
|
||||||
|
|
||||||
|
# each batch requires its own index scan, so we make the batches as big as
|
||||||
|
# possible.
|
||||||
|
for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"have_seen_events", have_seen_events_txn, chunk
|
"have_seen_events", have_seen_events_txn, chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@cached(max_entries=100000, tree=True)
|
||||||
|
async def have_seen_event(self, room_id: str, event_id: str):
|
||||||
|
# this only exists for the benefit of the @cachedList descriptor on
|
||||||
|
# _have_seen_events_dict
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _get_current_state_event_counts_txn(self, txn, room_id):
|
def _get_current_state_event_counts_txn(self, txn, room_id):
|
||||||
"""
|
"""
|
||||||
See get_current_state_event_counts.
|
See get_current_state_event_counts.
|
||||||
|
|
|
@ -16,14 +16,14 @@ import logging
|
||||||
from typing import Any, List, Set, Tuple
|
from typing import Any, List, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||||
async def purge_history(
|
async def purge_history(
|
||||||
self, room_id: str, token: str, delete_local_events: bool
|
self, room_id: str, token: str, delete_local_events: bool
|
||||||
) -> Set[int]:
|
) -> Set[int]:
|
||||||
|
@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
"DELETE FROM event_to_state_groups "
|
"DELETE FROM event_to_state_groups "
|
||||||
"WHERE event_id IN (SELECT event_id from events_to_purge)"
|
"WHERE event_id IN (SELECT event_id from events_to_purge)"
|
||||||
)
|
)
|
||||||
for event_id, _ in event_rows:
|
|
||||||
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
|
|
||||||
|
|
||||||
# Delete all remote non-state events
|
# Delete all remote non-state events
|
||||||
for table in (
|
for table in (
|
||||||
|
@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
# so make sure to keep this actually last.
|
# so make sure to keep this actually last.
|
||||||
txn.execute("DROP TABLE events_to_purge")
|
txn.execute("DROP TABLE events_to_purge")
|
||||||
|
|
||||||
|
for event_id, should_delete in event_rows:
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self._get_state_group_for_event, (event_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# XXX: This is racy, since have_seen_events could be called between the
|
||||||
|
# transaction completing and the invalidation running. On the other hand,
|
||||||
|
# that's no different to calling `have_seen_events` just before the
|
||||||
|
# event is deleted from the database.
|
||||||
|
if should_delete:
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.have_seen_event, (room_id, event_id)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("[purge] done")
|
logger.info("[purge] done")
|
||||||
|
|
||||||
return referenced_state_groups
|
return referenced_state_groups
|
||||||
|
@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
# index on them. In any case we should be clearing out 'stream' tables
|
# index on them. In any case we should be clearing out 'stream' tables
|
||||||
# periodically anyway (#5888)
|
# periodically anyway (#5888)
|
||||||
|
|
||||||
# TODO: we could probably usefully do a bunch of cache invalidation here
|
# TODO: we could probably usefully do a bunch more cache invalidation here
|
||||||
|
|
||||||
|
# XXX: as with purge_history, this is racy, but no worse than other races
|
||||||
|
# that already exist.
|
||||||
|
self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
|
||||||
|
|
||||||
logger.info("[purge] done")
|
logger.info("[purge] done")
|
||||||
|
|
||||||
|
|
|
@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
reason: str,
|
reason: Optional[str],
|
||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
received_ts: int,
|
received_ts: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
|
from typing import Dict, List
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -21,7 +22,6 @@ import signedjson.sign
|
||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey
|
||||||
from signedjson.key import encode_verify_key_base64, get_verify_key
|
from signedjson.key import encode_verify_key_base64, get_verify_key
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
# deferred completes.
|
# deferred completes.
|
||||||
first_lookup_deferred = Deferred()
|
first_lookup_deferred = Deferred()
|
||||||
|
|
||||||
async def first_lookup_fetch(keys_to_fetch):
|
async def first_lookup_fetch(
|
||||||
self.assertEquals(current_context().request.id, "context_11")
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
|
) -> Dict[str, FetchKeyResult]:
|
||||||
|
# self.assertEquals(current_context().request.id, "context_11")
|
||||||
|
self.assertEqual(server_name, "server10")
|
||||||
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
|
self.assertEqual(minimum_valid_until_ts, 0)
|
||||||
|
|
||||||
await make_deferred_yieldable(first_lookup_deferred)
|
await make_deferred_yieldable(first_lookup_deferred)
|
||||||
return {
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
|
||||||
"server10": {
|
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
mock_fetcher.get_keys.side_effect = first_lookup_fetch
|
||||||
|
|
||||||
async def first_lookup():
|
async def first_lookup():
|
||||||
with LoggingContext("context_11", request=FakeRequest("context_11")):
|
with LoggingContext("context_11", request=FakeRequest("context_11")):
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
|
[("server10", json1, 0), ("server11", {}, 0)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# the unsigned json should be rejected pretty quickly
|
# the unsigned json should be rejected pretty quickly
|
||||||
|
@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
d0 = ensureDeferred(first_lookup())
|
d0 = ensureDeferred(first_lookup())
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
mock_fetcher.get_keys.assert_called_once()
|
mock_fetcher.get_keys.assert_called_once()
|
||||||
|
|
||||||
# a second request for a server with outstanding requests
|
# a second request for a server with outstanding requests
|
||||||
# should block rather than start a second call
|
# should block rather than start a second call
|
||||||
|
|
||||||
async def second_lookup_fetch(keys_to_fetch):
|
async def second_lookup_fetch(
|
||||||
self.assertEquals(current_context().request.id, "context_12")
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
return {
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"server10": {
|
# self.assertEquals(current_context().request.id, "context_12")
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_fetcher.get_keys.reset_mock()
|
mock_fetcher.get_keys.reset_mock()
|
||||||
mock_fetcher.get_keys.side_effect = second_lookup_fetch
|
mock_fetcher.get_keys.side_effect = second_lookup_fetch
|
||||||
|
@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
async def second_lookup():
|
async def second_lookup():
|
||||||
with LoggingContext("context_12", request=FakeRequest("context_12")):
|
with LoggingContext("context_12", request=FakeRequest("context_12")):
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, 0, "test")]
|
[
|
||||||
|
(
|
||||||
|
"server10",
|
||||||
|
json1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
second_lookup_state[0] = 1
|
second_lookup_state[0] = 1
|
||||||
|
@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
signedjson.sign.sign_json(json1, "server9", key1)
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
|
||||||
# should fail immediately on an unsigned object
|
# should fail immediately on an unsigned object
|
||||||
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
|
d = kr.verify_json_for_server("server9", {}, 0)
|
||||||
self.get_failure(d, SynapseError)
|
self.get_failure(d, SynapseError)
|
||||||
|
|
||||||
# should succeed on a signed object
|
# should succeed on a signed object
|
||||||
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
|
d = kr.verify_json_for_server("server9", json1, 500)
|
||||||
# self.assertFalse(d.called)
|
# self.assertFalse(d.called)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
|
@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
signedjson.sign.sign_json(json1, "server9", key1)
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
|
||||||
# should fail immediately on an unsigned object
|
# should fail immediately on an unsigned object
|
||||||
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
|
d = kr.verify_json_for_server("server9", {}, 0)
|
||||||
self.get_failure(d, SynapseError)
|
self.get_failure(d, SynapseError)
|
||||||
|
|
||||||
# should fail on a signed object with a non-zero minimum_valid_until_ms,
|
# should fail on a signed object with a non-zero minimum_valid_until_ms,
|
||||||
# as it tries to refetch the keys and fails.
|
# as it tries to refetch the keys and fails.
|
||||||
d = _verify_json_for_server(
|
d = kr.verify_json_for_server("server9", json1, 500)
|
||||||
kr, "server9", json1, 500, "test signed non-zero min"
|
|
||||||
)
|
|
||||||
self.get_failure(d, SynapseError)
|
self.get_failure(d, SynapseError)
|
||||||
|
|
||||||
# We expect the keyring tried to refetch the key once.
|
# We expect the keyring tried to refetch the key once.
|
||||||
mock_fetcher.get_keys.assert_called_once_with(
|
mock_fetcher.get_keys.assert_called_once_with(
|
||||||
{"server9": {get_key_id(key1): 500}}
|
"server9", [get_key_id(key1)], 500
|
||||||
)
|
)
|
||||||
|
|
||||||
# should succeed on a signed object with a 0 minimum_valid_until_ms
|
# should succeed on a signed object with a 0 minimum_valid_until_ms
|
||||||
d = _verify_json_for_server(
|
d = kr.verify_json_for_server(
|
||||||
kr, "server9", json1, 0, "test signed with zero min"
|
"server9",
|
||||||
|
json1,
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
|
@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
"""Two requests for the same key should be deduped."""
|
"""Two requests for the same key should be deduped."""
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
async def get_keys(keys_to_fetch):
|
async def get_keys(
|
||||||
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
|
) -> Dict[str, FetchKeyResult]:
|
||||||
# there should only be one request object (with the max validity)
|
# there should only be one request object (with the max validity)
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
self.assertEqual(server_name, "server1")
|
||||||
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
|
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||||
|
|
||||||
return {
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
|
||||||
"server1": {
|
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_fetcher = Mock()
|
mock_fetcher = Mock()
|
||||||
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||||
|
@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
# the first request should succeed; the second should fail because the key
|
# the first request should succeed; the second should fail because the key
|
||||||
# has expired
|
# has expired
|
||||||
results = kr.verify_json_objects_for_server(
|
results = kr.verify_json_objects_for_server(
|
||||||
[("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
|
[
|
||||||
|
(
|
||||||
|
"server1",
|
||||||
|
json1,
|
||||||
|
500,
|
||||||
|
),
|
||||||
|
("server1", json1, 1500),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.get_success(results[0])
|
self.get_success(results[0])
|
||||||
|
@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
"""If the first fetcher cannot provide a recent enough key, we fall back"""
|
||||||
key1 = signedjson.key.generate_signing_key(1)
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
|
||||||
async def get_keys1(keys_to_fetch):
|
async def get_keys1(
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
return {
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
self.assertEqual(server_name, "server1")
|
||||||
}
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
|
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||||
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
|
||||||
|
|
||||||
async def get_keys2(keys_to_fetch):
|
async def get_keys2(
|
||||||
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
|
server_name: str, key_ids: List[str], minimum_valid_until_ts: int
|
||||||
return {
|
) -> Dict[str, FetchKeyResult]:
|
||||||
"server1": {
|
self.assertEqual(server_name, "server1")
|
||||||
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
|
self.assertEqual(key_ids, [get_key_id(key1)])
|
||||||
}
|
self.assertEqual(minimum_valid_until_ts, 1500)
|
||||||
}
|
return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
|
||||||
|
|
||||||
mock_fetcher1 = Mock()
|
mock_fetcher1 = Mock()
|
||||||
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||||
|
@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
signedjson.sign.sign_json(json1, "server1", key1)
|
signedjson.sign.sign_json(json1, "server1", key1)
|
||||||
|
|
||||||
results = kr.verify_json_objects_for_server(
|
results = kr.verify_json_objects_for_server(
|
||||||
[("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
|
[
|
||||||
|
(
|
||||||
|
"server1",
|
||||||
|
json1,
|
||||||
|
1200,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"server1",
|
||||||
|
json1,
|
||||||
|
1500,
|
||||||
|
),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.get_success(results[0])
|
self.get_success(results[0])
|
||||||
|
@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.http_client.get_json.side_effect = get_json
|
self.http_client.get_json.side_effect = get_json
|
||||||
|
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
k = keys[testverifykey_id]
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||||
|
@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
# change the server name: the result should be ignored
|
# change the server name: the result should be ignored
|
||||||
response["server_name"] = "OTHER_SERVER"
|
response["server_name"] = "OTHER_SERVER"
|
||||||
|
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
self.assertEqual(keys, {})
|
self.assertEqual(keys, {})
|
||||||
|
|
||||||
|
|
||||||
|
@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||||
|
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
self.assertIn(testverifykey_id, keys)
|
||||||
self.assertIn(SERVER_NAME, keys)
|
k = keys[testverifykey_id]
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||||
|
@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||||
|
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
|
self.assertIn(testverifykey_id, keys)
|
||||||
self.assertIn(SERVER_NAME, keys)
|
k = keys[testverifykey_id]
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
|
||||||
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
self.assertEqual(k.verify_key.alg, "ed25519")
|
self.assertEqual(k.verify_key.alg, "ed25519")
|
||||||
|
@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def get_key_from_perspectives(response):
|
def get_key_from_perspectives(response):
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs)
|
fetcher = PerspectivesKeyFetcher(self.hs)
|
||||||
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
|
|
||||||
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
|
||||||
return self.get_success(fetcher.get_keys(keys_to_fetch))
|
return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
|
||||||
|
|
||||||
# start with a valid response so we can check we are testing the right thing
|
# start with a valid response so we can check we are testing the right thing
|
||||||
response = build_response()
|
response = build_response()
|
||||||
keys = get_key_from_perspectives(response)
|
keys = get_key_from_perspectives(response)
|
||||||
k = keys[SERVER_NAME][testverifykey_id]
|
k = keys[testverifykey_id]
|
||||||
self.assertEqual(k.verify_key, testverifykey)
|
self.assertEqual(k.verify_key, testverifykey)
|
||||||
|
|
||||||
# remove the perspectives server's signature
|
# remove the perspectives server's signature
|
||||||
|
@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
def get_key_id(key):
|
def get_key_id(key):
|
||||||
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
|
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
|
||||||
return "%s:%s" % (key.alg, key.version)
|
return "%s:%s" % (key.alg, key.version)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def run_in_context(f, *args, **kwargs):
|
|
||||||
with LoggingContext("testctx"):
|
|
||||||
rv = yield f(*args, **kwargs)
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_json_for_server(kr, *args):
|
|
||||||
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
|
||||||
with the patched defer.inlineCallbacks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def v():
|
|
||||||
rv1 = yield kr.verify_json_for_server(*args)
|
|
||||||
return rv1
|
|
||||||
|
|
||||||
return run_in_context(v)
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
|
||||||
user_tok=self.admin_user_tok,
|
user_tok=self.admin_user_tok,
|
||||||
)
|
)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
self._create_event_and_report(
|
self._create_event_and_report_without_parameters(
|
||||||
room_id=self.room_id2,
|
room_id=self.room_id2,
|
||||||
user_tok=self.admin_user_tok,
|
user_tok=self.admin_user_tok,
|
||||||
)
|
)
|
||||||
|
@ -378,6 +378,19 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
def _create_event_and_report_without_parameters(self, room_id, user_tok):
|
||||||
|
"""Create and report an event, but omit reason and score"""
|
||||||
|
resp = self.helper.send(room_id, tok=user_tok)
|
||||||
|
event_id = resp["event_id"]
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"rooms/%s/report/%s" % (room_id, event_id),
|
||||||
|
json.dumps({}),
|
||||||
|
access_token=user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
def _check_fields(self, content):
|
def _check_fields(self, content):
|
||||||
"""Checks that all attributes are present in an event report"""
|
"""Checks that all attributes are present in an event report"""
|
||||||
for c in content:
|
for c in content:
|
||||||
|
|
83
tests/rest/client/v2_alpha/test_report_event.py
Normal file
83
tests/rest/client/v2_alpha/test_report_event.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright 2021 Callum Brown
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import synapse.rest.admin
|
||||||
|
from synapse.rest.client.v1 import login, room
|
||||||
|
from synapse.rest.client.v2_alpha import report_event
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ReportEventTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
report_event.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
self.other_user = self.register_user("user", "pass")
|
||||||
|
self.other_user_tok = self.login("user", "pass")
|
||||||
|
|
||||||
|
self.room_id = self.helper.create_room_as(
|
||||||
|
self.other_user, tok=self.other_user_tok, is_public=True
|
||||||
|
)
|
||||||
|
self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok)
|
||||||
|
resp = self.helper.send(self.room_id, tok=self.admin_user_tok)
|
||||||
|
self.event_id = resp["event_id"]
|
||||||
|
self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id)
|
||||||
|
|
||||||
|
def test_reason_str_and_score_int(self):
|
||||||
|
data = {"reason": "this makes me sad", "score": -100}
|
||||||
|
self._assert_status(200, data)
|
||||||
|
|
||||||
|
def test_no_reason(self):
|
||||||
|
data = {"score": 0}
|
||||||
|
self._assert_status(200, data)
|
||||||
|
|
||||||
|
def test_no_score(self):
|
||||||
|
data = {"reason": "this makes me sad"}
|
||||||
|
self._assert_status(200, data)
|
||||||
|
|
||||||
|
def test_no_reason_and_no_score(self):
|
||||||
|
data = {}
|
||||||
|
self._assert_status(200, data)
|
||||||
|
|
||||||
|
def test_reason_int_and_score_str(self):
|
||||||
|
data = {"reason": 10, "score": "string"}
|
||||||
|
self._assert_status(400, data)
|
||||||
|
|
||||||
|
def test_reason_zero_and_score_blank(self):
|
||||||
|
data = {"reason": 0, "score": ""}
|
||||||
|
self._assert_status(400, data)
|
||||||
|
|
||||||
|
def test_reason_and_score_null(self):
|
||||||
|
data = {"reason": None, "score": None}
|
||||||
|
self._assert_status(400, data)
|
||||||
|
|
||||||
|
def _assert_status(self, response_status, data):
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.report_path,
|
||||||
|
json.dumps(data),
|
||||||
|
access_token=self.other_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
response_status, int(channel.result["code"]), msg=channel.result["body"]
|
||||||
|
)
|
|
@ -208,10 +208,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||||
keyid = "ed25519:%s" % (testkey.version,)
|
keyid = "ed25519:%s" % (testkey.version,)
|
||||||
|
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs2)
|
fetcher = PerspectivesKeyFetcher(self.hs2)
|
||||||
d = fetcher.get_keys({"targetserver": {keyid: 1000}})
|
d = fetcher.get_keys("targetserver", [keyid], 1000)
|
||||||
res = self.get_success(d)
|
res = self.get_success(d)
|
||||||
self.assertIn("targetserver", res)
|
self.assertIn(keyid, res)
|
||||||
keyres = res["targetserver"][keyid]
|
keyres = res[keyid]
|
||||||
assert isinstance(keyres, FetchKeyResult)
|
assert isinstance(keyres, FetchKeyResult)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
||||||
|
@ -230,10 +230,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||||
keyid = "ed25519:%s" % (testkey.version,)
|
keyid = "ed25519:%s" % (testkey.version,)
|
||||||
|
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs2)
|
fetcher = PerspectivesKeyFetcher(self.hs2)
|
||||||
d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
|
d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
|
||||||
res = self.get_success(d)
|
res = self.get_success(d)
|
||||||
self.assertIn(self.hs.hostname, res)
|
self.assertIn(keyid, res)
|
||||||
keyres = res[self.hs.hostname][keyid]
|
keyres = res[keyid]
|
||||||
assert isinstance(keyres, FetchKeyResult)
|
assert isinstance(keyres, FetchKeyResult)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
||||||
|
@ -247,10 +247,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||||
keyid = "ed25519:%s" % (self.hs_signing_key.version,)
|
keyid = "ed25519:%s" % (self.hs_signing_key.version,)
|
||||||
|
|
||||||
fetcher = PerspectivesKeyFetcher(self.hs2)
|
fetcher = PerspectivesKeyFetcher(self.hs2)
|
||||||
d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
|
d = fetcher.get_keys(self.hs.hostname, [keyid], 1000)
|
||||||
res = self.get_success(d)
|
res = self.get_success(d)
|
||||||
self.assertIn(self.hs.hostname, res)
|
self.assertIn(keyid, res)
|
||||||
keyres = res[self.hs.hostname][keyid]
|
keyres = res[keyid]
|
||||||
assert isinstance(keyres, FetchKeyResult)
|
assert isinstance(keyres, FetchKeyResult)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
signedjson.key.encode_verify_key_base64(keyres.verify_key),
|
||||||
|
|
13
tests/storage/databases/__init__.py
Normal file
13
tests/storage/databases/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
13
tests/storage/databases/main/__init__.py
Normal file
13
tests/storage/databases/main/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
96
tests/storage/databases/main/test_events_worker.py
Normal file
96
tests/storage/databases/main/test_events_worker.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
|
from synapse.logging.context import LoggingContext
|
||||||
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store: EventsWorkerStore = hs.get_datastore()
|
||||||
|
|
||||||
|
# insert some test data
|
||||||
|
for rid in ("room1", "room2"):
|
||||||
|
self.get_success(
|
||||||
|
self.store.db_pool.simple_insert(
|
||||||
|
"rooms",
|
||||||
|
{"room_id": rid, "room_version": 4},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, (rid, eid) in enumerate(
|
||||||
|
(
|
||||||
|
("room1", "event10"),
|
||||||
|
("room1", "event11"),
|
||||||
|
("room1", "event12"),
|
||||||
|
("room2", "event20"),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
self.get_success(
|
||||||
|
self.store.db_pool.simple_insert(
|
||||||
|
"events",
|
||||||
|
{
|
||||||
|
"event_id": eid,
|
||||||
|
"room_id": rid,
|
||||||
|
"topological_ordering": idx,
|
||||||
|
"stream_ordering": idx,
|
||||||
|
"type": "test",
|
||||||
|
"processed": True,
|
||||||
|
"outlier": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.store.db_pool.simple_insert(
|
||||||
|
"event_json",
|
||||||
|
{
|
||||||
|
"event_id": eid,
|
||||||
|
"room_id": rid,
|
||||||
|
"json": json.dumps({"type": "test", "room_id": rid}),
|
||||||
|
"internal_metadata": "{}",
|
||||||
|
"format_version": 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
with LoggingContext(name="test") as ctx:
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.have_seen_events("room1", ["event10", "event19"])
|
||||||
|
)
|
||||||
|
self.assertEquals(res, {"event10"})
|
||||||
|
|
||||||
|
# that should result in a single db query
|
||||||
|
self.assertEquals(ctx.get_resource_usage().db_txn_count, 1)
|
||||||
|
|
||||||
|
# a second lookup of the same events should cause no queries
|
||||||
|
with LoggingContext(name="test") as ctx:
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.have_seen_events("room1", ["event10", "event19"])
|
||||||
|
)
|
||||||
|
self.assertEquals(res, {"event10"})
|
||||||
|
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
|
||||||
|
|
||||||
|
def test_query_via_event_cache(self):
|
||||||
|
# fetch an event into the event cache
|
||||||
|
self.get_success(self.store.get_event("event10"))
|
||||||
|
|
||||||
|
# looking it up should now cause no db hits
|
||||||
|
with LoggingContext(name="test") as ctx:
|
||||||
|
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
|
||||||
|
self.assertEquals(res, {"event10"})
|
||||||
|
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
|
|
@ -45,37 +45,32 @@ class BatchingQueueTestCase(TestCase):
|
||||||
self._pending_calls.append((values, d))
|
self._pending_calls.append((values, d))
|
||||||
return await make_deferred_yieldable(d)
|
return await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
def _get_sample_with_name(self, metric, name) -> int:
|
||||||
|
"""For a prometheus metric get the value of the sample that has a
|
||||||
|
matching "name" label.
|
||||||
|
"""
|
||||||
|
for sample in metric.collect()[0].samples:
|
||||||
|
if sample.labels.get("name") == name:
|
||||||
|
return sample.value
|
||||||
|
|
||||||
|
self.fail("Found no matching sample")
|
||||||
|
|
||||||
def _assert_metrics(self, queued, keys, in_flight):
|
def _assert_metrics(self, queued, keys, in_flight):
|
||||||
"""Assert that the metrics are correct"""
|
"""Assert that the metrics are correct"""
|
||||||
|
|
||||||
self.assertEqual(len(number_queued.collect()), 1)
|
sample = self._get_sample_with_name(number_queued, self.queue._name)
|
||||||
self.assertEqual(len(number_queued.collect()[0].samples), 1)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
number_queued.collect()[0].samples[0].labels,
|
sample,
|
||||||
{"name": self.queue._name},
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
number_queued.collect()[0].samples[0].value,
|
|
||||||
queued,
|
queued,
|
||||||
"number_queued",
|
"number_queued",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(number_of_keys.collect()), 1)
|
sample = self._get_sample_with_name(number_of_keys, self.queue._name)
|
||||||
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
|
self.assertEqual(sample, keys, "number_of_keys")
|
||||||
self.assertEqual(
|
|
||||||
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(number_in_flight.collect()), 1)
|
sample = self._get_sample_with_name(number_in_flight, self.queue._name)
|
||||||
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
sample,
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
number_in_flight.collect()[0].samples[0].value,
|
|
||||||
in_flight,
|
in_flight,
|
||||||
"number_in_flight",
|
"number_in_flight",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue