Merge branch 'release-v1.18.0' into matrix-org-hotfixes

This commit is contained in:
Richard van der Hoff 2020-07-28 10:15:22 +01:00
commit b2ccc72a00
111 changed files with 2482 additions and 2118 deletions

View file

@ -1,3 +1,96 @@
Synapse 1.18.0rc1 (2020-07-27)
==============================
Features
--------
- Include room states on invite events that are sent to application services. Contributed by @Sorunome. ([\#6455](https://github.com/matrix-org/synapse/issues/6455))
- Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7613](https://github.com/matrix-org/synapse/issues/7613), [\#7953](https://github.com/matrix-org/synapse/issues/7953))
- Add experimental support for running multiple federation sender processes. ([\#7798](https://github.com/matrix-org/synapse/issues/7798))
- Add the option to validate the `iss` and `aud` claims for JWT logins. ([\#7827](https://github.com/matrix-org/synapse/issues/7827))
- Add support for handling registration requests across multiple client reader workers. ([\#7830](https://github.com/matrix-org/synapse/issues/7830))
- Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7842](https://github.com/matrix-org/synapse/issues/7842))
- Allow email subjects to be customised through Synapse's configuration. ([\#7846](https://github.com/matrix-org/synapse/issues/7846))
- Add the ability to re-activate an account from the admin API. ([\#7847](https://github.com/matrix-org/synapse/issues/7847), [\#7908](https://github.com/matrix-org/synapse/issues/7908))
- Add experimental support for running multiple pusher workers. ([\#7855](https://github.com/matrix-org/synapse/issues/7855))
- Add experimental support for moving typing off master. ([\#7869](https://github.com/matrix-org/synapse/issues/7869), [\#7959](https://github.com/matrix-org/synapse/issues/7959))
- Report CPU metrics to prometheus for time spent processing replication commands. ([\#7879](https://github.com/matrix-org/synapse/issues/7879))
- Support oEmbed for media previews. ([\#7920](https://github.com/matrix-org/synapse/issues/7920))
- Abort federation requests where the client disconnects before the ratelimiter expires. ([\#7930](https://github.com/matrix-org/synapse/issues/7930))
- Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work. ([\#7931](https://github.com/matrix-org/synapse/issues/7931))
Bugfixes
--------
- Fix detection of out of sync remote device lists when receiving events from remote users. ([\#7815](https://github.com/matrix-org/synapse/issues/7815))
- Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain. ([\#7817](https://github.com/matrix-org/synapse/issues/7817))
- Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0. ([\#7822](https://github.com/matrix-org/synapse/issues/7822))
- Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. ([\#7829](https://github.com/matrix-org/synapse/issues/7829))
- Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. ([\#7844](https://github.com/matrix-org/synapse/issues/7844))
- Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0. ([\#7850](https://github.com/matrix-org/synapse/issues/7850))
- Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. ([\#7854](https://github.com/matrix-org/synapse/issues/7854))
- Fix a bug which allowed empty rooms to be rejoined over federation. ([\#7859](https://github.com/matrix-org/synapse/issues/7859))
- Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. ([\#7866](https://github.com/matrix-org/synapse/issues/7866))
- Fix a long standing bug where the tracing of async functions with opentracing was broken. ([\#7872](https://github.com/matrix-org/synapse/issues/7872), [\#7961](https://github.com/matrix-org/synapse/issues/7961))
- Fix "TypeError in `synapse.notifier`" exceptions. ([\#7880](https://github.com/matrix-org/synapse/issues/7880))
- Fix deprecation warning due to invalid escape sequences. ([\#7895](https://github.com/matrix-org/synapse/issues/7895))
Updates to the Docker image
---------------------------
- Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. ([\#7839](https://github.com/matrix-org/synapse/issues/7839))
Improved Documentation
----------------------
- Provide instructions on using `register_new_matrix_user` via docker. ([\#7885](https://github.com/matrix-org/synapse/issues/7885))
- Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. ([\#7889](https://github.com/matrix-org/synapse/issues/7889))
- Reorder database paragraphs to promote postgres over sqlite. ([\#7933](https://github.com/matrix-org/synapse/issues/7933))
- Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md). ([\#7934](https://github.com/matrix-org/synapse/issues/7934))
Deprecations and Removals
-------------------------
- Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. ([\#7878](https://github.com/matrix-org/synapse/issues/7878))
- Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. ([\#7888](https://github.com/matrix-org/synapse/issues/7888))
Internal Changes
----------------
- Switch parts of the codebase from `simplejson` to the standard library `json`. ([\#7802](https://github.com/matrix-org/synapse/issues/7802))
- Add type hints to the http server code and remove an unused parameter. ([\#7813](https://github.com/matrix-org/synapse/issues/7813))
- Add type hints to synapse.api.errors module. ([\#7820](https://github.com/matrix-org/synapse/issues/7820))
- Ensure that calls to `json.dumps` are compatible with the standard library json. ([\#7836](https://github.com/matrix-org/synapse/issues/7836))
- Remove redundant `retry_on_integrity_error` wrapper for event persistence code. ([\#7848](https://github.com/matrix-org/synapse/issues/7848))
- Consistently use `db_to_json` to convert from database values to JSON objects. ([\#7849](https://github.com/matrix-org/synapse/issues/7849))
- Convert various parts of the codebase to async/await. ([\#7851](https://github.com/matrix-org/synapse/issues/7851), [\#7860](https://github.com/matrix-org/synapse/issues/7860), [\#7868](https://github.com/matrix-org/synapse/issues/7868), [\#7871](https://github.com/matrix-org/synapse/issues/7871), [\#7873](https://github.com/matrix-org/synapse/issues/7873), [\#7874](https://github.com/matrix-org/synapse/issues/7874), [\#7884](https://github.com/matrix-org/synapse/issues/7884), [\#7912](https://github.com/matrix-org/synapse/issues/7912), [\#7935](https://github.com/matrix-org/synapse/issues/7935), [\#7939](https://github.com/matrix-org/synapse/issues/7939), [\#7942](https://github.com/matrix-org/synapse/issues/7942), [\#7944](https://github.com/matrix-org/synapse/issues/7944))
- Add support for handling registration requests across multiple client reader workers. ([\#7853](https://github.com/matrix-org/synapse/issues/7853))
- Small performance improvement in typing processing. ([\#7856](https://github.com/matrix-org/synapse/issues/7856))
- The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. ([\#7858](https://github.com/matrix-org/synapse/issues/7858))
- Optimise queueing of inbound replication commands. ([\#7861](https://github.com/matrix-org/synapse/issues/7861))
- Add some type annotations to `HomeServer` and `BaseHandler`. ([\#7870](https://github.com/matrix-org/synapse/issues/7870))
- Clean up `PreserveLoggingContext`. ([\#7877](https://github.com/matrix-org/synapse/issues/7877))
- Change "unknown room version" logging from 'error' to 'warning'. ([\#7881](https://github.com/matrix-org/synapse/issues/7881))
- Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. ([\#7882](https://github.com/matrix-org/synapse/issues/7882))
- Return an empty body for OPTIONS requests. ([\#7886](https://github.com/matrix-org/synapse/issues/7886))
- Fix typo in generated config file. Contributed by @ThiefMaster. ([\#7890](https://github.com/matrix-org/synapse/issues/7890))
- Import ABC from `collections.abc` for Python 3.10 compatibility. ([\#7892](https://github.com/matrix-org/synapse/issues/7892))
- Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
and `get_previous_frame` from `synapse.logging.utils` module. ([\#7897](https://github.com/matrix-org/synapse/issues/7897))
- Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. ([\#7914](https://github.com/matrix-org/synapse/issues/7914))
- Use Element CSS and logo in notification emails when app name is Element. ([\#7919](https://github.com/matrix-org/synapse/issues/7919))
- Optimisation to /sync handling: skip serializing the response if the client has already disconnected. ([\#7927](https://github.com/matrix-org/synapse/issues/7927))
- When a client disconnects, don't log it as 'Error processing request'. ([\#7928](https://github.com/matrix-org/synapse/issues/7928))
- Add debugging to `/sync` response generation (disabled by default). ([\#7929](https://github.com/matrix-org/synapse/issues/7929))
- Update comments that refer to Deferreds for async functions. ([\#7945](https://github.com/matrix-org/synapse/issues/7945))
- Simplify error handling in federation handler. ([\#7950](https://github.com/matrix-org/synapse/issues/7950))
Synapse 1.17.0 (2020-07-13)
===========================

View file

@ -1 +0,0 @@
Include room states on invite events that are sent to application services. Contributed by @Sorunome.

View file

@ -1 +0,0 @@
Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel.

View file

@ -1 +0,0 @@
Add experimental support for running multiple federation sender processes.

View file

@ -1 +0,0 @@
Switch from simplejson to the standard library json.

View file

@ -1 +0,0 @@
Add type hints to the http server code and remove an unused parameter.

View file

@ -1 +0,0 @@
Fix detection of out of sync remote device lists when receiving events from remote users.

View file

@ -1 +0,0 @@
Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain.

View file

@ -1 +0,0 @@
Add type hints to synapse.api.errors module.

View file

@ -1 +0,0 @@
Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0.

View file

@ -1 +0,0 @@
Add the option to validate the `iss` and `aud` claims for JWT logins.

View file

@ -1 +0,0 @@
Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails.

View file

@ -1 +0,0 @@
Add support for handling registration requests across multiple client reader workers.

View file

@ -1 +0,0 @@
Ensure that calls to `json.dumps` are compatible with the standard library json.

View file

@ -1 +0,0 @@
Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196.

View file

@ -1 +0,0 @@
Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH.

View file

@ -1 +0,0 @@
Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.

View file

@ -1 +0,0 @@
Allow email subjects to be customised through Synapse's configuration.

View file

@ -1 +0,0 @@
Add the ability to re-activate an account from the admin API.

View file

@ -1 +0,0 @@
Remove redundant `retry_on_integrity_error` wrapper for event persistence code.

View file

@ -1 +0,0 @@
Consistently use `db_to_json` to convert from database values to JSON objects.

View file

@ -1 +0,0 @@
Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0.

View file

@ -1 +0,0 @@
Convert E2E keys and room keys handlers to async/await.

View file

@ -1 +0,0 @@
Add support for handling registration requests across multiple client reader workers.

View file

@ -1 +0,0 @@
Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.

View file

@ -1 +0,0 @@
Add experimental support for running multiple pusher workers.

View file

@ -1 +0,0 @@
Small performance improvement in typing processing.

View file

@ -1 +0,0 @@
The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100.

View file

@ -1 +0,0 @@
Fix a bug which allowed empty rooms to be rejoined over federation.

View file

@ -1 +0,0 @@
Convert _base, profile, and _receipts handlers to async/await.

View file

@ -1 +0,0 @@
Optimise queueing of inbound replication commands.

View file

@ -1 +0,0 @@
Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.

View file

@ -1 +0,0 @@
Convert synapse.app and federation client to async/await.

View file

@ -1 +0,0 @@
Add experimental support for moving typing off master.

View file

@ -1 +0,0 @@
Add some type annotations to `HomeServer` and `BaseHandler`.

View file

@ -1 +0,0 @@
Convert device handler to async/await.

View file

@ -1 +0,0 @@
Fix a long standing bug where the tracing of async functions with opentracing was broken.

View file

@ -1 +0,0 @@
Convert the federation agent and related code to async/await.

1
changelog.d/7876.bugfix Normal file
View file

@ -0,0 +1 @@
Fix an `AssertionError` exception introduced in v1.18.0rc1.

1
changelog.d/7876.misc Normal file
View file

@ -0,0 +1 @@
Further optimise queueing of inbound replication commands.

View file

@ -1 +0,0 @@
Clean up `PreserveLoggingContext`.

View file

@ -1 +0,0 @@
Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric.

View file

@ -1 +0,0 @@
Report CPU metrics to prometheus for time spent processing replication commands.

View file

@ -1 +0,0 @@
Fix "TypeError in `synapse.notifier`" exceptions.

View file

@ -1 +0,0 @@
Change "unknown room version" logging from 'error' to 'warning'.

View file

@ -1 +0,0 @@
Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`.

View file

@ -1 +0,0 @@
Convert the message handler to async/await.

View file

@ -1 +0,0 @@
Provide instructions on using `register_new_matrix_user` via docker.

View file

@ -1 +0,0 @@
Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim.

View file

@ -1 +0,0 @@
Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation.

View file

@ -1 +0,0 @@
Fix typo in generated config file. Contributed by @ThiefMaster.

View file

@ -1 +0,0 @@
Import ABC from `collections.abc` for Python 3.10 compatibility.

View file

@ -1 +0,0 @@
Fix deprecation warning due to invalid escape sequences.

View file

@ -1,2 +0,0 @@
Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
and `get_previous_frame` from `synapse.logging.utils` module.

View file

@ -1 +0,0 @@
Add the ability to re-activate an account from the admin API.

View file

@ -1 +0,0 @@
Convert `RoomListHandler` to async/await.

View file

@ -1 +0,0 @@
Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI.

View file

@ -1 +0,0 @@
Use Element CSS and logo in notification emails when app name is Element.

View file

@ -1 +0,0 @@
Optimisation to /sync handling: skip serializing the response if the client has already disconnected.

View file

@ -1 +0,0 @@
When a client disconnects, don't log it as 'Error processing request'.

View file

@ -1 +0,0 @@
Add debugging to `/sync` response generation (disabled by default).

View file

@ -1 +0,0 @@
Abort federation requests where the client disconnects before the ratelimiter expires.

View file

@ -1 +0,0 @@
Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work.

View file

@ -1 +0,0 @@
Reorder database paragraphs to promote postgres over sqlite.

View file

@ -1 +0,0 @@
Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md).

View file

@ -1 +0,0 @@
Convert the auth providers to be async/await.

View file

@ -1 +0,0 @@
Convert presence handler helpers to async/await.

View file

@ -36,7 +36,7 @@ try:
except ImportError:
pass
__version__ = "1.17.0"
__version__ = "1.18.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -127,8 +127,10 @@ class Auth(object):
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
)
membership = member.membership if member else None
@ -665,8 +667,10 @@ class Auth(object):
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
)
if (
visibility

View file

@ -87,7 +87,6 @@ from synapse.replication.tcp.streams import (
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
@ -644,7 +643,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
super(GenericWorkerReplicationHandler, self).__init__(hs)
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
self.notifier = hs.get_notifier()
@ -681,11 +679,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
elif stream_name == TypingStream.NAME:
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:

View file

@ -106,8 +106,8 @@ class EventBuilder(object):
Deferred[FrozenEvent]
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids
state_ids = yield defer.ensureDeferred(
self._state.get_current_state_ids(self.room_id, prev_event_ids)
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)

View file

@ -348,7 +348,9 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains = yield self.state.get_current_hosts_in_room(room_id)
domains = yield defer.ensureDeferred(
self.state.get_current_hosts_in_room(room_id)
)
domains = [
d
for d in domains

View file

@ -72,7 +72,7 @@ class AdminHandler(BaseHandler):
writer (ExfiltrationWriter)
Returns:
defer.Deferred: Resolves when all data for a user has been written.
Resolves when all data for a user has been written.
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in

View file

@ -16,10 +16,11 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json, json
from signedjson.key import decode_verify_key_bytes
from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@ -265,7 +266,9 @@ class E2eKeysHandler(object):
return ret
async def get_cross_signing_keys_from_cache(self, query, from_user_id):
async def get_cross_signing_keys_from_cache(
self, query, from_user_id
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
@ -277,8 +280,7 @@ class E2eKeysHandler(object):
can see.
Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
A map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@ -312,16 +314,17 @@ class E2eKeysHandler(object):
}
@trace
async def query_local_devices(self, query):
async def query_local_devices(
self, query: Dict[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
query: map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = []
@ -1004,7 +1007,7 @@ class E2eKeysHandler(object):
async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
):
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
@ -1015,8 +1018,7 @@ class E2eKeysHandler(object):
desired_key_type: The type of key to receive. One of "master", "self_signing"
Returns:
Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
of the retrieved key content, the key's ID and the matching VerifyKey.
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
try:

View file

@ -1394,7 +1394,7 @@ class FederationHandler(BaseHandler):
# it's just a best-effort thing at this point. We do want to do
# them roughly in order, though, otherwise we'll end up making
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
# have. Hence we fire off the background task, but don't wait for it.
run_in_background(self._handle_queued_pdus, room_queue)
@ -1887,9 +1887,6 @@ class FederationHandler(BaseHandler):
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
success = False
try:
if (
not event.internal_metadata.is_outlier()
@ -1903,12 +1900,11 @@ class FederationHandler(BaseHandler):
await self.persist_events_and_notify(
[(event, context)], backfilled=backfilled
)
success = True
finally:
if not success:
run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
except Exception:
run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
raise
return context
@ -2994,7 +2990,9 @@ class FederationHandler(BaseHandler):
else:
user_joined_room(self.distributor, user, room_id)
async def get_room_complexity(self, remote_room_hosts, room_id):
async def get_room_complexity(
self, remote_room_hosts: List[str], room_id: str
) -> Optional[dict]:
"""
Fetch the complexity of a remote room over federation.
@ -3003,7 +3001,7 @@ class FederationHandler(BaseHandler):
room_id (str): The room ID to ask about.
Returns:
Deferred[dict] or Deferred[None]: Dict contains the complexity
Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""

View file

@ -19,6 +19,7 @@
import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
@ -36,6 +37,7 @@ from synapse.api.errors import (
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@ -59,23 +61,23 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client()
self.hs = hs
async def threepid_from_creds(self, id_server, creds):
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
id_server (str): The identity server to validate 3PIDs against. Must be a
id_server: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
creds (dict[str, str]): Dictionary containing the following keys:
creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
* sid: The ID of the validation session
Returns:
Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
the /getValidated3pid endpoint of the Identity Service API, or None if the
threepid was not found
A dictionary consisting of response params to the /getValidated3pid
endpoint of the Identity Service API, or None if the threepid was not found
"""
client_secret = creds.get("client_secret") or creds.get("clientSecret")
if not client_secret:
@ -119,26 +121,27 @@ class IdentityHandler(BaseHandler):
return None
async def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
):
self,
client_secret: str,
sid: str,
mxid: str,
id_server: str,
id_access_token: Optional[str] = None,
use_v2: bool = True,
) -> JsonDict:
"""Bind a 3PID to an identity server
Args:
client_secret (str): A unique secret provided by the client
sid (str): The ID of the validation session
mxid (str): The MXID to bind the 3PID to
id_server (str): The domain of the identity server to query
id_access_token (str): The access token to authenticate to the identity
client_secret: A unique secret provided by the client
sid: The ID of the validation session
mxid: The MXID to bind the 3PID to
id_server: The domain of the identity server to query
id_access_token: The access token to authenticate to the identity
server with, if necessary. Required if use_v2 is true
use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True
use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
Returns:
Deferred[dict]: The response from the identity server
The response from the identity server
"""
logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
@ -151,7 +154,7 @@ class IdentityHandler(BaseHandler):
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
headers["Authorization"] = create_id_access_token_header(id_access_token)
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
@ -187,20 +190,20 @@ class IdentityHandler(BaseHandler):
)
return res
async def try_unbind_threepid(self, mxid, threepid):
async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
Args:
mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be
mxid: Matrix user ID of binding to be removed
threepid: Dict with medium & address of binding to be
removed, and an optional id_server.
Raises:
SynapseError: If we failed to contact the identity server
Returns:
Deferred[bool]: True on success, otherwise False if the identity
True on success, otherwise False if the identity
server doesn't support unbinding (or no identity server found to
contact).
"""
@ -223,19 +226,21 @@ class IdentityHandler(BaseHandler):
return changed
async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
async def try_unbind_threepid_with_id_server(
self, mxid: str, threepid: dict, id_server: str
) -> bool:
"""Removes a binding from an identity server
Args:
mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be removed
id_server (str): Identity server to unbind from
mxid: Matrix user ID of binding to be removed
threepid: Dict with medium & address of binding to be removed
id_server: Identity server to unbind from
Raises:
SynapseError: If we failed to contact the identity server
Returns:
Deferred[bool]: True on success, otherwise False if the identity
True on success, otherwise False if the identity
server doesn't support unbinding
"""
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
@ -287,23 +292,23 @@ class IdentityHandler(BaseHandler):
async def send_threepid_validation(
self,
email_address,
client_secret,
send_attempt,
send_email_func,
next_link=None,
):
email_address: str,
client_secret: str,
send_attempt: int,
send_email_func: Callable[[str, str, str, str], Awaitable],
next_link: Optional[str] = None,
) -> str:
"""Send a threepid validation email for password reset or
registration purposes
Args:
email_address (str): The user's email address
client_secret (str): The provided client secret
send_attempt (int): Which send attempt this is
send_email_func (func): A function that takes an email address, token,
client_secret and session_id, sends an email
and returns a Deferred.
next_link (str|None): The URL to redirect the user to after validation
email_address: The user's email address
client_secret: The provided client secret
send_attempt: Which send attempt this is
send_email_func: A function that takes an email address, token,
client_secret and session_id, sends an email
and returns an Awaitable.
next_link: The URL to redirect the user to after validation
Returns:
The new session_id upon success
@ -372,17 +377,22 @@ class IdentityHandler(BaseHandler):
return session_id
async def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None
):
self,
id_server: str,
email: str,
client_secret: str,
send_attempt: int,
next_link: Optional[str] = None,
) -> JsonDict:
"""
Request an external server send an email on our behalf for the purposes of threepid
validation.
Args:
id_server (str): The identity server to proxy to
email (str): The email to send the message to
client_secret (str): The unique client_secret sends by the user
send_attempt (int): Which attempt this is
id_server: The identity server to proxy to
email: The email to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
@ -419,22 +429,22 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
id_server,
country,
phone_number,
client_secret,
send_attempt,
next_link=None,
):
id_server: str,
country: str,
phone_number: str,
client_secret: str,
send_attempt: int,
next_link: Optional[str] = None,
) -> JsonDict:
"""
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
id_server (str): The identity server to proxy to
country (str): The country code of the phone number
phone_number (str): The number to send the message to
client_secret (str): The unique client_secret sends by the user
send_attempt (int): Which attempt this is
id_server: The identity server to proxy to
country: The country code of the phone number
phone_number: The number to send the message to
client_secret: The unique client_secret sends by the user
send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
@ -480,17 +490,18 @@ class IdentityHandler(BaseHandler):
)
return data
async def validate_threepid_session(self, client_secret, sid):
async def validate_threepid_session(
self, client_secret: str, sid: str
) -> Optional[JsonDict]:
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
Args:
client_secret (str): A secret provided by the client
sid (str): The ID of the session
client_secret: A secret provided by the client
sid: The ID of the session
Returns:
Dict[str, str|int] if validation was successful, otherwise None
The json response if validation was successful, otherwise None
"""
# XXX: We shouldn't need to keep wrapping and unwrapping this value
threepid_creds = {"client_secret": client_secret, "sid": sid}
@ -523,23 +534,22 @@ class IdentityHandler(BaseHandler):
return validation_session
async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
) -> JsonDict:
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
id_server (str): The identity server URL to contact
client_secret (str): Secret provided by the client
sid (str): The ID of the session
token (str): The verification token
id_server: The identity server URL to contact
client_secret: Secret provided by the client
sid: The ID of the session
token: The verification token
Raises:
SynapseError: If we failed to contact the identity server
Returns:
Deferred[dict]: The response dict from the identity server
The response dict from the identity server
"""
body = {"client_secret": client_secret, "sid": sid, "token": token}
@ -554,19 +564,25 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
async def lookup_3pid(
self,
id_server: str,
medium: str,
address: str,
id_access_token: Optional[str] = None,
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
id_server: The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
id_access_token (str|None): The access token to authenticate to the identity
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
id_access_token: The access token to authenticate to the identity
server with
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
@ -591,17 +607,19 @@ class IdentityHandler(BaseHandler):
return await self._lookup_3pid_v1(id_server, medium, address)
async def _lookup_3pid_v1(self, id_server, medium, address):
async def _lookup_3pid_v1(
self, id_server: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
id_server: The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = await self.blacklisting_http_client.get_json(
@ -621,18 +639,20 @@ class IdentityHandler(BaseHandler):
return None
async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
async def _lookup_3pid_v2(
self, id_server: str, id_access_token: str, medium: str, address: str
) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
id_server (str): The server name (including port, if required)
id_server: The server name (including port, if required)
of the identity server to use.
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
id_access_token: The access token to authenticate to the identity server with
medium: The type of the third party identifier (e.g. "email").
address: The third party identifier (e.g. "foo@example.com").
Returns:
Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
@ -757,49 +777,48 @@ class IdentityHandler(BaseHandler):
async def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url,
id_access_token=None,
):
requester: Requester,
id_server: str,
medium: str,
address: str,
room_id: str,
inviter_user_id: str,
room_alias: str,
room_avatar_url: str,
room_join_rules: str,
room_name: str,
inviter_display_name: str,
inviter_avatar_url: str,
id_access_token: Optional[str] = None,
) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
"""
Asks an identity server for a third party invite.
Args:
requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
requester
id_server: hostname + optional port for the identity server.
medium: The literal string "email".
address: The third party address being invited.
room_id: The ID of the room to which the user is invited.
inviter_user_id: The user ID of the inviter.
room_alias: An alias for the room, for cosmetic notifications.
room_avatar_url: The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
room_join_rules: The join rules of the email (e.g. "public").
room_name: The m.room.name of the room.
inviter_display_name: The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
inviter_avatar_url: The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
A tuple containing:
token: The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
display_name: A user-friendly name to represent the invited user.
"""
invite_config = {
"medium": medium,
@ -896,15 +915,15 @@ class IdentityHandler(BaseHandler):
return token, public_keys, fallback_public_key, display_name
def create_id_access_token_header(id_access_token):
def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
of an HTTP request.
Args:
id_access_token (str): An identity server access token.
id_access_token: An identity server access token.
Returns:
list[str]: The ascii-encoded bearer token encased in a list.
The ascii-encoded bearer token encased in a list.
"""
# Prefix with Bearer
bearer_token = "Bearer %s" % id_access_token

View file

@ -859,9 +859,6 @@ class EventCreationHandler(object):
await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
success = False
try:
# If we're a worker we need to hit out to the master.
if not self._is_event_writer:
@ -877,22 +874,20 @@ class EventCreationHandler(object):
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
success = True
return stream_id
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
return stream_id
finally:
if not success:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
self.store.remove_push_actions_from_staging, event.event_id
)
raise
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str

View file

@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids)

View file

@ -119,7 +119,7 @@ class RoomCreationHandler(BaseHandler):
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
) -> str:
"""Replace a room with a new room with a different version
Args:
@ -128,7 +128,7 @@ class RoomCreationHandler(BaseHandler):
new_version: the new room version to use
Returns:
Deferred[unicode]: the new room id
the new room id
"""
await self.ratelimit(requester)
@ -239,7 +239,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
):
) -> None:
"""Send updated power levels in both rooms after an upgrade
Args:
@ -247,9 +247,6 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
Returns:
Deferred
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@ -322,7 +319,7 @@ class RoomCreationHandler(BaseHandler):
new_room_id: str,
new_room_version: RoomVersion,
tombstone_event_id: str,
):
) -> None:
"""Populate a new room based on an old room
Args:
@ -332,8 +329,6 @@ class RoomCreationHandler(BaseHandler):
created with _gemerate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
Returns:
Deferred
"""
user_id = requester.user.to_string()

View file

@ -15,6 +15,7 @@
import itertools
import logging
from typing import Iterable
from unpaddedbase64 import decode_base64, encode_base64
@ -37,7 +38,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state
self.auth = hs.get_auth()
async def get_old_rooms_from_upgraded_room(self, room_id):
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@ -48,10 +49,10 @@ class SearchHandler(BaseHandler):
The full list of all found rooms in then returned.
Args:
room_id (str): id of the room to search through.
room_id: id of the room to search through.
Returns:
Deferred[iterable[str]]: predecessor room ids
Predecessor room ids
"""
historical_room_ids = []

View file

@ -424,10 +424,6 @@ class SyncHandler(object):
potential_recents: Optional[List[EventBase]] = None,
newly_joined_room: bool = False,
) -> TimelineBatch:
"""
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = (

View file

@ -442,21 +442,6 @@ class StaticResource(File):
return super().render_GET(request)
def _options_handler(request):
"""Request handler for OPTIONS requests
This is a request handler suitable for return from
_get_handler_for_request. It returns a 200 and an empty body.
Args:
request (twisted.web.http.Request):
Returns:
Tuple[int, dict]: http code, response body.
"""
return 200, {}
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
@ -490,11 +475,12 @@ class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request):
code, response_json_object = _options_handler(request)
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
return respond_with_json(
request, code, response_json_object, send_cors=True, canonical_json=False,
)
set_cors_headers(request)
return b""
def getChildWithDefault(self, path, request):
if request.method == b"OPTIONS":

View file

@ -737,24 +737,14 @@ def trace(func=None, opname=None):
@wraps(func)
async def _trace_inner(*args, **kwargs):
if opentracing is None:
with start_active_span(_opname):
return await func(*args, **kwargs)
with start_active_span(_opname) as scope:
try:
return await func(*args, **kwargs)
except Exception:
scope.span.set_tag(tags.ERROR, True)
raise
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args, **kwargs):
if opentracing is None:
return func(*args, **kwargs)
scope = start_active_span(_opname)
scope.__enter__()
@ -767,7 +757,6 @@ def trace(func=None, opname=None):
return result
def err_back(result):
scope.span.set_tag(tags.ERROR, True)
scope.__exit__(None, None, None)
return result

View file

@ -116,6 +116,8 @@ class _LogContextScope(Scope):
if self._enter_logcontext:
self.logcontext.__enter__()
return self
def __exit__(self, type, value, traceback):
if type == twisted.internet.defer._DefGen_Return:
super(_LogContextScope, self).__exit__(None, None, None)

View file

@ -304,7 +304,9 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))

View file

@ -24,6 +24,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams import TypingStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@ -104,6 +105,7 @@ class ReplicationDataHandler:
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
@ -127,6 +129,12 @@ class ReplicationDataHandler:
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
if stream_name == TypingStream.NAME:
self._typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.

View file

@ -16,6 +16,7 @@
import logging
from typing import (
Any,
Awaitable,
Dict,
Iterable,
Iterator,
@ -33,6 +34,7 @@ from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
@ -152,7 +154,7 @@ class ReplicationCommandHandler:
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_stream
# the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str]
# for each stream, a queue of commands that are awaiting processing, and the
@ -185,7 +187,7 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
async def _add_command_to_stream_queue(
def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@ -199,33 +201,34 @@ class ReplicationCommandHandler:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return
# if we're already processing this stream, stick the new command in the
# queue, and we're done.
queue.append((cmd, conn))
# if we're already processing this stream, there's nothing more to do:
# the new entry on the queue will get picked up in due course
if stream_name in self._processing_streams:
queue.append((cmd, conn))
return
# otherwise, process the new command.
# fire off a background process to start processing the queue.
run_as_background_process(
"process-replication-data", self._unsafe_process_queue, stream_name
)
# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.
async def _unsafe_process_queue(self, stream_name: str):
"""Processes the command queue for the given stream, until it is empty
Does not check if there is already a thread processing the queue, hence "unsafe"
"""
assert stream_name not in self._processing_streams
self._processing_streams.add(stream_name)
try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)
# now process any other commands that have built up while we were
# dealing with that one.
queue = self._command_queues_by_stream.get(stream_name)
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)
finally:
self._processing_streams.discard(stream_name)
@ -299,7 +302,7 @@ class ReplicationCommandHandler:
"""
return self._streams_to_replicate
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
@ -318,57 +321,73 @@ class ReplicationCommandHandler:
)
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
if self._is_master:
await self._presence_handler.update_external_syncs_row(
return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
else:
return None
async def on_CLEAR_USER_SYNC(
def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
):
) -> Optional[Awaitable[None]]:
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
async def on_FEDERATION_ACK(
self, conn: AbstractConnection, cmd: FederationAckCommand
):
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
async def on_REMOVE_PUSHER(
def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
):
) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
return self._handle_remove_pusher(cmd)
else:
return None
self._notifier.on_new_replication_data()
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
self._notifier.on_new_replication_data()
def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
if self._is_master:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
return self._handle_user_ip(cmd)
else:
return None
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def _handle_user_ip(self, cmd: UserIpCommand):
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@ -382,7 +401,7 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@ -459,14 +478,14 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@ -526,9 +545,7 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)

View file

@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
self._logging_context = BackgroundProcessLoggingContext(
"replication_command_handler-%s" % self.conn_id
)
ctx_name = "replication-conn-%s" % self.conn_id
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
self._logging_context.request = ctx_name
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
self.handle_command(cmd)
async def handle_command(self, cmd: Command):
def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for
protocol level handling of commands (e.g. PINGs), before delegating to
the handler.
`self.command_handler.on_<COMMAND>` if it exists (which can optionally
return an Awaitable).
This allows for protocol level handling of commands (e.g. PINGs), before
delegating to the handler.
Args:
cmd: received command
@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
handled = True
if not handled:
@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
async def on_PING(self, line):
def on_PING(self, line):
self.received_ping = True
async def on_ERROR(self, cmd):
def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
async def on_NAME(self, cmd):
def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
async def on_SERVER(self, cmd):
def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from inspect import isawaitable
from typing import TYPE_CHECKING
import txredisapi
@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
self.handle_command(cmd)
async def handle_command(self, cmd: Command):
def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable.
Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
Awaitable).
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for redis
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
handled = True
if not handled:
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
def connectionLost(self, reason):
logger.info("Lost connection to redis")

View file

@ -17,8 +17,7 @@
"""
import logging
import re
from twisted.internet import defer
from typing import Iterable, Pattern
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@ -27,15 +26,23 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
def client_patterns(
path_regex: str,
releases: Iterable[int] = (0,),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
releases: An iterable of releases to include this endpoint under.
unstable: If true, include this endpoint under the "unstable" prefix.
v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
SRE_Pattern
An iterable of patterns.
"""
patterns = []
@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
async def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
await self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res
async def wrapped(*args, **kwargs):
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result
return wrapped
def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result

View file

@ -18,7 +18,6 @@ import logging
import os
import urllib
from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@ -77,8 +76,9 @@ def respond_404(request):
)
@defer.inlineCallbacks
def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None
):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
return True
@defer.inlineCallbacks
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None
):
"""Responds to the request with given responder. If responder is None then
returns 404.
@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
yield responder.write_to_consumer(request)
await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us

View file

@ -14,17 +14,18 @@
# limitations under the License.
import contextlib
import inspect
import logging
import os
import shutil
from typing import Optional
from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import Responder
from ._base import FileInfo, Responder
logger = logging.getLogger(__name__)
@ -46,25 +47,24 @@ class MediaStorage(object):
self.filepaths = filepaths
self.storage_providers = storage_providers
@defer.inlineCallbacks
def store_file(self, source, file_info):
async def store_file(self, source, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
file_info (FileInfo): Info about the file to store
file_info: Info about the file to store
Returns:
Deferred[str]: the file path written to in the primary media store
the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
yield defer_to_thread(
await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
yield finish_cb()
await finish_cb()
return fname
@ -75,7 +75,7 @@ class MediaStorage(object):
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns a Deferred.
on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
@ -91,7 +91,7 @@ class MediaStorage(object):
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
yield finish_cb()
await finish_cb()
"""
path = self._file_info_to_path(file_info)
@ -103,10 +103,13 @@ class MediaStorage(object):
finished_called = [False]
@defer.inlineCallbacks
def finish():
async def finish():
for provider in self.storage_providers:
yield provider.store_file(path, file_info)
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = provider.store_file(path, file_info)
if inspect.isawaitable(result):
await result
finished_called[0] = True
@ -123,17 +126,15 @@ class MediaStorage(object):
if not finished_called:
raise Exception("Finished callback not called")
@defer.inlineCallbacks
def fetch_media(self, file_info):
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
file_info (FileInfo)
file_info
Returns:
Deferred[Responder|None]: Returns a Responder if the file was found,
otherwise None.
Returns a Responder if the file was found, otherwise None.
"""
path = self._file_info_to_path(file_info)
@ -142,23 +143,26 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return None
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
file_info (FileInfo)
file_info
Returns:
Deferred[str]: Full path to local file
Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
@ -170,14 +174,18 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
yield res.write_to_consumer(consumer)
yield consumer.wait()
await res.write_to_consumer(consumer)
await consumer.wait()
return local_path
raise Exception("file could not be found")

View file

@ -26,6 +26,7 @@ import traceback
from typing import Dict, Optional
from urllib import parse as urlparse
import attr
from canonicaljson import json
from twisted.internet import defer
@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
# A map of globs to API endpoints.
_oembed_globs = {
# Twitter.
"https://publish.twitter.com/oembed": [
"https://twitter.com/*/status/*",
"https://*.twitter.com/*/status/*",
"https://twitter.com/*/moments/*",
"https://*.twitter.com/*/moments/*",
# Include the HTTP versions too.
"http://twitter.com/*/status/*",
"http://*.twitter.com/*/status/*",
"http://twitter.com/*/moments/*",
"http://*.twitter.com/*/moments/*",
],
}
# Convert the globs to regular expressions.
_oembed_patterns = {}
for endpoint, globs in _oembed_globs.items():
for glob in globs:
# Convert the glob into a sane regular expression to match against. The
# rules followed will be slightly different for the domain portion vs.
# the rest.
#
# 1. The scheme must be one of HTTP / HTTPS (and have no globs).
# 2. The domain can have globs, but we limit it to characters that can
# reasonably be a domain part.
# TODO: This does not attempt to handle Unicode domain names.
# 3. Other parts allow a glob to be any one, or more, characters.
results = urlparse.urlparse(glob)
# Ensure the scheme does not have wildcards (and is a sane scheme).
if results.scheme not in {"http", "https"}:
raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
pattern = urlparse.urlunparse(
[
results.scheme,
re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
]
+ [re.escape(part).replace("\\*", ".+") for part in results[2:]]
)
_oembed_patterns[re.compile(pattern)] = endpoint
@attr.s
class OEmbedResult:
# Either HTML content or URL must be provided.
html = attr.ib(type=Optional[str])
url = attr.ib(type=Optional[str])
title = attr.ib(type=Optional[str])
# Number of seconds to cache the content.
cache_age = attr.ib(type=int)
class OEmbedError(Exception):
"""An error occurred processing the oEmbed object."""
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource):
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
expiry_ms=ONE_HOUR,
)
if self._worker_run_media_background_jobs:
@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
def _get_oembed_url(self, url: str) -> Optional[str]:
"""
Check whether the URL should be downloaded as oEmbed content instead.
Params:
url: The URL to check.
Returns:
A URL to use instead or None if the original URL should be used.
"""
for url_pattern, endpoint in _oembed_patterns.items():
if url_pattern.fullmatch(url):
return endpoint
# No match.
return None
async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
"""
Request content from an oEmbed endpoint.
Params:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
Returns:
An object representing the metadata returned.
Raises:
OEmbedError if fetching or parsing of the oEmbed information fails.
"""
try:
logger.debug("Trying to get oEmbed content for url '%s'", url)
result = await self.client.get_json(
endpoint,
# TODO Specify max height / width.
# Note that only the JSON format is supported.
args={"url": url},
)
# Ensure there's a version of 1.0.
if result.get("version") != "1.0":
raise OEmbedError("Invalid version: %s" % (result.get("version"),))
oembed_type = result.get("type")
# Ensure the cache age is None or an int.
cache_age = result.get("cache_age")
if cache_age:
cache_age = int(cache_age)
oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
# HTML content.
if oembed_type == "rich":
oembed_result.html = result.get("html")
return oembed_result
if oembed_type == "photo":
oembed_result.url = result.get("url")
return oembed_result
# TODO Handle link and video types.
if "thumbnail_url" in result:
oembed_result.url = result.get("thumbnail_url")
return oembed_result
raise OEmbedError("Incompatible oEmbed information.")
except OEmbedError as e:
# Trap OEmbedErrors first so we can directly re-raise them.
logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
raise
except Exception as e:
# Trap any exception and let the code follow as usual.
# FIXME: pass through 404s and other error messages nicely
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
# If this URL can be accessed via oEmbed, use that instead.
url_to_download = url
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
try:
logger.debug("Trying to get preview for url '%s'", url)
length, headers, uri, code = await self.client.get_file(
url,
output_stream=f,
max_size=self.max_spider_size,
headers={"Accept-Language": self.url_preview_accept_language},
)
except SynapseError:
# Pass SynapseErrors through directly, so that the servlet
# handler will return a SynapseError to the client instead of
# blank data or a 500.
raise
except DNSLookupError:
# DNS lookup returned no results
# Note: This will also be the case if one of the resolved IP
# addresses is blacklisted
raise SynapseError(
502,
"DNS resolution failure during URL preview generation",
Codes.UNKNOWN,
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
logger.warning("Error downloading %s: %r", url, e)
oembed_result = await self._get_oembed_content(oembed_url, url)
if oembed_result.url:
url_to_download = oembed_result.url
elif oembed_result.html:
url_to_download = None
except OEmbedError:
# If an error occurs, try doing a normal preview.
pass
raise SynapseError(
500,
"Failed to download content: %s"
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
await finish()
if url_to_download:
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get preview for url '%s'", url_to_download)
length, headers, uri, code = await self.client.get_file(
url_to_download,
output_stream=f,
max_size=self.max_spider_size,
headers={"Accept-Language": self.url_preview_accept_language},
)
except SynapseError:
# Pass SynapseErrors through directly, so that the servlet
# handler will return a SynapseError to the client instead of
# blank data or a 500.
raise
except DNSLookupError:
# DNS lookup returned no results
# Note: This will also be the case if one of the resolved IP
# addresses is blacklisted
raise SynapseError(
502,
"DNS resolution failure during URL preview generation",
Codes.UNKNOWN,
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
logger.warning("Error downloading %s: %r", url_to_download, e)
raise SynapseError(
500,
"Failed to download content: %s"
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
await finish()
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
download_name = get_filename_from_headers(headers)
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
etag = headers["ETag"][0] if "ETag" in headers else None
else:
html_bytes = oembed_result.html.encode("utf-8") # type: ignore
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
media_type = "text/html"
download_name = oembed_result.title
length = len(html_bytes)
# If a specific cache age was not given, assume 1 hour.
expires = oembed_result.cache_age or ONE_HOUR
uri = oembed_url
code = 200
etag = None
try:
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
download_name = get_filename_from_headers(headers)
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
time_now_ms=time_now_ms,
upload_name=download_name,
media_length=length,
user_id=user,
@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource):
"filename": fname,
"uri": uri,
"response_code": code,
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
"expires": 60 * 60 * 1000,
"etag": headers["ETag"][0] if "ETag" in headers else None,
"expires": expires,
"etag": etag,
}
def _start_expire_url_cache_data(self):
@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
expire_before = now - 2 * 24 * ONE_HOUR
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []

View file

@ -31,6 +31,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
from synapse.handlers.typing import FollowerTypingHandler
from synapse.replication.tcp.streams import Stream
class HomeServer(object):
@ -150,3 +151,5 @@ class HomeServer(object):
pass
def should_send_federation(self) -> bool:
pass
def get_typing_handler(self) -> FollowerTypingHandler:
pass

View file

@ -16,14 +16,12 @@
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set
from typing import Awaitable, Dict, Iterable, List, Optional, Set
import attr
from frozendict import frozendict
from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@ -108,8 +107,7 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def get_current_state(
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
@ -126,20 +124,20 @@ class StateHandler(object):
map from (type, state_key) to event
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = yield self.store.get_events(
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
@ -148,8 +146,7 @@ class StateHandler(object):
return state
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room
Args:
@ -164,41 +161,38 @@ class StateHandler(object):
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Args:
room_id (str): The ID of the room.
latest_event_ids (List[str]|None): Precomputed list of latest
event IDs. Will be computed if None.
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
profileinfo.
Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id):
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
async def get_current_hosts_in_room(self, room_id):
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
@defer.inlineCallbacks
def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
@ -208,12 +202,11 @@ class StateHandler(object):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
@defer.inlineCallbacks
def compute_event_context(
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
@ -278,7 +271,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events(
entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@ -295,7 +288,7 @@ class StateHandler(object):
#
if not state_group_before_event:
state_group_before_event = yield self.state_store.store_state_group(
state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
@ -335,7 +328,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
state_group_after_event = yield self.state_store.store_state_group(
state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
@ -353,8 +346,7 @@ class StateHandler(object):
)
@measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@ -373,7 +365,7 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = yield self.state_store.get_state_groups_ids(
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
@ -382,7 +374,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
@ -391,9 +383,9 @@ class StateHandler(object):
delta_ids=delta_ids,
)
room_version = yield self.store.get_room_version_id(room_id)
room_version = await self.store.get_room_version_id(room_id)
result = yield self._state_resolution_handler.resolve_state_groups(
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
@ -402,8 +394,7 @@ class StateHandler(object):
)
return result
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
async def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@ -414,7 +405,7 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
event.room_id,
room_version,
@ -451,9 +442,8 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@ -479,13 +469,13 @@ class StateResolutionHandler(object):
state_res_store (StateResolutionStore)
Returns:
Deferred[_StateCacheEntry]: resolved state
_StateCacheEntry: resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)):
with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
@ -517,7 +507,7 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:

View file

@ -15,9 +15,7 @@
import hashlib
import logging
from typing import Callable, Dict, List, Optional
from twisted.internet import defer
from typing import Awaitable, Callable, Dict, List, Optional
from synapse import event_auth
from synapse.api.constants import EventTypes
@ -32,12 +30,11 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
def resolve_events_with_store(
async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable,
state_map_factory: Callable[[List[str]], Awaitable],
):
"""
Args:
@ -56,7 +53,7 @@ def resolve_events_with_store(
state_map_factory: will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
@ -80,7 +77,7 @@ def resolve_events_with_store(
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@ -110,7 +107,7 @@ def resolve_events_with_store(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
state_map_new = yield state_map_factory(new_needed_events)
state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(

View file

@ -18,8 +18,6 @@ import itertools
import logging
from typing import Dict, List, Optional
from twisted.internet import defer
import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
@ -32,14 +30,13 @@ from synapse.util import Clock
logger = logging.getLogger(__name__)
# We want to yield to the reactor occasionally during state res when dealing
# We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
# yielding to reactor during loops every N iterations.
_YIELD_AFTER_ITERATIONS = 100
# awaiting to reactor during loops every N iterations.
_AWAIT_AFTER_ITERATIONS = 100
@defer.inlineCallbacks
def resolve_events_with_store(
async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
@ -87,7 +84,7 @@ def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
full_conflicted_set = set(
itertools.chain(
@ -95,7 +92,7 @@ def resolve_events_with_store(
)
)
events = yield state_res_store.get_events(
events = await state_res_store.get_events(
[eid for eid in full_conflicted_set if eid not in event_map],
allow_rejected=True,
)
@ -118,14 +115,14 @@ def resolve_events_with_store(
eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
sorted_power_events = yield _reverse_topological_power_sort(
sorted_power_events = await _reverse_topological_power_sort(
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@ -148,13 +145,13 @@ def resolve_events_with_store(
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
leftover_events = await _mainline_sort(
clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@ -174,8 +171,7 @@ def resolve_events_with_store(
return resolved_state
@defer.inlineCallbacks
def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
Returns:
Deferred[int]
"""
event = yield _get_event(room_id, event_id, event_map, state_res_store)
event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
aev = yield _get_event(
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
aev = yield _get_event(
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
return int(level)
@defer.inlineCallbacks
def _get_auth_chain_difference(state_sets, event_map, state_res_store):
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Deferred[set[str]]: Set of event IDs
"""
difference = yield state_res_store.get_auth_chain_difference(
difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
@ -292,8 +287,7 @@ def _is_power_event(event):
return False
@defer.inlineCallbacks
def _add_event_and_auth_chain_to_graph(
async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
"""Helper function for _reverse_topological_power_sort that add the event
@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop()
graph.setdefault(eid, set())
event = yield _get_event(room_id, eid, event_map, state_res_store)
event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph(
graph.setdefault(eid, set()).add(aid)
@defer.inlineCallbacks
def _reverse_topological_power_sort(
async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
@ -344,26 +337,26 @@ def _reverse_topological_power_sort(
graph = {}
for idx, event_id in enumerate(event_ids, start=1):
yield _add_event_and_auth_chain_to_graph(
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
event_to_pl = {}
for idx, event_id in enumerate(graph, start=1):
pl = yield _get_power_level_for_sender(
pl = await _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
def _get_power_order(event_id):
ev = event_map[event_id]
@ -378,8 +371,7 @@ def _reverse_topological_power_sort(
return sorted_events
@defer.inlineCallbacks
def _iterative_auth_checks(
async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
"""Sequentially apply auth checks to each event in given list, updating the
@ -405,7 +397,7 @@ def _iterative_auth_checks(
auth_events = {}
for aid in event.auth_event_ids():
ev = yield _get_event(
ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
@ -420,7 +412,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
@ -438,16 +430,15 @@ def _iterative_auth_checks(
except AuthError:
pass
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
return resolved_state
@defer.inlineCallbacks
def _mainline_sort(
async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
@ -474,21 +465,21 @@ def _mainline_sort(
idx = 0
while pl:
mainline.append(pl)
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
ev = yield _get_event(
ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
idx += 1
@ -498,23 +489,24 @@ def _mainline_sort(
order_map = {}
for idx, ev_id in enumerate(event_ids, start=1):
depth = yield _get_mainline_depth_for_event(
depth = await _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
# We yield occasionally when we're working with large data sets to
# We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
if idx % _AWAIT_AFTER_ITERATIONS == 0:
await clock.sleep(0)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
@defer.inlineCallbacks
def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
async def _get_mainline_depth_for_event(
event, mainline_map, event_map, state_res_store
):
"""Get the mainline depths for the given event based on the mainline map
Args:
@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None
for aid in auth_events:
aev = yield _get_event(
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
return 0
@defer.inlineCallbacks
def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
Deferred[Optional[FrozenEvent]]
"""
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
events = await state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
event = event_map.get(event_id)

View file

@ -259,7 +259,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)

View file

@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)

View file

@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
users_with_profile = yield state.get_current_users_in_room(room_id)
users_with_profile = yield defer.ensureDeferred(
state.get_current_users_in_room(room_id)
)
user_ids = set(users_with_profile)
# Update each user in the user directory.

Some files were not shown because too many files have changed in this diff Show more