Merge branch 'release-v1.18.0' into matrix-org-hotfixes
This commit is contained in:
commit
b2ccc72a00
93
CHANGES.md
93
CHANGES.md
|
@ -1,3 +1,96 @@
|
|||
Synapse 1.18.0rc1 (2020-07-27)
|
||||
==============================
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- Include room states on invite events that are sent to application services. Contributed by @Sorunome. ([\#6455](https://github.com/matrix-org/synapse/issues/6455))
|
||||
- Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7613](https://github.com/matrix-org/synapse/issues/7613), [\#7953](https://github.com/matrix-org/synapse/issues/7953))
|
||||
- Add experimental support for running multiple federation sender processes. ([\#7798](https://github.com/matrix-org/synapse/issues/7798))
|
||||
- Add the option to validate the `iss` and `aud` claims for JWT logins. ([\#7827](https://github.com/matrix-org/synapse/issues/7827))
|
||||
- Add support for handling registration requests across multiple client reader workers. ([\#7830](https://github.com/matrix-org/synapse/issues/7830))
|
||||
- Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7842](https://github.com/matrix-org/synapse/issues/7842))
|
||||
- Allow email subjects to be customised through Synapse's configuration. ([\#7846](https://github.com/matrix-org/synapse/issues/7846))
|
||||
- Add the ability to re-activate an account from the admin API. ([\#7847](https://github.com/matrix-org/synapse/issues/7847), [\#7908](https://github.com/matrix-org/synapse/issues/7908))
|
||||
- Add experimental support for running multiple pusher workers. ([\#7855](https://github.com/matrix-org/synapse/issues/7855))
|
||||
- Add experimental support for moving typing off master. ([\#7869](https://github.com/matrix-org/synapse/issues/7869), [\#7959](https://github.com/matrix-org/synapse/issues/7959))
|
||||
- Report CPU metrics to prometheus for time spent processing replication commands. ([\#7879](https://github.com/matrix-org/synapse/issues/7879))
|
||||
- Support oEmbed for media previews. ([\#7920](https://github.com/matrix-org/synapse/issues/7920))
|
||||
- Abort federation requests where the client disconnects before the ratelimiter expires. ([\#7930](https://github.com/matrix-org/synapse/issues/7930))
|
||||
- Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work. ([\#7931](https://github.com/matrix-org/synapse/issues/7931))
|
||||
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix detection of out of sync remote device lists when receiving events from remote users. ([\#7815](https://github.com/matrix-org/synapse/issues/7815))
|
||||
- Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain. ([\#7817](https://github.com/matrix-org/synapse/issues/7817))
|
||||
- Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0. ([\#7822](https://github.com/matrix-org/synapse/issues/7822))
|
||||
- Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. ([\#7829](https://github.com/matrix-org/synapse/issues/7829))
|
||||
- Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. ([\#7844](https://github.com/matrix-org/synapse/issues/7844))
|
||||
- Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0. ([\#7850](https://github.com/matrix-org/synapse/issues/7850))
|
||||
- Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. ([\#7854](https://github.com/matrix-org/synapse/issues/7854))
|
||||
- Fix a bug which allowed empty rooms to be rejoined over federation. ([\#7859](https://github.com/matrix-org/synapse/issues/7859))
|
||||
- Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. ([\#7866](https://github.com/matrix-org/synapse/issues/7866))
|
||||
- Fix a long standing bug where the tracing of async functions with opentracing was broken. ([\#7872](https://github.com/matrix-org/synapse/issues/7872), [\#7961](https://github.com/matrix-org/synapse/issues/7961))
|
||||
- Fix "TypeError in `synapse.notifier`" exceptions. ([\#7880](https://github.com/matrix-org/synapse/issues/7880))
|
||||
- Fix deprecation warning due to invalid escape sequences. ([\#7895](https://github.com/matrix-org/synapse/issues/7895))
|
||||
|
||||
|
||||
Updates to the Docker image
|
||||
---------------------------
|
||||
|
||||
- Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. ([\#7839](https://github.com/matrix-org/synapse/issues/7839))
|
||||
|
||||
|
||||
Improved Documentation
|
||||
----------------------
|
||||
|
||||
- Provide instructions on using `register_new_matrix_user` via docker. ([\#7885](https://github.com/matrix-org/synapse/issues/7885))
|
||||
- Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. ([\#7889](https://github.com/matrix-org/synapse/issues/7889))
|
||||
- Reorder database paragraphs to promote postgres over sqlite. ([\#7933](https://github.com/matrix-org/synapse/issues/7933))
|
||||
- Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md). ([\#7934](https://github.com/matrix-org/synapse/issues/7934))
|
||||
|
||||
|
||||
Deprecations and Removals
|
||||
-------------------------
|
||||
|
||||
- Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. ([\#7878](https://github.com/matrix-org/synapse/issues/7878))
|
||||
- Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. ([\#7888](https://github.com/matrix-org/synapse/issues/7888))
|
||||
|
||||
|
||||
Internal Changes
|
||||
----------------
|
||||
|
||||
- Switch parts of the codebase from `simplejson` to the standard library `json`. ([\#7802](https://github.com/matrix-org/synapse/issues/7802))
|
||||
- Add type hints to the http server code and remove an unused parameter. ([\#7813](https://github.com/matrix-org/synapse/issues/7813))
|
||||
- Add type hints to synapse.api.errors module. ([\#7820](https://github.com/matrix-org/synapse/issues/7820))
|
||||
- Ensure that calls to `json.dumps` are compatible with the standard library json. ([\#7836](https://github.com/matrix-org/synapse/issues/7836))
|
||||
- Remove redundant `retry_on_integrity_error` wrapper for event persistence code. ([\#7848](https://github.com/matrix-org/synapse/issues/7848))
|
||||
- Consistently use `db_to_json` to convert from database values to JSON objects. ([\#7849](https://github.com/matrix-org/synapse/issues/7849))
|
||||
- Convert various parts of the codebase to async/await. ([\#7851](https://github.com/matrix-org/synapse/issues/7851), [\#7860](https://github.com/matrix-org/synapse/issues/7860), [\#7868](https://github.com/matrix-org/synapse/issues/7868), [\#7871](https://github.com/matrix-org/synapse/issues/7871), [\#7873](https://github.com/matrix-org/synapse/issues/7873), [\#7874](https://github.com/matrix-org/synapse/issues/7874), [\#7884](https://github.com/matrix-org/synapse/issues/7884), [\#7912](https://github.com/matrix-org/synapse/issues/7912), [\#7935](https://github.com/matrix-org/synapse/issues/7935), [\#7939](https://github.com/matrix-org/synapse/issues/7939), [\#7942](https://github.com/matrix-org/synapse/issues/7942), [\#7944](https://github.com/matrix-org/synapse/issues/7944))
|
||||
- Add support for handling registration requests across multiple client reader workers. ([\#7853](https://github.com/matrix-org/synapse/issues/7853))
|
||||
- Small performance improvement in typing processing. ([\#7856](https://github.com/matrix-org/synapse/issues/7856))
|
||||
- The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. ([\#7858](https://github.com/matrix-org/synapse/issues/7858))
|
||||
- Optimise queueing of inbound replication commands. ([\#7861](https://github.com/matrix-org/synapse/issues/7861))
|
||||
- Add some type annotations to `HomeServer` and `BaseHandler`. ([\#7870](https://github.com/matrix-org/synapse/issues/7870))
|
||||
- Clean up `PreserveLoggingContext`. ([\#7877](https://github.com/matrix-org/synapse/issues/7877))
|
||||
- Change "unknown room version" logging from 'error' to 'warning'. ([\#7881](https://github.com/matrix-org/synapse/issues/7881))
|
||||
- Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. ([\#7882](https://github.com/matrix-org/synapse/issues/7882))
|
||||
- Return an empty body for OPTIONS requests. ([\#7886](https://github.com/matrix-org/synapse/issues/7886))
|
||||
- Fix typo in generated config file. Contributed by @ThiefMaster. ([\#7890](https://github.com/matrix-org/synapse/issues/7890))
|
||||
- Import ABC from `collections.abc` for Python 3.10 compatibility. ([\#7892](https://github.com/matrix-org/synapse/issues/7892))
|
||||
- Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
|
||||
and `get_previous_frame` from `synapse.logging.utils` module. ([\#7897](https://github.com/matrix-org/synapse/issues/7897))
|
||||
- Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. ([\#7914](https://github.com/matrix-org/synapse/issues/7914))
|
||||
- Use Element CSS and logo in notification emails when app name is Element. ([\#7919](https://github.com/matrix-org/synapse/issues/7919))
|
||||
- Optimisation to /sync handling: skip serializing the response if the client has already disconnected. ([\#7927](https://github.com/matrix-org/synapse/issues/7927))
|
||||
- When a client disconnects, don't log it as 'Error processing request'. ([\#7928](https://github.com/matrix-org/synapse/issues/7928))
|
||||
- Add debugging to `/sync` response generation (disabled by default). ([\#7929](https://github.com/matrix-org/synapse/issues/7929))
|
||||
- Update comments that refer to Deferreds for async functions. ([\#7945](https://github.com/matrix-org/synapse/issues/7945))
|
||||
- Simplify error handling in federation handler. ([\#7950](https://github.com/matrix-org/synapse/issues/7950))
|
||||
|
||||
|
||||
Synapse 1.17.0 (2020-07-13)
|
||||
===========================
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Include room states on invite events that are sent to application services. Contributed by @Sorunome.
|
|
@ -1 +0,0 @@
|
|||
Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel.
|
|
@ -1 +0,0 @@
|
|||
Add experimental support for running multiple federation sender processes.
|
|
@ -1 +0,0 @@
|
|||
Switch from simplejson to the standard library json.
|
|
@ -1 +0,0 @@
|
|||
Add type hints to the http server code and remove an unused parameter.
|
|
@ -1 +0,0 @@
|
|||
Fix detection of out of sync remote device lists when receiving events from remote users.
|
|
@ -1 +0,0 @@
|
|||
Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain.
|
|
@ -1 +0,0 @@
|
|||
Add type hints to synapse.api.errors module.
|
|
@ -1 +0,0 @@
|
|||
Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0.
|
|
@ -1 +0,0 @@
|
|||
Add the option to validate the `iss` and `aud` claims for JWT logins.
|
|
@ -1 +0,0 @@
|
|||
Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails.
|
|
@ -1 +0,0 @@
|
|||
Add support for handling registration requests across multiple client reader workers.
|
|
@ -1 +0,0 @@
|
|||
Ensure that calls to `json.dumps` are compatible with the standard library json.
|
|
@ -1 +0,0 @@
|
|||
Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196.
|
|
@ -1 +0,0 @@
|
|||
Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH.
|
|
@ -1 +0,0 @@
|
|||
Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.
|
|
@ -1 +0,0 @@
|
|||
Allow email subjects to be customised through Synapse's configuration.
|
|
@ -1 +0,0 @@
|
|||
Add the ability to re-activate an account from the admin API.
|
|
@ -1 +0,0 @@
|
|||
Remove redundant `retry_on_integrity_error` wrapper for event persistence code.
|
|
@ -1 +0,0 @@
|
|||
Consistently use `db_to_json` to convert from database values to JSON objects.
|
|
@ -1 +0,0 @@
|
|||
Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0.
|
|
@ -1 +0,0 @@
|
|||
Convert E2E keys and room keys handlers to async/await.
|
|
@ -1 +0,0 @@
|
|||
Add support for handling registration requests across multiple client reader workers.
|
|
@ -1 +0,0 @@
|
|||
Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.
|
|
@ -1 +0,0 @@
|
|||
Add experimental support for running multiple pusher workers.
|
|
@ -1 +0,0 @@
|
|||
Small performance improvement in typing processing.
|
|
@ -1 +0,0 @@
|
|||
The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100.
|
|
@ -1 +0,0 @@
|
|||
Fix a bug which allowed empty rooms to be rejoined over federation.
|
|
@ -1 +0,0 @@
|
|||
Convert _base, profile, and _receipts handlers to async/await.
|
|
@ -1 +0,0 @@
|
|||
Optimise queueing of inbound replication commands.
|
|
@ -1 +0,0 @@
|
|||
Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.
|
|
@ -1 +0,0 @@
|
|||
Convert synapse.app and federation client to async/await.
|
|
@ -1 +0,0 @@
|
|||
Add experimental support for moving typing off master.
|
|
@ -1 +0,0 @@
|
|||
Add some type annotations to `HomeServer` and `BaseHandler`.
|
|
@ -1 +0,0 @@
|
|||
Convert device handler to async/await.
|
|
@ -1 +0,0 @@
|
|||
Fix a long standing bug where the tracing of async functions with opentracing was broken.
|
|
@ -1 +0,0 @@
|
|||
Convert the federation agent and related code to async/await.
|
1
changelog.d/7876.bugfix
Normal file
1
changelog.d/7876.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix an `AssertionError` exception introduced in v1.18.0rc1.
|
1
changelog.d/7876.misc
Normal file
1
changelog.d/7876.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Further optimise queueing of inbound replication commands.
|
|
@ -1 +0,0 @@
|
|||
Clean up `PreserveLoggingContext`.
|
|
@ -1 +0,0 @@
|
|||
Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric.
|
|
@ -1 +0,0 @@
|
|||
Report CPU metrics to prometheus for time spent processing replication commands.
|
|
@ -1 +0,0 @@
|
|||
Fix "TypeError in `synapse.notifier`" exceptions.
|
|
@ -1 +0,0 @@
|
|||
Change "unknown room version" logging from 'error' to 'warning'.
|
|
@ -1 +0,0 @@
|
|||
Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`.
|
|
@ -1 +0,0 @@
|
|||
Convert the message handler to async/await.
|
|
@ -1 +0,0 @@
|
|||
Provide instructions on using `register_new_matrix_user` via docker.
|
|
@ -1 +0,0 @@
|
|||
Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim.
|
|
@ -1 +0,0 @@
|
|||
Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation.
|
|
@ -1 +0,0 @@
|
|||
Fix typo in generated config file. Contributed by @ThiefMaster.
|
|
@ -1 +0,0 @@
|
|||
Import ABC from `collections.abc` for Python 3.10 compatibility.
|
|
@ -1 +0,0 @@
|
|||
Fix deprecation warning due to invalid escape sequences.
|
|
@ -1,2 +0,0 @@
|
|||
Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
|
||||
and `get_previous_frame` from `synapse.logging.utils` module.
|
|
@ -1 +0,0 @@
|
|||
Add the ability to re-activate an account from the admin API.
|
|
@ -1 +0,0 @@
|
|||
Convert `RoomListHandler` to async/await.
|
|
@ -1 +0,0 @@
|
|||
Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI.
|
|
@ -1 +0,0 @@
|
|||
Use Element CSS and logo in notification emails when app name is Element.
|
|
@ -1 +0,0 @@
|
|||
Optimisation to /sync handling: skip serializing the response if the client has already disconnected.
|
|
@ -1 +0,0 @@
|
|||
When a client disconnects, don't log it as 'Error processing request'.
|
|
@ -1 +0,0 @@
|
|||
Add debugging to `/sync` response generation (disabled by default).
|
|
@ -1 +0,0 @@
|
|||
Abort federation requests where the client disconnects before the ratelimiter expires.
|
|
@ -1 +0,0 @@
|
|||
Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work.
|
|
@ -1 +0,0 @@
|
|||
Reorder database paragraphs to promote postgres over sqlite.
|
|
@ -1 +0,0 @@
|
|||
Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md).
|
|
@ -1 +0,0 @@
|
|||
Convert the auth providers to be async/await.
|
|
@ -1 +0,0 @@
|
|||
Convert presence handler helpers to async/await.
|
|
@ -36,7 +36,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.17.0"
|
||||
__version__ = "1.18.0rc1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
|
|
@ -127,8 +127,10 @@ class Auth(object):
|
|||
if current_state:
|
||||
member = current_state.get((EventTypes.Member, user_id), None)
|
||||
else:
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
member = yield defer.ensureDeferred(
|
||||
self.state.get_current_state(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
)
|
||||
membership = member.membership if member else None
|
||||
|
||||
|
@ -665,8 +667,10 @@ class Auth(object):
|
|||
)
|
||||
return member_event.membership, member_event.event_id
|
||||
except AuthError:
|
||||
visibility = yield self.state.get_current_state(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
visibility = yield defer.ensureDeferred(
|
||||
self.state.get_current_state(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
)
|
||||
if (
|
||||
visibility
|
||||
|
|
|
@ -87,7 +87,6 @@ from synapse.replication.tcp.streams import (
|
|||
ReceiptsStream,
|
||||
TagAccountDataStream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
)
|
||||
from synapse.rest.admin import register_servlets_for_media_repo
|
||||
from synapse.rest.client.v1 import events
|
||||
|
@ -644,7 +643,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
|||
super(GenericWorkerReplicationHandler, self).__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
|
@ -681,11 +679,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
|||
await self.pusher_pool.on_new_receipts(
|
||||
token, token, {row.room_id for row in rows}
|
||||
)
|
||||
elif stream_name == TypingStream.NAME:
|
||||
self.typing_handler.process_replication_rows(token, rows)
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
elif stream_name == ToDeviceStream.NAME:
|
||||
entities = [row.entity for row in rows if row.entity.startswith("@")]
|
||||
if entities:
|
||||
|
|
|
@ -106,8 +106,8 @@ class EventBuilder(object):
|
|||
Deferred[FrozenEvent]
|
||||
"""
|
||||
|
||||
state_ids = yield self._state.get_current_state_ids(
|
||||
self.room_id, prev_event_ids
|
||||
state_ids = yield defer.ensureDeferred(
|
||||
self._state.get_current_state_ids(self.room_id, prev_event_ids)
|
||||
)
|
||||
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
|
||||
|
||||
|
|
|
@ -348,7 +348,9 @@ class FederationSender(object):
|
|||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains = yield self.state.get_current_hosts_in_room(room_id)
|
||||
domains = yield defer.ensureDeferred(
|
||||
self.state.get_current_hosts_in_room(room_id)
|
||||
)
|
||||
domains = [
|
||||
d
|
||||
for d in domains
|
||||
|
|
|
@ -72,7 +72,7 @@ class AdminHandler(BaseHandler):
|
|||
writer (ExfiltrationWriter)
|
||||
|
||||
Returns:
|
||||
defer.Deferred: Resolves when all data for a user has been written.
|
||||
Resolves when all data for a user has been written.
|
||||
The returned value is that returned by `writer.finished()`.
|
||||
"""
|
||||
# Get all rooms the user is in or has been in
|
||||
|
|
|
@ -16,10 +16,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.key import VerifyKey, decode_verify_key_bytes
|
||||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
|
@ -265,7 +266,9 @@ class E2eKeysHandler(object):
|
|||
|
||||
return ret
|
||||
|
||||
async def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
||||
async def get_cross_signing_keys_from_cache(
|
||||
self, query, from_user_id
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Get cross-signing keys for users from the database
|
||||
|
||||
Args:
|
||||
|
@ -277,8 +280,7 @@ class E2eKeysHandler(object):
|
|||
can see.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
||||
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
|
||||
A map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
|
||||
"""
|
||||
master_keys = {}
|
||||
self_signing_keys = {}
|
||||
|
@ -312,16 +314,17 @@ class E2eKeysHandler(object):
|
|||
}
|
||||
|
||||
@trace
|
||||
async def query_local_devices(self, query):
|
||||
async def query_local_devices(
|
||||
self, query: Dict[str, Optional[List[str]]]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
Args:
|
||||
query (dict[string, list[string]|None): map from user_id to a list
|
||||
query: map from user_id to a list
|
||||
of devices to query (None for all devices)
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
|
||||
map from user_id -> device_id -> device details
|
||||
A map from user_id -> device_id -> device details
|
||||
"""
|
||||
set_tag("local_query", query)
|
||||
local_query = []
|
||||
|
@ -1004,7 +1007,7 @@ class E2eKeysHandler(object):
|
|||
|
||||
async def _retrieve_cross_signing_keys_for_remote_user(
|
||||
self, user: UserID, desired_key_type: str,
|
||||
):
|
||||
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
|
||||
"""Queries cross-signing keys for a remote user and saves them to the database
|
||||
|
||||
Only the key specified by `key_type` will be returned, while all retrieved keys
|
||||
|
@ -1015,8 +1018,7 @@ class E2eKeysHandler(object):
|
|||
desired_key_type: The type of key to receive. One of "master", "self_signing"
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
|
||||
of the retrieved key content, the key's ID and the matching VerifyKey.
|
||||
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
|
||||
If the key cannot be retrieved, all values in the tuple will instead be None.
|
||||
"""
|
||||
try:
|
||||
|
|
|
@ -1394,7 +1394,7 @@ class FederationHandler(BaseHandler):
|
|||
# it's just a best-effort thing at this point. We do want to do
|
||||
# them roughly in order, though, otherwise we'll end up making
|
||||
# lots of requests for missing prev_events which we do actually
|
||||
# have. Hence we fire off the deferred, but don't wait for it.
|
||||
# have. Hence we fire off the background task, but don't wait for it.
|
||||
|
||||
run_in_background(self._handle_queued_pdus, room_queue)
|
||||
|
||||
|
@ -1887,9 +1887,6 @@ class FederationHandler(BaseHandler):
|
|||
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
|
||||
)
|
||||
|
||||
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
|
||||
# hack around with a try/finally instead.
|
||||
success = False
|
||||
try:
|
||||
if (
|
||||
not event.internal_metadata.is_outlier()
|
||||
|
@ -1903,12 +1900,11 @@ class FederationHandler(BaseHandler):
|
|||
await self.persist_events_and_notify(
|
||||
[(event, context)], backfilled=backfilled
|
||||
)
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
run_in_background(
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
except Exception:
|
||||
run_in_background(
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
raise
|
||||
|
||||
return context
|
||||
|
||||
|
@ -2994,7 +2990,9 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
user_joined_room(self.distributor, user, room_id)
|
||||
|
||||
async def get_room_complexity(self, remote_room_hosts, room_id):
|
||||
async def get_room_complexity(
|
||||
self, remote_room_hosts: List[str], room_id: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Fetch the complexity of a remote room over federation.
|
||||
|
||||
|
@ -3003,7 +3001,7 @@ class FederationHandler(BaseHandler):
|
|||
room_id (str): The room ID to ask about.
|
||||
|
||||
Returns:
|
||||
Deferred[dict] or Deferred[None]: Dict contains the complexity
|
||||
Dict contains the complexity
|
||||
metric versions, while None means we could not fetch the complexity.
|
||||
"""
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
@ -36,6 +37,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util.hash import sha256_and_url_safe_base64
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
|
||||
|
@ -59,23 +61,23 @@ class IdentityHandler(BaseHandler):
|
|||
self.federation_http_client = hs.get_http_client()
|
||||
self.hs = hs
|
||||
|
||||
async def threepid_from_creds(self, id_server, creds):
|
||||
async def threepid_from_creds(
|
||||
self, id_server: str, creds: Dict[str, str]
|
||||
) -> Optional[JsonDict]:
|
||||
"""
|
||||
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
|
||||
given identity server
|
||||
|
||||
Args:
|
||||
id_server (str): The identity server to validate 3PIDs against. Must be a
|
||||
id_server: The identity server to validate 3PIDs against. Must be a
|
||||
complete URL including the protocol (http(s)://)
|
||||
|
||||
creds (dict[str, str]): Dictionary containing the following keys:
|
||||
creds: Dictionary containing the following keys:
|
||||
* client_secret|clientSecret: A unique secret str provided by the client
|
||||
* sid: The ID of the validation session
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
|
||||
the /getValidated3pid endpoint of the Identity Service API, or None if the
|
||||
threepid was not found
|
||||
A dictionary consisting of response params to the /getValidated3pid
|
||||
endpoint of the Identity Service API, or None if the threepid was not found
|
||||
"""
|
||||
client_secret = creds.get("client_secret") or creds.get("clientSecret")
|
||||
if not client_secret:
|
||||
|
@ -119,26 +121,27 @@ class IdentityHandler(BaseHandler):
|
|||
return None
|
||||
|
||||
async def bind_threepid(
|
||||
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
|
||||
):
|
||||
self,
|
||||
client_secret: str,
|
||||
sid: str,
|
||||
mxid: str,
|
||||
id_server: str,
|
||||
id_access_token: Optional[str] = None,
|
||||
use_v2: bool = True,
|
||||
) -> JsonDict:
|
||||
"""Bind a 3PID to an identity server
|
||||
|
||||
Args:
|
||||
client_secret (str): A unique secret provided by the client
|
||||
|
||||
sid (str): The ID of the validation session
|
||||
|
||||
mxid (str): The MXID to bind the 3PID to
|
||||
|
||||
id_server (str): The domain of the identity server to query
|
||||
|
||||
id_access_token (str): The access token to authenticate to the identity
|
||||
client_secret: A unique secret provided by the client
|
||||
sid: The ID of the validation session
|
||||
mxid: The MXID to bind the 3PID to
|
||||
id_server: The domain of the identity server to query
|
||||
id_access_token: The access token to authenticate to the identity
|
||||
server with, if necessary. Required if use_v2 is true
|
||||
|
||||
use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True
|
||||
use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The response from the identity server
|
||||
The response from the identity server
|
||||
"""
|
||||
logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
|
||||
|
||||
|
@ -151,7 +154,7 @@ class IdentityHandler(BaseHandler):
|
|||
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
|
||||
if use_v2:
|
||||
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
|
||||
headers["Authorization"] = create_id_access_token_header(id_access_token)
|
||||
headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
|
||||
else:
|
||||
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
|
||||
|
||||
|
@ -187,20 +190,20 @@ class IdentityHandler(BaseHandler):
|
|||
)
|
||||
return res
|
||||
|
||||
async def try_unbind_threepid(self, mxid, threepid):
|
||||
async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
|
||||
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
|
||||
identity servers we're aware the binding is present on
|
||||
|
||||
Args:
|
||||
mxid (str): Matrix user ID of binding to be removed
|
||||
threepid (dict): Dict with medium & address of binding to be
|
||||
mxid: Matrix user ID of binding to be removed
|
||||
threepid: Dict with medium & address of binding to be
|
||||
removed, and an optional id_server.
|
||||
|
||||
Raises:
|
||||
SynapseError: If we failed to contact the identity server
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True on success, otherwise False if the identity
|
||||
True on success, otherwise False if the identity
|
||||
server doesn't support unbinding (or no identity server found to
|
||||
contact).
|
||||
"""
|
||||
|
@ -223,19 +226,21 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
return changed
|
||||
|
||||
async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
|
||||
async def try_unbind_threepid_with_id_server(
|
||||
self, mxid: str, threepid: dict, id_server: str
|
||||
) -> bool:
|
||||
"""Removes a binding from an identity server
|
||||
|
||||
Args:
|
||||
mxid (str): Matrix user ID of binding to be removed
|
||||
threepid (dict): Dict with medium & address of binding to be removed
|
||||
id_server (str): Identity server to unbind from
|
||||
mxid: Matrix user ID of binding to be removed
|
||||
threepid: Dict with medium & address of binding to be removed
|
||||
id_server: Identity server to unbind from
|
||||
|
||||
Raises:
|
||||
SynapseError: If we failed to contact the identity server
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True on success, otherwise False if the identity
|
||||
True on success, otherwise False if the identity
|
||||
server doesn't support unbinding
|
||||
"""
|
||||
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
|
||||
|
@ -287,23 +292,23 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
async def send_threepid_validation(
|
||||
self,
|
||||
email_address,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
send_email_func,
|
||||
next_link=None,
|
||||
):
|
||||
email_address: str,
|
||||
client_secret: str,
|
||||
send_attempt: int,
|
||||
send_email_func: Callable[[str, str, str, str], Awaitable],
|
||||
next_link: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Send a threepid validation email for password reset or
|
||||
registration purposes
|
||||
|
||||
Args:
|
||||
email_address (str): The user's email address
|
||||
client_secret (str): The provided client secret
|
||||
send_attempt (int): Which send attempt this is
|
||||
send_email_func (func): A function that takes an email address, token,
|
||||
client_secret and session_id, sends an email
|
||||
and returns a Deferred.
|
||||
next_link (str|None): The URL to redirect the user to after validation
|
||||
email_address: The user's email address
|
||||
client_secret: The provided client secret
|
||||
send_attempt: Which send attempt this is
|
||||
send_email_func: A function that takes an email address, token,
|
||||
client_secret and session_id, sends an email
|
||||
and returns an Awaitable.
|
||||
next_link: The URL to redirect the user to after validation
|
||||
|
||||
Returns:
|
||||
The new session_id upon success
|
||||
|
@ -372,17 +377,22 @@ class IdentityHandler(BaseHandler):
|
|||
return session_id
|
||||
|
||||
async def requestEmailToken(
|
||||
self, id_server, email, client_secret, send_attempt, next_link=None
|
||||
):
|
||||
self,
|
||||
id_server: str,
|
||||
email: str,
|
||||
client_secret: str,
|
||||
send_attempt: int,
|
||||
next_link: Optional[str] = None,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Request an external server send an email on our behalf for the purposes of threepid
|
||||
validation.
|
||||
|
||||
Args:
|
||||
id_server (str): The identity server to proxy to
|
||||
email (str): The email to send the message to
|
||||
client_secret (str): The unique client_secret sends by the user
|
||||
send_attempt (int): Which attempt this is
|
||||
id_server: The identity server to proxy to
|
||||
email: The email to send the message to
|
||||
client_secret: The unique client_secret sends by the user
|
||||
send_attempt: Which attempt this is
|
||||
next_link: A link to redirect the user to once they submit the token
|
||||
|
||||
Returns:
|
||||
|
@ -419,22 +429,22 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
async def requestMsisdnToken(
|
||||
self,
|
||||
id_server,
|
||||
country,
|
||||
phone_number,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
next_link=None,
|
||||
):
|
||||
id_server: str,
|
||||
country: str,
|
||||
phone_number: str,
|
||||
client_secret: str,
|
||||
send_attempt: int,
|
||||
next_link: Optional[str] = None,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Request an external server send an SMS message on our behalf for the purposes of
|
||||
threepid validation.
|
||||
Args:
|
||||
id_server (str): The identity server to proxy to
|
||||
country (str): The country code of the phone number
|
||||
phone_number (str): The number to send the message to
|
||||
client_secret (str): The unique client_secret sends by the user
|
||||
send_attempt (int): Which attempt this is
|
||||
id_server: The identity server to proxy to
|
||||
country: The country code of the phone number
|
||||
phone_number: The number to send the message to
|
||||
client_secret: The unique client_secret sends by the user
|
||||
send_attempt: Which attempt this is
|
||||
next_link: A link to redirect the user to once they submit the token
|
||||
|
||||
Returns:
|
||||
|
@ -480,17 +490,18 @@ class IdentityHandler(BaseHandler):
|
|||
)
|
||||
return data
|
||||
|
||||
async def validate_threepid_session(self, client_secret, sid):
|
||||
async def validate_threepid_session(
|
||||
self, client_secret: str, sid: str
|
||||
) -> Optional[JsonDict]:
|
||||
"""Validates a threepid session with only the client secret and session ID
|
||||
Tries validating against any configured account_threepid_delegates as well as locally.
|
||||
|
||||
Args:
|
||||
client_secret (str): A secret provided by the client
|
||||
|
||||
sid (str): The ID of the session
|
||||
client_secret: A secret provided by the client
|
||||
sid: The ID of the session
|
||||
|
||||
Returns:
|
||||
Dict[str, str|int] if validation was successful, otherwise None
|
||||
The json response if validation was successful, otherwise None
|
||||
"""
|
||||
# XXX: We shouldn't need to keep wrapping and unwrapping this value
|
||||
threepid_creds = {"client_secret": client_secret, "sid": sid}
|
||||
|
@ -523,23 +534,22 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
return validation_session
|
||||
|
||||
async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
|
||||
async def proxy_msisdn_submit_token(
|
||||
self, id_server: str, client_secret: str, sid: str, token: str
|
||||
) -> JsonDict:
|
||||
"""Proxy a POST submitToken request to an identity server for verification purposes
|
||||
|
||||
Args:
|
||||
id_server (str): The identity server URL to contact
|
||||
|
||||
client_secret (str): Secret provided by the client
|
||||
|
||||
sid (str): The ID of the session
|
||||
|
||||
token (str): The verification token
|
||||
id_server: The identity server URL to contact
|
||||
client_secret: Secret provided by the client
|
||||
sid: The ID of the session
|
||||
token: The verification token
|
||||
|
||||
Raises:
|
||||
SynapseError: If we failed to contact the identity server
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The response dict from the identity server
|
||||
The response dict from the identity server
|
||||
"""
|
||||
body = {"client_secret": client_secret, "sid": sid, "token": token}
|
||||
|
||||
|
@ -554,19 +564,25 @@ class IdentityHandler(BaseHandler):
|
|||
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
|
||||
raise SynapseError(400, "Error contacting the identity server")
|
||||
|
||||
async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
|
||||
async def lookup_3pid(
|
||||
self,
|
||||
id_server: str,
|
||||
medium: str,
|
||||
address: str,
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Looks up a 3pid in the passed identity server.
|
||||
|
||||
Args:
|
||||
id_server (str): The server name (including port, if required)
|
||||
id_server: The server name (including port, if required)
|
||||
of the identity server to use.
|
||||
medium (str): The type of the third party identifier (e.g. "email").
|
||||
address (str): The third party identifier (e.g. "foo@example.com").
|
||||
id_access_token (str|None): The access token to authenticate to the identity
|
||||
medium: The type of the third party identifier (e.g. "email").
|
||||
address: The third party identifier (e.g. "foo@example.com").
|
||||
id_access_token: The access token to authenticate to the identity
|
||||
server with
|
||||
|
||||
Returns:
|
||||
str|None: the matrix ID of the 3pid, or None if it is not recognized.
|
||||
the matrix ID of the 3pid, or None if it is not recognized.
|
||||
"""
|
||||
if id_access_token is not None:
|
||||
try:
|
||||
|
@ -591,17 +607,19 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
return await self._lookup_3pid_v1(id_server, medium, address)
|
||||
|
||||
async def _lookup_3pid_v1(self, id_server, medium, address):
|
||||
async def _lookup_3pid_v1(
|
||||
self, id_server: str, medium: str, address: str
|
||||
) -> Optional[str]:
|
||||
"""Looks up a 3pid in the passed identity server using v1 lookup.
|
||||
|
||||
Args:
|
||||
id_server (str): The server name (including port, if required)
|
||||
id_server: The server name (including port, if required)
|
||||
of the identity server to use.
|
||||
medium (str): The type of the third party identifier (e.g. "email").
|
||||
address (str): The third party identifier (e.g. "foo@example.com").
|
||||
medium: The type of the third party identifier (e.g. "email").
|
||||
address: The third party identifier (e.g. "foo@example.com").
|
||||
|
||||
Returns:
|
||||
str: the matrix ID of the 3pid, or None if it is not recognized.
|
||||
the matrix ID of the 3pid, or None if it is not recognized.
|
||||
"""
|
||||
try:
|
||||
data = await self.blacklisting_http_client.get_json(
|
||||
|
@ -621,18 +639,20 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
return None
|
||||
|
||||
async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
|
||||
async def _lookup_3pid_v2(
|
||||
self, id_server: str, id_access_token: str, medium: str, address: str
|
||||
) -> Optional[str]:
|
||||
"""Looks up a 3pid in the passed identity server using v2 lookup.
|
||||
|
||||
Args:
|
||||
id_server (str): The server name (including port, if required)
|
||||
id_server: The server name (including port, if required)
|
||||
of the identity server to use.
|
||||
id_access_token (str): The access token to authenticate to the identity server with
|
||||
medium (str): The type of the third party identifier (e.g. "email").
|
||||
address (str): The third party identifier (e.g. "foo@example.com").
|
||||
id_access_token: The access token to authenticate to the identity server with
|
||||
medium: The type of the third party identifier (e.g. "email").
|
||||
address: The third party identifier (e.g. "foo@example.com").
|
||||
|
||||
Returns:
|
||||
Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
|
||||
the matrix ID of the 3pid, or None if it is not recognised.
|
||||
"""
|
||||
# Check what hashing details are supported by this identity server
|
||||
try:
|
||||
|
@ -757,49 +777,48 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
async def ask_id_server_for_third_party_invite(
|
||||
self,
|
||||
requester,
|
||||
id_server,
|
||||
medium,
|
||||
address,
|
||||
room_id,
|
||||
inviter_user_id,
|
||||
room_alias,
|
||||
room_avatar_url,
|
||||
room_join_rules,
|
||||
room_name,
|
||||
inviter_display_name,
|
||||
inviter_avatar_url,
|
||||
id_access_token=None,
|
||||
):
|
||||
requester: Requester,
|
||||
id_server: str,
|
||||
medium: str,
|
||||
address: str,
|
||||
room_id: str,
|
||||
inviter_user_id: str,
|
||||
room_alias: str,
|
||||
room_avatar_url: str,
|
||||
room_join_rules: str,
|
||||
room_name: str,
|
||||
inviter_display_name: str,
|
||||
inviter_avatar_url: str,
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
|
||||
"""
|
||||
Asks an identity server for a third party invite.
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
id_server (str): hostname + optional port for the identity server.
|
||||
medium (str): The literal string "email".
|
||||
address (str): The third party address being invited.
|
||||
room_id (str): The ID of the room to which the user is invited.
|
||||
inviter_user_id (str): The user ID of the inviter.
|
||||
room_alias (str): An alias for the room, for cosmetic notifications.
|
||||
room_avatar_url (str): The URL of the room's avatar, for cosmetic
|
||||
requester
|
||||
id_server: hostname + optional port for the identity server.
|
||||
medium: The literal string "email".
|
||||
address: The third party address being invited.
|
||||
room_id: The ID of the room to which the user is invited.
|
||||
inviter_user_id: The user ID of the inviter.
|
||||
room_alias: An alias for the room, for cosmetic notifications.
|
||||
room_avatar_url: The URL of the room's avatar, for cosmetic
|
||||
notifications.
|
||||
room_join_rules (str): The join rules of the email (e.g. "public").
|
||||
room_name (str): The m.room.name of the room.
|
||||
inviter_display_name (str): The current display name of the
|
||||
room_join_rules: The join rules of the email (e.g. "public").
|
||||
room_name: The m.room.name of the room.
|
||||
inviter_display_name: The current display name of the
|
||||
inviter.
|
||||
inviter_avatar_url (str): The URL of the inviter's avatar.
|
||||
inviter_avatar_url: The URL of the inviter's avatar.
|
||||
id_access_token (str|None): The access token to authenticate to the identity
|
||||
server with
|
||||
|
||||
Returns:
|
||||
A deferred tuple containing:
|
||||
token (str): The token which must be signed to prove authenticity.
|
||||
A tuple containing:
|
||||
token: The token which must be signed to prove authenticity.
|
||||
public_keys ([{"public_key": str, "key_validity_url": str}]):
|
||||
public_key is a base64-encoded ed25519 public key.
|
||||
fallback_public_key: One element from public_keys.
|
||||
display_name (str): A user-friendly name to represent the invited
|
||||
user.
|
||||
display_name: A user-friendly name to represent the invited user.
|
||||
"""
|
||||
invite_config = {
|
||||
"medium": medium,
|
||||
|
@ -896,15 +915,15 @@ class IdentityHandler(BaseHandler):
|
|||
return token, public_keys, fallback_public_key, display_name
|
||||
|
||||
|
||||
def create_id_access_token_header(id_access_token):
|
||||
def create_id_access_token_header(id_access_token: str) -> List[str]:
|
||||
"""Create an Authorization header for passing to SimpleHttpClient as the header value
|
||||
of an HTTP request.
|
||||
|
||||
Args:
|
||||
id_access_token (str): An identity server access token.
|
||||
id_access_token: An identity server access token.
|
||||
|
||||
Returns:
|
||||
list[str]: The ascii-encoded bearer token encased in a list.
|
||||
The ascii-encoded bearer token encased in a list.
|
||||
"""
|
||||
# Prefix with Bearer
|
||||
bearer_token = "Bearer %s" % id_access_token
|
||||
|
|
|
@ -859,9 +859,6 @@ class EventCreationHandler(object):
|
|||
|
||||
await self.action_generator.handle_push_actions_for_event(event, context)
|
||||
|
||||
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
|
||||
# hack around with a try/finally instead.
|
||||
success = False
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
if not self._is_event_writer:
|
||||
|
@ -877,22 +874,20 @@ class EventCreationHandler(object):
|
|||
)
|
||||
stream_id = result["stream_id"]
|
||||
event.internal_metadata.stream_ordering = stream_id
|
||||
success = True
|
||||
return stream_id
|
||||
|
||||
stream_id = await self.persist_and_notify_client_event(
|
||||
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
||||
)
|
||||
|
||||
success = True
|
||||
return stream_id
|
||||
finally:
|
||||
if not success:
|
||||
# Ensure that we actually remove the entries in the push actions
|
||||
# staging area, if we calculated them.
|
||||
run_in_background(
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
except Exception:
|
||||
# Ensure that we actually remove the entries in the push actions
|
||||
# staging area, if we calculated them.
|
||||
run_in_background(
|
||||
self.store.remove_push_actions_from_staging, event.event_id
|
||||
)
|
||||
raise
|
||||
|
||||
async def _validate_canonical_alias(
|
||||
self, directory_handler, room_alias_str: str, expected_room_id: str
|
||||
|
|
|
@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
|
|||
# TODO: Check that this is actually a new server joining the
|
||||
# room.
|
||||
|
||||
user_ids = await self.state.get_current_users_in_room(room_id)
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
user_ids = list(filter(self.is_mine_id, users))
|
||||
|
||||
states_d = await self.current_state_for_users(user_ids)
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
async def upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
):
|
||||
) -> str:
|
||||
"""Replace a room with a new room with a different version
|
||||
|
||||
Args:
|
||||
|
@ -128,7 +128,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
new_version: the new room version to use
|
||||
|
||||
Returns:
|
||||
Deferred[unicode]: the new room id
|
||||
the new room id
|
||||
"""
|
||||
await self.ratelimit(requester)
|
||||
|
||||
|
@ -239,7 +239,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
old_room_id: str,
|
||||
new_room_id: str,
|
||||
old_room_state: StateMap[str],
|
||||
):
|
||||
) -> None:
|
||||
"""Send updated power levels in both rooms after an upgrade
|
||||
|
||||
Args:
|
||||
|
@ -247,9 +247,6 @@ class RoomCreationHandler(BaseHandler):
|
|||
old_room_id: the id of the room to be replaced
|
||||
new_room_id: the id of the replacement room
|
||||
old_room_state: the state map for the old room
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
|
@ -322,7 +319,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
new_room_id: str,
|
||||
new_room_version: RoomVersion,
|
||||
tombstone_event_id: str,
|
||||
):
|
||||
) -> None:
|
||||
"""Populate a new room based on an old room
|
||||
|
||||
Args:
|
||||
|
@ -332,8 +329,6 @@ class RoomCreationHandler(BaseHandler):
|
|||
created with _gemerate_room_id())
|
||||
new_room_version: the new room version to use
|
||||
tombstone_event_id: the ID of the tombstone event in the old room.
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Iterable
|
||||
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
|
@ -37,7 +38,7 @@ class SearchHandler(BaseHandler):
|
|||
self.state_store = self.storage.state
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def get_old_rooms_from_upgraded_room(self, room_id):
|
||||
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
|
||||
"""Retrieves room IDs of old rooms in the history of an upgraded room.
|
||||
|
||||
We do so by checking the m.room.create event of the room for a
|
||||
|
@ -48,10 +49,10 @@ class SearchHandler(BaseHandler):
|
|||
The full list of all found rooms in then returned.
|
||||
|
||||
Args:
|
||||
room_id (str): id of the room to search through.
|
||||
room_id: id of the room to search through.
|
||||
|
||||
Returns:
|
||||
Deferred[iterable[str]]: predecessor room ids
|
||||
Predecessor room ids
|
||||
"""
|
||||
|
||||
historical_room_ids = []
|
||||
|
|
|
@ -424,10 +424,6 @@ class SyncHandler(object):
|
|||
potential_recents: Optional[List[EventBase]] = None,
|
||||
newly_joined_room: bool = False,
|
||||
) -> TimelineBatch:
|
||||
"""
|
||||
Returns:
|
||||
a Deferred TimelineBatch
|
||||
"""
|
||||
with Measure(self.clock, "load_filtered_recents"):
|
||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||
block_all_timeline = (
|
||||
|
|
|
@ -442,21 +442,6 @@ class StaticResource(File):
|
|||
return super().render_GET(request)
|
||||
|
||||
|
||||
def _options_handler(request):
|
||||
"""Request handler for OPTIONS requests
|
||||
|
||||
This is a request handler suitable for return from
|
||||
_get_handler_for_request. It returns a 200 and an empty body.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request):
|
||||
|
||||
Returns:
|
||||
Tuple[int, dict]: http code, response body.
|
||||
"""
|
||||
return 200, {}
|
||||
|
||||
|
||||
def _unrecognised_request_handler(request):
|
||||
"""Request handler for unrecognised requests
|
||||
|
||||
|
@ -490,11 +475,12 @@ class OptionsResource(resource.Resource):
|
|||
"""Responds to OPTION requests for itself and all children."""
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
code, response_json_object = _options_handler(request)
|
||||
request.setResponseCode(204)
|
||||
request.setHeader(b"Content-Length", b"0")
|
||||
|
||||
return respond_with_json(
|
||||
request, code, response_json_object, send_cors=True, canonical_json=False,
|
||||
)
|
||||
set_cors_headers(request)
|
||||
|
||||
return b""
|
||||
|
||||
def getChildWithDefault(self, path, request):
|
||||
if request.method == b"OPTIONS":
|
||||
|
|
|
@ -737,24 +737,14 @@ def trace(func=None, opname=None):
|
|||
|
||||
@wraps(func)
|
||||
async def _trace_inner(*args, **kwargs):
|
||||
if opentracing is None:
|
||||
with start_active_span(_opname):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
with start_active_span(_opname) as scope:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception:
|
||||
scope.span.set_tag(tags.ERROR, True)
|
||||
raise
|
||||
|
||||
else:
|
||||
# The other case here handles both sync functions and those
|
||||
# decorated with inlineDeferred.
|
||||
@wraps(func)
|
||||
def _trace_inner(*args, **kwargs):
|
||||
if opentracing is None:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
scope = start_active_span(_opname)
|
||||
scope.__enter__()
|
||||
|
||||
|
@ -767,7 +757,6 @@ def trace(func=None, opname=None):
|
|||
return result
|
||||
|
||||
def err_back(result):
|
||||
scope.span.set_tag(tags.ERROR, True)
|
||||
scope.__exit__(None, None, None)
|
||||
return result
|
||||
|
||||
|
|
|
@ -116,6 +116,8 @@ class _LogContextScope(Scope):
|
|||
if self._enter_logcontext:
|
||||
self.logcontext.__enter__()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if type == twisted.internet.defer._DefGen_Return:
|
||||
super(_LogContextScope, self).__exit__(None, None, None)
|
||||
|
|
|
@ -304,7 +304,9 @@ class RulesForRoom(object):
|
|||
|
||||
push_rules_delta_state_cache_metric.inc_hits()
|
||||
else:
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state_ids = yield defer.ensureDeferred(
|
||||
context.get_current_state_ids()
|
||||
)
|
||||
push_rules_delta_state_cache_metric.inc_misses()
|
||||
|
||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||
|
|
|
@ -24,6 +24,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
|
|||
from synapse.api.constants import EventTypes
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamEventRow,
|
||||
|
@ -104,6 +105,7 @@ class ReplicationDataHandler:
|
|||
self._clock = hs.get_clock()
|
||||
self._streams = hs.get_replication_streams()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._typing_handler = hs.get_typing_handler()
|
||||
|
||||
# Map from stream to list of deferreds waiting for the stream to
|
||||
# arrive at a particular position. The lists are sorted by stream position.
|
||||
|
@ -127,6 +129,12 @@ class ReplicationDataHandler:
|
|||
"""
|
||||
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
if stream_name == TypingStream.NAME:
|
||||
self._typing_handler.process_replication_rows(token, rows)
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
|
||||
if stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
|
@ -33,6 +34,7 @@ from typing_extensions import Deque
|
|||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
|
||||
from synapse.replication.tcp.commands import (
|
||||
ClearUserSyncsCommand,
|
||||
|
@ -152,7 +154,7 @@ class ReplicationCommandHandler:
|
|||
# When POSITION or RDATA commands arrive, we stick them in a queue and process
|
||||
# them in order in a separate background process.
|
||||
|
||||
# the streams which are currently being processed by _unsafe_process_stream
|
||||
# the streams which are currently being processed by _unsafe_process_queue
|
||||
self._processing_streams = set() # type: Set[str]
|
||||
|
||||
# for each stream, a queue of commands that are awaiting processing, and the
|
||||
|
@ -185,7 +187,7 @@ class ReplicationCommandHandler:
|
|||
if self._is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
async def _add_command_to_stream_queue(
|
||||
def _add_command_to_stream_queue(
|
||||
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||
) -> None:
|
||||
"""Queue the given received command for processing
|
||||
|
@ -199,33 +201,34 @@ class ReplicationCommandHandler:
|
|||
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
|
||||
return
|
||||
|
||||
# if we're already processing this stream, stick the new command in the
|
||||
# queue, and we're done.
|
||||
queue.append((cmd, conn))
|
||||
|
||||
# if we're already processing this stream, there's nothing more to do:
|
||||
# the new entry on the queue will get picked up in due course
|
||||
if stream_name in self._processing_streams:
|
||||
queue.append((cmd, conn))
|
||||
return
|
||||
|
||||
# otherwise, process the new command.
|
||||
# fire off a background process to start processing the queue.
|
||||
run_as_background_process(
|
||||
"process-replication-data", self._unsafe_process_queue, stream_name
|
||||
)
|
||||
|
||||
# arguably we should start off a new background process here, but nothing
|
||||
# will be too upset if we don't return for ages, so let's save the overhead
|
||||
# and use the existing logcontext.
|
||||
async def _unsafe_process_queue(self, stream_name: str):
|
||||
"""Processes the command queue for the given stream, until it is empty
|
||||
|
||||
Does not check if there is already a thread processing the queue, hence "unsafe"
|
||||
"""
|
||||
assert stream_name not in self._processing_streams
|
||||
|
||||
self._processing_streams.add(stream_name)
|
||||
try:
|
||||
# might as well skip the queue for this one, since it must be empty
|
||||
assert not queue
|
||||
await self._process_command(cmd, conn, stream_name)
|
||||
|
||||
# now process any other commands that have built up while we were
|
||||
# dealing with that one.
|
||||
queue = self._command_queues_by_stream.get(stream_name)
|
||||
while queue:
|
||||
cmd, conn = queue.popleft()
|
||||
try:
|
||||
await self._process_command(cmd, conn, stream_name)
|
||||
except Exception:
|
||||
logger.exception("Failed to handle command %s", cmd)
|
||||
|
||||
finally:
|
||||
self._processing_streams.discard(stream_name)
|
||||
|
||||
|
@ -299,7 +302,7 @@ class ReplicationCommandHandler:
|
|||
"""
|
||||
return self._streams_to_replicate
|
||||
|
||||
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
self.send_positions_to_connection(conn)
|
||||
|
||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
||||
|
@ -318,57 +321,73 @@ class ReplicationCommandHandler:
|
|||
)
|
||||
)
|
||||
|
||||
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
|
||||
def on_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: UserSyncCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_sync_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
await self._presence_handler.update_external_syncs_row(
|
||||
return self._presence_handler.update_external_syncs_row(
|
||||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def on_CLEAR_USER_SYNC(
|
||||
def on_CLEAR_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
||||
):
|
||||
) -> Optional[Awaitable[None]]:
|
||||
if self._is_master:
|
||||
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def on_FEDERATION_ACK(
|
||||
self, conn: AbstractConnection, cmd: FederationAckCommand
|
||||
):
|
||||
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(
|
||||
def on_REMOVE_PUSHER(
|
||||
self, conn: AbstractConnection, cmd: RemovePusherCommand
|
||||
):
|
||||
) -> Optional[Awaitable[None]]:
|
||||
remove_pusher_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
await self._store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
|
||||
)
|
||||
return self._handle_remove_pusher(cmd)
|
||||
else:
|
||||
return None
|
||||
|
||||
self._notifier.on_new_replication_data()
|
||||
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
|
||||
await self._store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||
self._notifier.on_new_replication_data()
|
||||
|
||||
def on_USER_IP(
|
||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
await self._store.insert_client_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
cmd.user_agent,
|
||||
cmd.device_id,
|
||||
cmd.last_seen,
|
||||
)
|
||||
return self._handle_user_ip(cmd)
|
||||
else:
|
||||
return None
|
||||
|
||||
if self._server_notices_sender:
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
async def _handle_user_ip(self, cmd: UserIpCommand):
|
||||
await self._store.insert_client_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
cmd.user_agent,
|
||||
cmd.device_id,
|
||||
cmd.last_seen,
|
||||
)
|
||||
|
||||
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
assert self._server_notices_sender is not None
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
return
|
||||
|
@ -382,7 +401,7 @@ class ReplicationCommandHandler:
|
|||
# 2. so we don't race with getting a POSITION command and fetching
|
||||
# missing RDATA.
|
||||
|
||||
await self._add_command_to_stream_queue(conn, cmd)
|
||||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_rdata(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
|
||||
|
@ -459,14 +478,14 @@ class ReplicationCommandHandler:
|
|||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
|
||||
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
|
||||
|
||||
await self._add_command_to_stream_queue(conn, cmd)
|
||||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_position(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
|
||||
|
@ -526,9 +545,7 @@ class ReplicationCommandHandler:
|
|||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(
|
||||
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ import abc
|
|||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
|
||||
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
|
||||
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
|
||||
if so, that will get run as a background process.
|
||||
|
||||
It also sends `PING` periodically, and correctly times out remote connections
|
||||
(if they send a `PING` command)
|
||||
|
@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
# a logcontext which we use for processing incoming commands. We declare it as a
|
||||
# background process so that the CPU stats get reported to prometheus.
|
||||
self._logging_context = BackgroundProcessLoggingContext(
|
||||
"replication_command_handler-%s" % self.conn_id
|
||||
)
|
||||
ctx_name = "replication-conn-%s" % self.conn_id
|
||||
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
|
||||
self._logging_context.request = ctx_name
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("[%s] Connection established", self.id())
|
||||
|
@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
self.handle_command(cmd)
|
||||
|
||||
async def handle_command(self, cmd: Command):
|
||||
def handle_command(self, cmd: Command) -> None:
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
First calls `self.on_<COMMAND>` if it exists, then calls
|
||||
`self.command_handler.on_<COMMAND>` if it exists. This allows for
|
||||
protocol level handling of commands (e.g. PINGs), before delegating to
|
||||
the handler.
|
||||
`self.command_handler.on_<COMMAND>` if it exists (which can optionally
|
||||
return an Awaitable).
|
||||
|
||||
This allows for protocol level handling of commands (e.g. PINGs), before
|
||||
delegating to the handler.
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
|
@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(self, cmd)
|
||||
res = cmd_func(self, cmd)
|
||||
|
||||
# the handler might be a coroutine: fire it off as a background process
|
||||
# if so.
|
||||
|
||||
if isawaitable(res):
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), lambda: res
|
||||
)
|
||||
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
|
@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
for cmd in pending:
|
||||
self.send_command(cmd)
|
||||
|
||||
async def on_PING(self, line):
|
||||
def on_PING(self, line):
|
||||
self.received_ping = True
|
||||
|
||||
async def on_ERROR(self, cmd):
|
||||
def on_ERROR(self, cmd):
|
||||
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
||||
|
||||
def pauseProducing(self):
|
||||
|
@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.send_command(ServerCommand(self.server_name))
|
||||
super().connectionMade()
|
||||
|
||||
async def on_NAME(self, cmd):
|
||||
def on_NAME(self, cmd):
|
||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||
self.name = cmd.data
|
||||
|
||||
|
@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
# Once we've connected subscribe to the necessary streams
|
||||
self.replicate()
|
||||
|
||||
async def on_SERVER(self, cmd):
|
||||
def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
self.send_error("Wrong remote")
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import txredisapi
|
||||
|
@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
# remote instances.
|
||||
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
self.handle_command(cmd)
|
||||
|
||||
async def handle_command(self, cmd: Command):
|
||||
def handle_command(self, cmd: Command) -> None:
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||
Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
|
||||
Awaitable).
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
"""
|
||||
handled = False
|
||||
|
||||
# First call any command handlers on this instance. These are for redis
|
||||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(self, cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
if not cmd_func:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
return
|
||||
|
||||
res = cmd_func(self, cmd)
|
||||
|
||||
# the handler might be a coroutine: fire it off as a background process
|
||||
# if so.
|
||||
|
||||
if isawaitable(res):
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), lambda: res
|
||||
)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis")
|
||||
|
|
|
@ -17,8 +17,7 @@
|
|||
"""
|
||||
import logging
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Iterable, Pattern
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
|
@ -27,15 +26,23 @@ from synapse.types import JsonDict
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
|
||||
def client_patterns(
|
||||
path_regex: str,
|
||||
releases: Iterable[int] = (0,),
|
||||
unstable: bool = True,
|
||||
v1: bool = False,
|
||||
) -> Iterable[Pattern]:
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
prefix.
|
||||
|
||||
Args:
|
||||
path_regex (str): The regex string to match. This should NOT have a ^
|
||||
path_regex: The regex string to match. This should NOT have a ^
|
||||
as this will be prefixed.
|
||||
releases: An iterable of releases to include this endpoint under.
|
||||
unstable: If true, include this endpoint under the "unstable" prefix.
|
||||
v1: If true, include this endpoint under the "api/v1" prefix.
|
||||
Returns:
|
||||
SRE_Pattern
|
||||
An iterable of patterns.
|
||||
"""
|
||||
patterns = []
|
||||
|
||||
|
@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
|
|||
def interactive_auth_handler(orig):
|
||||
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
||||
|
||||
Takes a on_POST method which returns a deferred (errcode, body) response
|
||||
Takes a on_POST method which returns an Awaitable (errcode, body) response
|
||||
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
||||
a 401 response.
|
||||
|
||||
Normal usage is:
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
async def on_POST(self, request):
|
||||
# ...
|
||||
yield self.auth_handler.check_auth
|
||||
"""
|
||||
await self.auth_handler.check_auth
|
||||
"""
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
res = defer.ensureDeferred(orig(*args, **kwargs))
|
||||
res.addErrback(_catch_incomplete_interactive_auth)
|
||||
return res
|
||||
async def wrapped(*args, **kwargs):
|
||||
try:
|
||||
return await orig(*args, **kwargs)
|
||||
except InteractiveAuthIncompleteError as e:
|
||||
return 401, e.result
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _catch_incomplete_interactive_auth(f):
|
||||
"""helper for interactive_auth_handler
|
||||
|
||||
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
|
||||
|
||||
Args:
|
||||
f (failure.Failure):
|
||||
"""
|
||||
f.trap(InteractiveAuthIncompleteError)
|
||||
return 401, f.value.result
|
||||
|
|
|
@ -18,7 +18,6 @@ import logging
|
|||
import os
|
||||
import urllib
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
|
@ -77,8 +76,9 @@ def respond_404(request):
|
|||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
|
||||
async def respond_with_file(
|
||||
request, media_type, file_path, file_size=None, upload_name=None
|
||||
):
|
||||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
|
@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
|
|||
add_file_headers(request, media_type, file_size, upload_name)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
|
||||
await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
|
||||
|
||||
finish_request(request)
|
||||
else:
|
||||
|
@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
|
|||
return True
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
|
||||
async def respond_with_responder(
|
||||
request, responder, media_type, file_size, upload_name=None
|
||||
):
|
||||
"""Responds to the request with given responder. If responder is None then
|
||||
returns 404.
|
||||
|
||||
|
@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
|
|||
add_file_headers(request, media_type, file_size, upload_name)
|
||||
try:
|
||||
with responder:
|
||||
yield responder.write_to_consumer(request)
|
||||
await responder.write_to_consumer(request)
|
||||
except Exception as e:
|
||||
# The majority of the time this will be due to the client having gone
|
||||
# away. Unfortunately, Twisted simply throws a generic exception at us
|
||||
|
|
|
@ -14,17 +14,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
|
||||
from ._base import Responder
|
||||
from ._base import FileInfo, Responder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,25 +47,24 @@ class MediaStorage(object):
|
|||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_file(self, source, file_info):
|
||||
async def store_file(self, source, file_info: FileInfo) -> str:
|
||||
"""Write `source` to the on disk media store, and also any other
|
||||
configured storage providers
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
file_info (FileInfo): Info about the file to store
|
||||
file_info: Info about the file to store
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
the file path written to in the primary media store
|
||||
"""
|
||||
|
||||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||
# Write to the main repository
|
||||
yield defer_to_thread(
|
||||
await defer_to_thread(
|
||||
self.hs.get_reactor(), _write_file_synchronously, source, f
|
||||
)
|
||||
yield finish_cb()
|
||||
await finish_cb()
|
||||
|
||||
return fname
|
||||
|
||||
|
@ -75,7 +75,7 @@ class MediaStorage(object):
|
|||
|
||||
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
||||
like object that can be written to, fname is the absolute path of file
|
||||
on disk, and finish_cb is a function that returns a Deferred.
|
||||
on disk, and finish_cb is a function that returns an awaitable.
|
||||
|
||||
fname can be used to read the contents from after upload, e.g. to
|
||||
generate thumbnails.
|
||||
|
@ -91,7 +91,7 @@ class MediaStorage(object):
|
|||
|
||||
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
||||
# .. write into f ...
|
||||
yield finish_cb()
|
||||
await finish_cb()
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
|
@ -103,10 +103,13 @@ class MediaStorage(object):
|
|||
|
||||
finished_called = [False]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def finish():
|
||||
async def finish():
|
||||
for provider in self.storage_providers:
|
||||
yield provider.store_file(path, file_info)
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = provider.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
|
@ -123,17 +126,15 @@ class MediaStorage(object):
|
|||
if not finished_called:
|
||||
raise Exception("Finished callback not called")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_media(self, file_info):
|
||||
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""Attempts to fetch media described by file_info from the local cache
|
||||
and configured storage providers.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
file_info
|
||||
|
||||
Returns:
|
||||
Deferred[Responder|None]: Returns a Responder if the file was found,
|
||||
otherwise None.
|
||||
Returns a Responder if the file was found, otherwise None.
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
|
@ -142,23 +143,26 @@ class MediaStorage(object):
|
|||
return FileResponder(open(local_path, "rb"))
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = yield provider.fetch(path, file_info)
|
||||
res = provider.fetch(path, file_info)
|
||||
# Fetch is supposed to return an Awaitable, but guard against
|
||||
# improper implementations.
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
if res:
|
||||
logger.debug("Streaming %s from %s", path, provider)
|
||||
return res
|
||||
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ensure_media_is_in_local_cache(self, file_info):
|
||||
async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
|
||||
"""Ensures that the given file is in the local cache. Attempts to
|
||||
download it from storage providers if it isn't.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
file_info
|
||||
|
||||
Returns:
|
||||
Deferred[str]: Full path to local file
|
||||
Full path to local file
|
||||
"""
|
||||
path = self._file_info_to_path(file_info)
|
||||
local_path = os.path.join(self.local_media_directory, path)
|
||||
|
@ -170,14 +174,18 @@ class MediaStorage(object):
|
|||
os.makedirs(dirname)
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = yield provider.fetch(path, file_info)
|
||||
res = provider.fetch(path, file_info)
|
||||
# Fetch is supposed to return an Awaitable, but guard against
|
||||
# improper implementations.
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(
|
||||
open(local_path, "wb"), self.hs.get_reactor()
|
||||
)
|
||||
yield res.write_to_consumer(consumer)
|
||||
yield consumer.wait()
|
||||
await res.write_to_consumer(consumer)
|
||||
await consumer.wait()
|
||||
return local_path
|
||||
|
||||
raise Exception("file could not be found")
|
||||
|
|
|
@ -26,6 +26,7 @@ import traceback
|
|||
from typing import Dict, Optional
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import attr
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
|||
OG_TAG_NAME_MAXLEN = 50
|
||||
OG_TAG_VALUE_MAXLEN = 1000
|
||||
|
||||
ONE_HOUR = 60 * 60 * 1000
|
||||
|
||||
# A map of globs to API endpoints.
|
||||
_oembed_globs = {
|
||||
# Twitter.
|
||||
"https://publish.twitter.com/oembed": [
|
||||
"https://twitter.com/*/status/*",
|
||||
"https://*.twitter.com/*/status/*",
|
||||
"https://twitter.com/*/moments/*",
|
||||
"https://*.twitter.com/*/moments/*",
|
||||
# Include the HTTP versions too.
|
||||
"http://twitter.com/*/status/*",
|
||||
"http://*.twitter.com/*/status/*",
|
||||
"http://twitter.com/*/moments/*",
|
||||
"http://*.twitter.com/*/moments/*",
|
||||
],
|
||||
}
|
||||
# Convert the globs to regular expressions.
|
||||
_oembed_patterns = {}
|
||||
for endpoint, globs in _oembed_globs.items():
|
||||
for glob in globs:
|
||||
# Convert the glob into a sane regular expression to match against. The
|
||||
# rules followed will be slightly different for the domain portion vs.
|
||||
# the rest.
|
||||
#
|
||||
# 1. The scheme must be one of HTTP / HTTPS (and have no globs).
|
||||
# 2. The domain can have globs, but we limit it to characters that can
|
||||
# reasonably be a domain part.
|
||||
# TODO: This does not attempt to handle Unicode domain names.
|
||||
# 3. Other parts allow a glob to be any one, or more, characters.
|
||||
results = urlparse.urlparse(glob)
|
||||
|
||||
# Ensure the scheme does not have wildcards (and is a sane scheme).
|
||||
if results.scheme not in {"http", "https"}:
|
||||
raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
|
||||
|
||||
pattern = urlparse.urlunparse(
|
||||
[
|
||||
results.scheme,
|
||||
re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
|
||||
]
|
||||
+ [re.escape(part).replace("\\*", ".+") for part in results[2:]]
|
||||
)
|
||||
_oembed_patterns[re.compile(pattern)] = endpoint
|
||||
|
||||
|
||||
@attr.s
|
||||
class OEmbedResult:
|
||||
# Either HTML content or URL must be provided.
|
||||
html = attr.ib(type=Optional[str])
|
||||
url = attr.ib(type=Optional[str])
|
||||
title = attr.ib(type=Optional[str])
|
||||
# Number of seconds to cache the content.
|
||||
cache_age = attr.ib(type=int)
|
||||
|
||||
|
||||
class OEmbedError(Exception):
|
||||
"""An error occurred processing the oEmbed object."""
|
||||
|
||||
|
||||
class PreviewUrlResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
cache_name="url_previews",
|
||||
clock=self.clock,
|
||||
# don't spider URLs more often than once an hour
|
||||
expiry_ms=60 * 60 * 1000,
|
||||
expiry_ms=ONE_HOUR,
|
||||
)
|
||||
|
||||
if self._worker_run_media_background_jobs:
|
||||
|
@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
|
||||
return jsonog.encode("utf8")
|
||||
|
||||
def _get_oembed_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the URL should be downloaded as oEmbed content instead.
|
||||
|
||||
Params:
|
||||
url: The URL to check.
|
||||
|
||||
Returns:
|
||||
A URL to use instead or None if the original URL should be used.
|
||||
"""
|
||||
for url_pattern, endpoint in _oembed_patterns.items():
|
||||
if url_pattern.fullmatch(url):
|
||||
return endpoint
|
||||
|
||||
# No match.
|
||||
return None
|
||||
|
||||
async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
|
||||
"""
|
||||
Request content from an oEmbed endpoint.
|
||||
|
||||
Params:
|
||||
endpoint: The oEmbed API endpoint.
|
||||
url: The URL to pass to the API.
|
||||
|
||||
Returns:
|
||||
An object representing the metadata returned.
|
||||
|
||||
Raises:
|
||||
OEmbedError if fetching or parsing of the oEmbed information fails.
|
||||
"""
|
||||
try:
|
||||
logger.debug("Trying to get oEmbed content for url '%s'", url)
|
||||
result = await self.client.get_json(
|
||||
endpoint,
|
||||
# TODO Specify max height / width.
|
||||
# Note that only the JSON format is supported.
|
||||
args={"url": url},
|
||||
)
|
||||
|
||||
# Ensure there's a version of 1.0.
|
||||
if result.get("version") != "1.0":
|
||||
raise OEmbedError("Invalid version: %s" % (result.get("version"),))
|
||||
|
||||
oembed_type = result.get("type")
|
||||
|
||||
# Ensure the cache age is None or an int.
|
||||
cache_age = result.get("cache_age")
|
||||
if cache_age:
|
||||
cache_age = int(cache_age)
|
||||
|
||||
oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
|
||||
|
||||
# HTML content.
|
||||
if oembed_type == "rich":
|
||||
oembed_result.html = result.get("html")
|
||||
return oembed_result
|
||||
|
||||
if oembed_type == "photo":
|
||||
oembed_result.url = result.get("url")
|
||||
return oembed_result
|
||||
|
||||
# TODO Handle link and video types.
|
||||
|
||||
if "thumbnail_url" in result:
|
||||
oembed_result.url = result.get("thumbnail_url")
|
||||
return oembed_result
|
||||
|
||||
raise OEmbedError("Incompatible oEmbed information.")
|
||||
|
||||
except OEmbedError as e:
|
||||
# Trap OEmbedErrors first so we can directly re-raise them.
|
||||
logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# Trap any exception and let the code follow as usual.
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
||||
raise OEmbedError() from e
|
||||
|
||||
async def _download_url(self, url, user):
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
|
@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
|
||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
# If this URL can be accessed via oEmbed, use that instead.
|
||||
url_to_download = url
|
||||
oembed_url = self._get_oembed_url(url)
|
||||
if oembed_url:
|
||||
# The result might be a new URL to download, or it might be HTML content.
|
||||
try:
|
||||
logger.debug("Trying to get preview for url '%s'", url)
|
||||
length, headers, uri, code = await self.client.get_file(
|
||||
url,
|
||||
output_stream=f,
|
||||
max_size=self.max_spider_size,
|
||||
headers={"Accept-Language": self.url_preview_accept_language},
|
||||
)
|
||||
except SynapseError:
|
||||
# Pass SynapseErrors through directly, so that the servlet
|
||||
# handler will return a SynapseError to the client instead of
|
||||
# blank data or a 500.
|
||||
raise
|
||||
except DNSLookupError:
|
||||
# DNS lookup returned no results
|
||||
# Note: This will also be the case if one of the resolved IP
|
||||
# addresses is blacklisted
|
||||
raise SynapseError(
|
||||
502,
|
||||
"DNS resolution failure during URL preview generation",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
except Exception as e:
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
logger.warning("Error downloading %s: %r", url, e)
|
||||
oembed_result = await self._get_oembed_content(oembed_url, url)
|
||||
if oembed_result.url:
|
||||
url_to_download = oembed_result.url
|
||||
elif oembed_result.html:
|
||||
url_to_download = None
|
||||
except OEmbedError:
|
||||
# If an error occurs, try doing a normal preview.
|
||||
pass
|
||||
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Failed to download content: %s"
|
||||
% (traceback.format_exception_only(sys.exc_info()[0], e),),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
await finish()
|
||||
if url_to_download:
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
logger.debug("Trying to get preview for url '%s'", url_to_download)
|
||||
length, headers, uri, code = await self.client.get_file(
|
||||
url_to_download,
|
||||
output_stream=f,
|
||||
max_size=self.max_spider_size,
|
||||
headers={"Accept-Language": self.url_preview_accept_language},
|
||||
)
|
||||
except SynapseError:
|
||||
# Pass SynapseErrors through directly, so that the servlet
|
||||
# handler will return a SynapseError to the client instead of
|
||||
# blank data or a 500.
|
||||
raise
|
||||
except DNSLookupError:
|
||||
# DNS lookup returned no results
|
||||
# Note: This will also be the case if one of the resolved IP
|
||||
# addresses is blacklisted
|
||||
raise SynapseError(
|
||||
502,
|
||||
"DNS resolution failure during URL preview generation",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
except Exception as e:
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
logger.warning("Error downloading %s: %r", url_to_download, e)
|
||||
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Failed to download content: %s"
|
||||
% (traceback.format_exception_only(sys.exc_info()[0], e),),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
await finish()
|
||||
|
||||
if b"Content-Type" in headers:
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
|
||||
download_name = get_filename_from_headers(headers)
|
||||
|
||||
# FIXME: we should calculate a proper expiration based on the
|
||||
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
||||
expires = ONE_HOUR
|
||||
etag = headers["ETag"][0] if "ETag" in headers else None
|
||||
else:
|
||||
html_bytes = oembed_result.html.encode("utf-8") # type: ignore
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
f.write(html_bytes)
|
||||
await finish()
|
||||
|
||||
media_type = "text/html"
|
||||
download_name = oembed_result.title
|
||||
length = len(html_bytes)
|
||||
# If a specific cache age was not given, assume 1 hour.
|
||||
expires = oembed_result.cache_age or ONE_HOUR
|
||||
uri = oembed_url
|
||||
code = 200
|
||||
etag = None
|
||||
|
||||
try:
|
||||
if b"Content-Type" in headers:
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
download_name = get_filename_from_headers(headers)
|
||||
|
||||
await self.store.store_local_media(
|
||||
media_id=file_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
time_now_ms=time_now_ms,
|
||||
upload_name=download_name,
|
||||
media_length=length,
|
||||
user_id=user,
|
||||
|
@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
"filename": fname,
|
||||
"uri": uri,
|
||||
"response_code": code,
|
||||
# FIXME: we should calculate a proper expiration based on the
|
||||
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
||||
"expires": 60 * 60 * 1000,
|
||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
"expires": expires,
|
||||
"etag": etag,
|
||||
}
|
||||
|
||||
def _start_expire_url_cache_data(self):
|
||||
|
@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
# These may be cached for a bit on the client (i.e., they
|
||||
# may have a room open with a preview url thing open).
|
||||
# So we wait a couple of days before deleting, just in case.
|
||||
expire_before = now - 2 * 24 * 60 * 60 * 1000
|
||||
expire_before = now - 2 * 24 * ONE_HOUR
|
||||
media_ids = await self.store.get_url_cache_media_before(expire_before)
|
||||
|
||||
removed_media = []
|
||||
|
|
|
@ -31,6 +31,7 @@ import synapse.server_notices.server_notices_sender
|
|||
import synapse.state
|
||||
import synapse.storage
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.handlers.typing import FollowerTypingHandler
|
||||
from synapse.replication.tcp.streams import Stream
|
||||
|
||||
class HomeServer(object):
|
||||
|
@ -150,3 +151,5 @@ class HomeServer(object):
|
|||
pass
|
||||
def should_send_federation(self) -> bool:
|
||||
pass
|
||||
def get_typing_handler(self) -> FollowerTypingHandler:
|
||||
pass
|
||||
|
|
|
@ -16,14 +16,12 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
from synapse.events import EventBase
|
||||
|
@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
|
|||
from synapse.logging.utils import log_function
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
@ -108,8 +107,7 @@ class StateHandler(object):
|
|||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(
|
||||
async def get_current_state(
|
||||
self, room_id, event_type=None, state_key="", latest_event_ids=None
|
||||
):
|
||||
""" Retrieves the current state for the room. This is done by
|
||||
|
@ -126,20 +124,20 @@ class StateHandler(object):
|
|||
map from (type, state_key) to event
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
event_id = state.get((event_type, state_key))
|
||||
event = None
|
||||
if event_id:
|
||||
event = yield self.store.get_event(event_id, allow_none=True)
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
return event
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
state_map = await self.store.get_events(
|
||||
list(state.values()), get_prev_content=False
|
||||
)
|
||||
state = {
|
||||
|
@ -148,8 +146,7 @@ class StateHandler(object):
|
|||
|
||||
return state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||
async def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
|
@ -164,41 +161,38 @@ class StateHandler(object):
|
|||
(event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
return state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_users_in_room(self, room_id, latest_event_ids=None):
|
||||
async def get_current_users_in_room(
|
||||
self, room_id: str, latest_event_ids: Optional[List[str]] = None
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
"""
|
||||
Get the users who are currently in a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room.
|
||||
latest_event_ids (List[str]|None): Precomputed list of latest
|
||||
event IDs. Will be computed if None.
|
||||
room_id: The ID of the room.
|
||||
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
|
||||
Returns:
|
||||
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
|
||||
profileinfo.
|
||||
Dictionary of user IDs to their profileinfo.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
|
||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
|
||||
return joined_users
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_hosts_in_room(self, room_id):
|
||||
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
|
||||
async def get_current_hosts_in_room(self, room_id):
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_hosts_in_room_at_events(self, room_id, event_ids):
|
||||
async def get_hosts_in_room_at_events(self, room_id, event_ids):
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
Args:
|
||||
|
@ -208,12 +202,11 @@ class StateHandler(object):
|
|||
Returns:
|
||||
Deferred[list[str]]: the hosts in the room at the given events
|
||||
"""
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
|
||||
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
|
||||
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
|
||||
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
|
||||
return joined_hosts
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_event_context(
|
||||
async def compute_event_context(
|
||||
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
|
||||
):
|
||||
"""Build an EventContext structure for the event.
|
||||
|
@ -278,7 +271,7 @@ class StateHandler(object):
|
|||
# otherwise, we'll need to resolve the state across the prev_events.
|
||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||
|
||||
entry = yield self.resolve_state_groups_for_events(
|
||||
entry = await self.resolve_state_groups_for_events(
|
||||
event.room_id, event.prev_event_ids()
|
||||
)
|
||||
|
||||
|
@ -295,7 +288,7 @@ class StateHandler(object):
|
|||
#
|
||||
|
||||
if not state_group_before_event:
|
||||
state_group_before_event = yield self.state_store.store_state_group(
|
||||
state_group_before_event = await self.state_store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=state_group_before_event_prev_group,
|
||||
|
@ -335,7 +328,7 @@ class StateHandler(object):
|
|||
state_ids_after_event[key] = event.event_id
|
||||
delta_ids = {key: event.event_id}
|
||||
|
||||
state_group_after_event = yield self.state_store.store_state_group(
|
||||
state_group_after_event = await self.state_store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=state_group_before_event,
|
||||
|
@ -353,8 +346,7 @@ class StateHandler(object):
|
|||
)
|
||||
|
||||
@measure_func()
|
||||
@defer.inlineCallbacks
|
||||
def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
async def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
|
@ -373,7 +365,7 @@ class StateHandler(object):
|
|||
# map from state group id to the state in that state group (where
|
||||
# 'state' is a map from state key to event id)
|
||||
# dict[int, dict[(str, str), str]]
|
||||
state_groups_ids = yield self.state_store.get_state_groups_ids(
|
||||
state_groups_ids = await self.state_store.get_state_groups_ids(
|
||||
room_id, event_ids
|
||||
)
|
||||
|
||||
|
@ -382,7 +374,7 @@ class StateHandler(object):
|
|||
elif len(state_groups_ids) == 1:
|
||||
name, state_list = list(state_groups_ids.items()).pop()
|
||||
|
||||
prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
|
||||
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
|
||||
|
||||
return _StateCacheEntry(
|
||||
state=state_list,
|
||||
|
@ -391,9 +383,9 @@ class StateHandler(object):
|
|||
delta_ids=delta_ids,
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version_id(room_id)
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
|
||||
result = yield self._state_resolution_handler.resolve_state_groups(
|
||||
result = await self._state_resolution_handler.resolve_state_groups(
|
||||
room_id,
|
||||
room_version,
|
||||
state_groups_ids,
|
||||
|
@ -402,8 +394,7 @@ class StateHandler(object):
|
|||
)
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events(self, room_version, state_sets, event):
|
||||
async def resolve_events(self, room_version, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
|
@ -414,7 +405,7 @@ class StateHandler(object):
|
|||
state_map = {ev.event_id: ev for st in state_sets for ev in st}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_store(
|
||||
new_state = await resolve_events_with_store(
|
||||
self.clock,
|
||||
event.room_id,
|
||||
room_version,
|
||||
|
@ -451,9 +442,8 @@ class StateResolutionHandler(object):
|
|||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(
|
||||
async def resolve_state_groups(
|
||||
self, room_id, room_version, state_groups_ids, event_map, state_res_store
|
||||
):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
@ -479,13 +469,13 @@ class StateResolutionHandler(object):
|
|||
state_res_store (StateResolutionStore)
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
_StateCacheEntry: resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
|
||||
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
|
||||
with (yield self.resolve_linearizer.queue(group_names)):
|
||||
with (await self.resolve_linearizer.queue(group_names)):
|
||||
if self._state_cache is not None:
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
if cache:
|
||||
|
@ -517,7 +507,7 @@ class StateResolutionHandler(object):
|
|||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_store(
|
||||
new_state = await resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
|
@ -598,7 +588,7 @@ def resolve_events_with_store(
|
|||
state_sets: List[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "StateResolutionStore",
|
||||
):
|
||||
) -> Awaitable[StateMap[str]]:
|
||||
"""
|
||||
Args:
|
||||
room_id: the room we are working in
|
||||
|
@ -619,8 +609,7 @@ def resolve_events_with_store(
|
|||
state_res_store: a place to fetch events from
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
v = KNOWN_ROOM_VERSIONS[room_version]
|
||||
if v.state_res == StateResolutionVersions.V1:
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
|
@ -32,12 +30,11 @@ logger = logging.getLogger(__name__)
|
|||
POWER_KEY = (EventTypes.PowerLevels, "")
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_store(
|
||||
async def resolve_events_with_store(
|
||||
room_id: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable,
|
||||
state_map_factory: Callable[[List[str]], Awaitable],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -56,7 +53,7 @@ def resolve_events_with_store(
|
|||
|
||||
state_map_factory: will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event.
|
||||
an Awaitable that resolves to a dict of event_id to event.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
|
@ -80,7 +77,7 @@ def resolve_events_with_store(
|
|||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
state_map = await state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
|
||||
|
@ -110,7 +107,7 @@ def resolve_events_with_store(
|
|||
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
|
||||
)
|
||||
|
||||
state_map_new = yield state_map_factory(new_needed_events)
|
||||
state_map_new = await state_map_factory(new_needed_events)
|
||||
for event in state_map_new.values():
|
||||
if event.room_id != room_id:
|
||||
raise Exception(
|
||||
|
|
|
@ -18,8 +18,6 @@ import itertools
|
|||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.state
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
|
@ -32,14 +30,13 @@ from synapse.util import Clock
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# We want to yield to the reactor occasionally during state res when dealing
|
||||
# We want to await to the reactor occasionally during state res when dealing
|
||||
# with large data sets, so that we don't exhaust the reactor. This is done by
|
||||
# yielding to reactor during loops every N iterations.
|
||||
_YIELD_AFTER_ITERATIONS = 100
|
||||
# awaiting to reactor during loops every N iterations.
|
||||
_AWAIT_AFTER_ITERATIONS = 100
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_store(
|
||||
async def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
|
@ -87,7 +84,7 @@ def resolve_events_with_store(
|
|||
|
||||
# Also fetch all auth events that appear in only some of the state sets'
|
||||
# auth chains.
|
||||
auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
|
||||
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
|
||||
|
||||
full_conflicted_set = set(
|
||||
itertools.chain(
|
||||
|
@ -95,7 +92,7 @@ def resolve_events_with_store(
|
|||
)
|
||||
)
|
||||
|
||||
events = yield state_res_store.get_events(
|
||||
events = await state_res_store.get_events(
|
||||
[eid for eid in full_conflicted_set if eid not in event_map],
|
||||
allow_rejected=True,
|
||||
)
|
||||
|
@ -118,14 +115,14 @@ def resolve_events_with_store(
|
|||
eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
|
||||
)
|
||||
|
||||
sorted_power_events = yield _reverse_topological_power_sort(
|
||||
sorted_power_events = await _reverse_topological_power_sort(
|
||||
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
|
||||
)
|
||||
|
||||
logger.debug("sorted %d power events", len(sorted_power_events))
|
||||
|
||||
# Now sequentially auth each one
|
||||
resolved_state = yield _iterative_auth_checks(
|
||||
resolved_state = await _iterative_auth_checks(
|
||||
clock,
|
||||
room_id,
|
||||
room_version,
|
||||
|
@ -148,13 +145,13 @@ def resolve_events_with_store(
|
|||
logger.debug("sorting %d remaining events", len(leftover_events))
|
||||
|
||||
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
|
||||
leftover_events = yield _mainline_sort(
|
||||
leftover_events = await _mainline_sort(
|
||||
clock, room_id, leftover_events, pl, event_map, state_res_store
|
||||
)
|
||||
|
||||
logger.debug("resolving remaining events")
|
||||
|
||||
resolved_state = yield _iterative_auth_checks(
|
||||
resolved_state = await _iterative_auth_checks(
|
||||
clock,
|
||||
room_id,
|
||||
room_version,
|
||||
|
@ -174,8 +171,7 @@ def resolve_events_with_store(
|
|||
return resolved_state
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
||||
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
||||
"""Return the power level of the sender of the given event according to
|
||||
their auth events.
|
||||
|
||||
|
@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
|||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
event = yield _get_event(room_id, event_id, event_map, state_res_store)
|
||||
event = await _get_event(room_id, event_id, event_map, state_res_store)
|
||||
|
||||
pl = None
|
||||
for aid in event.auth_event_ids():
|
||||
aev = yield _get_event(
|
||||
aev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
|
@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
|||
if pl is None:
|
||||
# Couldn't find power level. Check if they're the creator of the room
|
||||
for aid in event.auth_event_ids():
|
||||
aev = yield _get_event(
|
||||
aev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
|
||||
|
@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
|||
return int(level)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
||||
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
||||
"""Compare the auth chains of each state set and return the set of events
|
||||
that only appear in some but not all of the auth chains.
|
||||
|
||||
|
@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
|||
Deferred[set[str]]: Set of event IDs
|
||||
"""
|
||||
|
||||
difference = yield state_res_store.get_auth_chain_difference(
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
[set(state_set.values()) for state_set in state_sets]
|
||||
)
|
||||
|
||||
|
@ -292,8 +287,7 @@ def _is_power_event(event):
|
|||
return False
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _add_event_and_auth_chain_to_graph(
|
||||
async def _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
):
|
||||
"""Helper function for _reverse_topological_power_sort that add the event
|
||||
|
@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph(
|
|||
eid = state.pop()
|
||||
graph.setdefault(eid, set())
|
||||
|
||||
event = yield _get_event(room_id, eid, event_map, state_res_store)
|
||||
event = await _get_event(room_id, eid, event_map, state_res_store)
|
||||
for aid in event.auth_event_ids():
|
||||
if aid in auth_diff:
|
||||
if aid not in graph:
|
||||
|
@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph(
|
|||
graph.setdefault(eid, set()).add(aid)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _reverse_topological_power_sort(
|
||||
async def _reverse_topological_power_sort(
|
||||
clock, room_id, event_ids, event_map, state_res_store, auth_diff
|
||||
):
|
||||
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
||||
|
@ -344,26 +337,26 @@ def _reverse_topological_power_sort(
|
|||
|
||||
graph = {}
|
||||
for idx, event_id in enumerate(event_ids, start=1):
|
||||
yield _add_event_and_auth_chain_to_graph(
|
||||
await _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
)
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# We await occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
if idx % _AWAIT_AFTER_ITERATIONS == 0:
|
||||
await clock.sleep(0)
|
||||
|
||||
event_to_pl = {}
|
||||
for idx, event_id in enumerate(graph, start=1):
|
||||
pl = yield _get_power_level_for_sender(
|
||||
pl = await _get_power_level_for_sender(
|
||||
room_id, event_id, event_map, state_res_store
|
||||
)
|
||||
event_to_pl[event_id] = pl
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# We await occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
if idx % _AWAIT_AFTER_ITERATIONS == 0:
|
||||
await clock.sleep(0)
|
||||
|
||||
def _get_power_order(event_id):
|
||||
ev = event_map[event_id]
|
||||
|
@ -378,8 +371,7 @@ def _reverse_topological_power_sort(
|
|||
return sorted_events
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _iterative_auth_checks(
|
||||
async def _iterative_auth_checks(
|
||||
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
|
||||
):
|
||||
"""Sequentially apply auth checks to each event in given list, updating the
|
||||
|
@ -405,7 +397,7 @@ def _iterative_auth_checks(
|
|||
|
||||
auth_events = {}
|
||||
for aid in event.auth_event_ids():
|
||||
ev = yield _get_event(
|
||||
ev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
|
||||
|
@ -420,7 +412,7 @@ def _iterative_auth_checks(
|
|||
for key in event_auth.auth_types_for_event(event):
|
||||
if key in resolved_state:
|
||||
ev_id = resolved_state[key]
|
||||
ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
|
||||
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
|
||||
|
||||
if ev.rejected_reason is None:
|
||||
auth_events[key] = event_map[ev_id]
|
||||
|
@ -438,16 +430,15 @@ def _iterative_auth_checks(
|
|||
except AuthError:
|
||||
pass
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# We await occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
if idx % _AWAIT_AFTER_ITERATIONS == 0:
|
||||
await clock.sleep(0)
|
||||
|
||||
return resolved_state
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _mainline_sort(
|
||||
async def _mainline_sort(
|
||||
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
|
||||
):
|
||||
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
||||
|
@ -474,21 +465,21 @@ def _mainline_sort(
|
|||
idx = 0
|
||||
while pl:
|
||||
mainline.append(pl)
|
||||
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
|
||||
pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
|
||||
auth_events = pl_ev.auth_event_ids()
|
||||
pl = None
|
||||
for aid in auth_events:
|
||||
ev = yield _get_event(
|
||||
ev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
pl = aid
|
||||
break
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# We await occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
|
||||
await clock.sleep(0)
|
||||
|
||||
idx += 1
|
||||
|
||||
|
@ -498,23 +489,24 @@ def _mainline_sort(
|
|||
|
||||
order_map = {}
|
||||
for idx, ev_id in enumerate(event_ids, start=1):
|
||||
depth = yield _get_mainline_depth_for_event(
|
||||
depth = await _get_mainline_depth_for_event(
|
||||
event_map[ev_id], mainline_map, event_map, state_res_store
|
||||
)
|
||||
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# We await occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
if idx % _AWAIT_AFTER_ITERATIONS == 0:
|
||||
await clock.sleep(0)
|
||||
|
||||
event_ids.sort(key=lambda ev_id: order_map[ev_id])
|
||||
|
||||
return event_ids
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
|
||||
async def _get_mainline_depth_for_event(
|
||||
event, mainline_map, event_map, state_res_store
|
||||
):
|
||||
"""Get the mainline depths for the given event based on the mainline map
|
||||
|
||||
Args:
|
||||
|
@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
|
|||
event = None
|
||||
|
||||
for aid in auth_events:
|
||||
aev = yield _get_event(
|
||||
aev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
|
@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
|
|||
return 0
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
|
||||
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
|
||||
"""Helper function to look up event in event_map, falling back to looking
|
||||
it up in the store
|
||||
|
||||
|
@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
|
|||
Deferred[Optional[FrozenEvent]]
|
||||
"""
|
||||
if event_id not in event_map:
|
||||
events = yield state_res_store.get_events([event_id], allow_rejected=True)
|
||||
events = await state_res_store.get_events([event_id], allow_rejected=True)
|
||||
event_map.update(events)
|
||||
event = event_map.get(event_id)
|
||||
|
||||
|
|
|
@ -259,7 +259,7 @@ class PushRulesWorkerStore(
|
|||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
result = yield self._bulk_get_push_rules_for_room(
|
||||
event.room_id, state_group, current_state_ids, event=event
|
||||
)
|
||||
|
|
|
@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
result = yield self._get_joined_users_from_context(
|
||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||
)
|
||||
|
|
|
@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
room_id
|
||||
)
|
||||
|
||||
users_with_profile = yield state.get_current_users_in_room(room_id)
|
||||
users_with_profile = yield defer.ensureDeferred(
|
||||
state.get_current_users_in_room(room_id)
|
||||
)
|
||||
user_ids = set(users_with_profile)
|
||||
|
||||
# Update each user in the user directory.
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue