0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 11:20:26 +01:00

Convert the message handler to async/await. (#7884)

This commit is contained in:
Patrick Cloke 2020-07-22 12:29:15 -04:00 committed by GitHub
parent a4cf94a3c2
commit cc9bb3dc3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 273 additions and 238 deletions

1
changelog.d/7884.misc Normal file
View file

@ -0,0 +1 @@
Convert the message handler to async/await.

View file

@ -15,12 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
from twisted.internet.interfaces import IDelayedCall from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth from synapse import event_auth
@ -41,13 +39,22 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import Collection, RoomAlias, UserID, create_requester from synapse.types import (
Collection,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -84,14 +91,22 @@ class MessageHandler(object):
"_schedule_next_expiry", self._schedule_next_expiry "_schedule_next_expiry", self._schedule_next_expiry
) )
@defer.inlineCallbacks async def get_room_data(
def get_room_data( self,
self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False user_id: str = None,
): room_id: str = None,
event_type: Optional[str] = None,
state_key: str = "",
is_guest: bool = False,
) -> dict:
""" Get data from a room. """ Get data from a room.
Args: Args:
event : The room path event user_id
room_id
event_type
state_key
is_guest
Returns: Returns:
The path data content. The path data content.
Raises: Raises:
@ -100,30 +115,29 @@ class MessageHandler(object):
( (
membership, membership,
membership_event_id, membership_event_id,
) = yield self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, user_id, allow_departed_users=True
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
data = yield self.state.get_current_state(room_id, event_type, state_key) data = await self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.state_store.get_state_for_events( room_state = await self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key]) [membership_event_id], StateFilter.from_types([key])
) )
data = room_state[membership_event_id].get(key) data = room_state[membership_event_id].get(key)
return data return data
@defer.inlineCallbacks async def get_state_events(
def get_state_events(
self, self,
user_id, user_id: str,
room_id, room_id: str,
state_filter=StateFilter.all(), state_filter: StateFilter = StateFilter.all(),
at_token=None, at_token: Optional[StreamToken] = None,
is_guest=False, is_guest: bool = False,
): ) -> List[dict]:
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. If an explicit left the room return the state events from when they left. If an explicit
@ -131,15 +145,14 @@ class MessageHandler(object):
visible. visible.
Args: Args:
user_id(str): The user requesting state events. user_id: The user requesting state events.
room_id(str): The room ID to get all state events from. room_id: The room ID to get all state events from.
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database. at_token: the stream token of the at which we are requesting
at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table. state based on the current_state_events table.
is_guest(bool): whether this user is a guest is_guest: whether this user is a guest
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
Raises: Raises:
@ -153,20 +166,20 @@ class MessageHandler(object):
# get_recent_events_for_room operates by topo ordering. This therefore # get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position. # does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305) # (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room( last_events, _ = await self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1 room_id, end_token=at_token.room_key, limit=1
) )
if not last_events: if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token,)) raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client( visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False self.storage, user_id, last_events, filter_send_to_client=False
) )
event = last_events[0] event = last_events[0]
if visible_events: if visible_events:
room_state = yield self.state_store.get_state_for_events( room_state = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter [event.event_id], state_filter=state_filter
) )
room_state = room_state[event.event_id] room_state = room_state[event.event_id]
@ -180,23 +193,23 @@ class MessageHandler(object):
( (
membership, membership,
membership_event_id, membership_event_id,
) = yield self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, user_id, allow_departed_users=True
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids( state_ids = await self.store.get_filtered_current_state_ids(
room_id, state_filter=state_filter room_id, state_filter=state_filter
) )
room_state = yield self.store.get_events(state_ids.values()) room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.state_store.get_state_for_events( room_state = await self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter [membership_event_id], state_filter=state_filter
) )
room_state = room_state[membership_event_id] room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events( events = await self._event_serializer.serialize_events(
room_state.values(), room_state.values(),
now, now,
# We don't bother bundling aggregations in when asked for state # We don't bother bundling aggregations in when asked for state
@ -205,15 +218,14 @@ class MessageHandler(object):
) )
return events return events
@defer.inlineCallbacks async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
def get_joined_members(self, requester, room_id):
"""Get all the joined members in the room and their profile information. """Get all the joined members in the room and their profile information.
If the user has left the room return the state events from when they left. If the user has left the room return the state events from when they left.
Args: Args:
requester(Requester): The user requesting state events. requester: The user requesting state events.
room_id(str): The room ID to get all state events from. room_id: The room ID to get all state events from.
Returns: Returns:
A dict of user_id to profile info A dict of user_id to profile info
""" """
@ -221,7 +233,7 @@ class MessageHandler(object):
if not requester.app_service: if not requester.app_service:
# We check AS auth after fetching the room membership, as it # We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway. # requires us to pull out all joined members anyway.
membership, _ = yield self.auth.check_user_in_room_or_world_readable( membership, _ = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, user_id, allow_departed_users=True
) )
if membership != Membership.JOIN: if membership != Membership.JOIN:
@ -229,7 +241,7 @@ class MessageHandler(object):
"Getting joined members after leaving is not implemented" "Getting joined members after leaving is not implemented"
) )
users_with_profile = yield self.state.get_current_users_in_room(room_id) users_with_profile = await self.state.get_current_users_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members. # If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there # This can either be because the AS user is in the room or because there
@ -250,7 +262,7 @@ class MessageHandler(object):
for user_id, profile in users_with_profile.items() for user_id, profile in users_with_profile.items()
} }
def maybe_schedule_expiry(self, event): def maybe_schedule_expiry(self, event: EventBase):
"""Schedule the expiry of an event if there's not already one scheduled, """Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided or if the one running is for an event that will expire after the provided
timestamp. timestamp.
@ -259,7 +271,7 @@ class MessageHandler(object):
the master process, and therefore needs to be run on there. the master process, and therefore needs to be run on there.
Args: Args:
event (EventBase): The event to schedule the expiry of. event: The event to schedule the expiry of.
""" """
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
@ -270,8 +282,7 @@ class MessageHandler(object):
# a task scheduled for a timestamp that's sooner than the provided one. # a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts) self._schedule_expiry_for_event(event.event_id, expiry_ts)
@defer.inlineCallbacks async def _schedule_next_expiry(self):
def _schedule_next_expiry(self):
"""Retrieve the ID and the expiry timestamp of the next event to be expired, """Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it. and schedule an expiry task for it.
@ -279,18 +290,18 @@ class MessageHandler(object):
future call to save_expiry_ts can schedule a new expiry task. future call to save_expiry_ts can schedule a new expiry task.
""" """
# Try to get the expiry timestamp of the next event to expire. # Try to get the expiry timestamp of the next event to expire.
res = yield self.store.get_next_event_to_expire() res = await self.store.get_next_event_to_expire()
if res: if res:
event_id, expiry_ts = res event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts) self._schedule_expiry_for_event(event_id, expiry_ts)
def _schedule_expiry_for_event(self, event_id, expiry_ts): def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
"""Schedule an expiry task for the provided event if there's not already one """Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one. scheduled at a timestamp that's sooner than the provided one.
Args: Args:
event_id (str): The ID of the event to expire. event_id: The ID of the event to expire.
expiry_ts (int): The timestamp at which to expire the event. expiry_ts: The timestamp at which to expire the event.
""" """
if self._scheduled_expiry: if self._scheduled_expiry:
# If the provided timestamp refers to a time before the scheduled time of the # If the provided timestamp refers to a time before the scheduled time of the
@ -320,8 +331,7 @@ class MessageHandler(object):
event_id, event_id,
) )
@defer.inlineCallbacks async def _expire_event(self, event_id: str):
def _expire_event(self, event_id):
"""Retrieve and expire an event that needs to be expired from the database. """Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date If the event doesn't exist in the database, log it and delete the expiry date
@ -336,12 +346,12 @@ class MessageHandler(object):
try: try:
# Expire the event if we know about it. This function also deletes the expiry # Expire the event if we know about it. This function also deletes the expiry
# date from the database in the same database transaction. # date from the database in the same database transaction.
yield self.store.expire_event(event_id) await self.store.expire_event(event_id)
except Exception as e: except Exception as e:
logger.error("Could not expire event %s: %r", event_id, e) logger.error("Could not expire event %s: %r", event_id, e)
# Schedule the expiry of the next event to expire. # Schedule the expiry of the next event to expire.
yield self._schedule_next_expiry() await self._schedule_next_expiry()
# The duration (in ms) after which rooms should be removed # The duration (in ms) after which rooms should be removed
@ -423,16 +433,15 @@ class EventCreationHandler(object):
self._dummy_events_threshold = hs.config.dummy_events_threshold self._dummy_events_threshold = hs.config.dummy_events_threshold
@defer.inlineCallbacks async def create_event(
def create_event(
self, self,
requester, requester: Requester,
event_dict, event_dict: dict,
token_id=None, token_id: Optional[str] = None,
txn_id=None, txn_id: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None, prev_event_ids: Optional[Collection[str]] = None,
require_consent=True, require_consent: bool = True,
): ) -> Tuple[EventBase, EventContext]:
""" """
Given a dict from a client, create a new event. Given a dict from a client, create a new event.
@ -443,31 +452,29 @@ class EventCreationHandler(object):
Args: Args:
requester requester
event_dict (dict): An entire event event_dict: An entire event
token_id (str) token_id
txn_id (str) txn_id
prev_event_ids: prev_event_ids:
the forward extremities to use as the prev_events for the the forward extremities to use as the prev_events for the
new event. new event.
If None, they will be requested from the database. If None, they will be requested from the database.
require_consent: Whether to check if the requester has
require_consent (bool): Whether to check if the requester has consented to the privacy policy.
consented to privacy policy.
Raises: Raises:
ResourceLimitError if server is blocked to some resource being ResourceLimitError if server is blocked to some resource being
exceeded exceeded
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event, Context
""" """
yield self.auth.check_auth_blocking(requester.user.to_string()) await self.auth.check_auth_blocking(requester.user.to_string())
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"] room_version = event_dict["content"]["room_version"]
else: else:
try: try:
room_version = yield self.store.get_room_version_id( room_version = await self.store.get_room_version_id(
event_dict["room_id"] event_dict["room_id"]
) )
except NotFoundError: except NotFoundError:
@ -488,15 +495,11 @@ class EventCreationHandler(object):
try: try:
if "displayname" not in content: if "displayname" not in content:
displayname = yield defer.ensureDeferred( displayname = await profile.get_displayname(target)
profile.get_displayname(target)
)
if displayname is not None: if displayname is not None:
content["displayname"] = displayname content["displayname"] = displayname
if "avatar_url" not in content: if "avatar_url" not in content:
avatar_url = yield defer.ensureDeferred( avatar_url = await profile.get_avatar_url(target)
profile.get_avatar_url(target)
)
if avatar_url is not None: if avatar_url is not None:
content["avatar_url"] = avatar_url content["avatar_url"] = avatar_url
except Exception as e: except Exception as e:
@ -504,9 +507,9 @@ class EventCreationHandler(object):
"Failed to get profile information for %r: %s", target, e "Failed to get profile information for %r: %s", target, e
) )
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt: if require_consent and not is_exempt:
yield self.assert_accepted_privacy_policy(requester) await self.assert_accepted_privacy_policy(requester)
if token_id is not None: if token_id is not None:
builder.internal_metadata.token_id = token_id builder.internal_metadata.token_id = token_id
@ -514,7 +517,7 @@ class EventCreationHandler(object):
if txn_id is not None: if txn_id is not None:
builder.internal_metadata.txn_id = txn_id builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event( event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids, builder=builder, requester=requester, prev_event_ids=prev_event_ids,
) )
@ -530,10 +533,10 @@ class EventCreationHandler(object):
# federation as well as those created locally. As of room v3, aliases events # federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to # can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check(). # tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = ( prev_event = (
yield self.store.get_event(prev_event_id, allow_none=True) await self.store.get_event(prev_event_id, allow_none=True)
if prev_event_id if prev_event_id
else None else None
) )
@ -556,37 +559,36 @@ class EventCreationHandler(object):
return (event, context) return (event, context)
def _is_exempt_from_privacy_policy(self, builder, requester): async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
) -> bool:
""""Determine if an event to be sent is exempt from having to consent """"Determine if an event to be sent is exempt from having to consent
to the privacy policy to the privacy policy
Args: Args:
builder (synapse.events.builder.EventBuilder): event being created builder: event being created
requester (Requster): user requesting this event requester: user requesting this event
Returns: Returns:
Deferred[bool]: true if the event can be sent without the user true if the event can be sent without the user consenting
consenting
""" """
# the only thing the user can do is join the server notices room. # the only thing the user can do is join the server notices room.
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
if membership == Membership.JOIN: if membership == Membership.JOIN:
return self._is_server_notices_room(builder.room_id) return await self._is_server_notices_room(builder.room_id)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
# the user is always allowed to leave (but not kick people) # the user is always allowed to leave (but not kick people)
return builder.state_key == requester.user.to_string() return builder.state_key == requester.user.to_string()
return succeed(False) return False
@defer.inlineCallbacks async def _is_server_notices_room(self, room_id: str) -> bool:
def _is_server_notices_room(self, room_id):
if self.config.server_notices_mxid is None: if self.config.server_notices_mxid is None:
return False return False
user_ids = yield self.store.get_users_in_room(room_id) user_ids = await self.store.get_users_in_room(room_id)
return self.config.server_notices_mxid in user_ids return self.config.server_notices_mxid in user_ids
@defer.inlineCallbacks async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
def assert_accepted_privacy_policy(self, requester):
"""Check if a user has accepted the privacy policy """Check if a user has accepted the privacy policy
Called when the given user is about to do something that requires Called when the given user is about to do something that requires
@ -595,12 +597,10 @@ class EventCreationHandler(object):
raised. raised.
Args: Args:
requester (synapse.types.Requester): requester: The user making the request
The user making the request
Returns: Returns:
Deferred[None]: returns normally if the user has consented or is Returns normally if the user has consented or is exempt
exempt
Raises: Raises:
ConsentNotGivenError: if the user has not given consent yet ConsentNotGivenError: if the user has not given consent yet
@ -621,7 +621,7 @@ class EventCreationHandler(object):
): ):
return return
u = yield self.store.get_user_by_id(user_id) u = await self.store.get_user_by_id(user_id)
assert u is not None assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent # support and bot users are not required to consent
@ -639,16 +639,20 @@ class EventCreationHandler(object):
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
async def send_nonmember_event( async def send_nonmember_event(
self, requester, event, context, ratelimit=True self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
) -> int: ) -> int:
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
Args: Args:
event (FrozenEvent) the event to send. requester
context (Context) the context of the event. event the event to send.
ratelimit (bool): Whether to rate limit this send. context: the context of the event.
is_guest (bool): Whether the sender is a guest. ratelimit: Whether to rate limit this send.
Return: Return:
The stream_id of the persisted event. The stream_id of the persisted event.
@ -676,19 +680,20 @@ class EventCreationHandler(object):
requester=requester, event=event, context=context, ratelimit=ratelimit requester=requester, event=event, context=context, ratelimit=ratelimit
) )
@defer.inlineCallbacks async def deduplicate_state_event(
def deduplicate_state_event(self, event, context): self, event: EventBase, context: EventContext
) -> None:
""" """
Checks whether event is in the latest resolved state in context. Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key)) prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id: if not prev_event_id:
return return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event: if not prev_event:
return return
@ -700,7 +705,11 @@ class EventCreationHandler(object):
return return
async def create_and_send_nonmember_event( async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None self,
requester: Requester,
event_dict: EventBase,
ratelimit: bool = True,
txn_id: Optional[str] = None,
) -> Tuple[EventBase, int]: ) -> Tuple[EventBase, int]:
""" """
Creates an event, then sends it. Creates an event, then sends it.
@ -730,17 +739,17 @@ class EventCreationHandler(object):
return event, stream_id return event, stream_id
@measure_func("create_new_client_event") @measure_func("create_new_client_event")
@defer.inlineCallbacks async def create_new_client_event(
def create_new_client_event( self,
self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None builder: EventBuilder,
): requester: Optional[Requester] = None,
prev_event_ids: Optional[Collection[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client """Create a new event for a local client
Args: Args:
builder (EventBuilder): builder:
requester:
requester (synapse.types.Requester|None):
prev_event_ids: prev_event_ids:
the forward extremities to use as the prev_events for the the forward extremities to use as the prev_events for the
new event. new event.
@ -748,7 +757,7 @@ class EventCreationHandler(object):
If None, they will be requested from the database. If None, they will be requested from the database.
Returns: Returns:
Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] Tuple of created event, context
""" """
if prev_event_ids is not None: if prev_event_ids is not None:
@ -757,10 +766,10 @@ class EventCreationHandler(object):
% (len(prev_event_ids),) % (len(prev_event_ids),)
) )
else: else:
prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
event = yield builder.build(prev_event_ids=prev_event_ids) event = await builder.build(prev_event_ids=prev_event_ids)
context = yield self.state.compute_event_context(event) context = await self.state.compute_event_context(event)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
@ -774,7 +783,7 @@ class EventCreationHandler(object):
relates_to = relation["event_id"] relates_to = relation["event_id"]
aggregation_key = relation["key"] aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event( already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender relates_to, event.type, aggregation_key, event.sender
) )
if already_exists: if already_exists:
@ -786,7 +795,12 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event") @measure_func("handle_new_client_event")
async def handle_new_client_event( async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
) -> int: ) -> int:
"""Processes a new event. This includes checking auth, persisting it, """Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc. notifying users, sending to remote servers, etc.
@ -795,11 +809,11 @@ class EventCreationHandler(object):
processing. processing.
Args: Args:
requester (Requester) requester
event (FrozenEvent) event
context (EventContext) context
ratelimit (bool) ratelimit
extra_users (list(UserID)): Any extra users to notify about event extra_users: Any extra users to notify about event
Return: Return:
The stream_id of the persisted event. The stream_id of the persisted event.
@ -878,10 +892,9 @@ class EventCreationHandler(object):
self.store.remove_push_actions_from_staging, event.event_id self.store.remove_push_actions_from_staging, event.event_id
) )
@defer.inlineCallbacks async def _validate_canonical_alias(
def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str
self, directory_handler, room_alias_str, expected_room_id ) -> None:
):
""" """
Ensure that the given room alias points to the expected room ID. Ensure that the given room alias points to the expected room ID.
@ -892,9 +905,7 @@ class EventCreationHandler(object):
""" """
room_alias = RoomAlias.from_string(room_alias_str) room_alias = RoomAlias.from_string(room_alias_str)
try: try:
mapping = yield defer.ensureDeferred( mapping = await directory_handler.get_association(room_alias)
directory_handler.get_association(room_alias)
)
except SynapseError as e: except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND: if e.errcode == Codes.NOT_FOUND:
@ -913,7 +924,12 @@ class EventCreationHandler(object):
) )
async def persist_and_notify_client_event( async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
) -> int: ) -> int:
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth. calculated the push actions for the event, and checked auth.
@ -1106,7 +1122,7 @@ class EventCreationHandler(object):
return event_stream_id return event_stream_id
async def _bump_active_time(self, user): async def _bump_active_time(self, user: UserID) -> None:
try: try:
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()
await presence.bump_presence_active_time(user) await presence.bump_presence_active_time(user)

View file

@ -41,9 +41,11 @@ class TestEventContext(unittest.HomeserverTestCase):
serialize/deserialize. serialize/deserialize.
""" """
event, context = create_event( event, context = self.get_success(
create_event(
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
) )
)
self._check_serialize_deserialize(event, context) self._check_serialize_deserialize(event, context)
@ -51,13 +53,15 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (with not previous entry) """Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize. is the same after serialize/deserialize.
""" """
event, context = create_event( event, context = self.get_success(
create_event(
self.hs, self.hs,
room_id=self.room_id, room_id=self.room_id,
type="m.test", type="m.test",
sender=self.user_id, sender=self.user_id,
state_key="", state_key="",
) )
)
self._check_serialize_deserialize(event, context) self._check_serialize_deserialize(event, context)
@ -65,7 +69,8 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (which replaces a """Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize. previous entry) is the same after serialize/deserialize.
""" """
event, context = create_event( event, context = self.get_success(
create_event(
self.hs, self.hs,
room_id=self.room_id, room_id=self.room_id,
type="m.room.member", type="m.room.member",
@ -73,6 +78,7 @@ class TestEventContext(unittest.HomeserverTestCase):
state_key=self.user_id, state_key=self.user_id,
content={"membership": "leave"}, content={"membership": "leave"},
) )
)
self._check_serialize_deserialize(event, context) self._check_serialize_deserialize(event, context)

View file

@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
OTHER_USER = "@other_user:localhost" OTHER_USER = "@other_user:localhost"
# have the user join # have the user join
self.get_success(
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
)
# Update existing power levels with mod at PL50 # Update existing power levels with mod at PL50
pls = self.helper.get_state( pls = self.helper.get_state(
@ -157,7 +159,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
# roll back all the state by de-modding the user # roll back all the state by de-modding the user
prev_events = fork_point prev_events = fork_point
pls["users"][OTHER_USER] = 0 pls["users"][OTHER_USER] = 0
pl_event = inject_event( pl_event = self.get_success(
inject_event(
self.hs, self.hs,
prev_event_ids=prev_events, prev_event_ids=prev_events,
type=EventTypes.PowerLevels, type=EventTypes.PowerLevels,
@ -166,6 +169,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room_id=self.room_id, room_id=self.room_id,
content=pls, content=pls,
) )
)
# one more bit of state that doesn't get rolled back # one more bit of state that doesn't get rolled back
state2 = self._inject_state_event() state2 = self._inject_state_event()
@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# have the users join # have the users join
for u in user_ids: for u in user_ids:
self.get_success(
inject_member_event(self.hs, self.room_id, u, Membership.JOIN) inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
)
# Update existing power levels with mod at PL50 # Update existing power levels with mod at PL50
pls = self.helper.get_state( pls = self.helper.get_state(
@ -306,7 +312,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_events = [] pl_events = []
for u in user_ids: for u in user_ids:
pls["users"][u] = 0 pls["users"][u] = 0
e = inject_event( e = self.get_success(
inject_event(
self.hs, self.hs,
prev_event_ids=prev_events, prev_event_ids=prev_events,
type=EventTypes.PowerLevels, type=EventTypes.PowerLevels,
@ -315,6 +322,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room_id=self.room_id, room_id=self.room_id,
content=pls, content=pls,
) )
)
prev_events = [e.event_id] prev_events = [e.event_id]
pl_events.append(e) pl_events.append(e)
@ -434,7 +442,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
body = "event %i" % (self.event_count,) body = "event %i" % (self.event_count,)
self.event_count += 1 self.event_count += 1
return inject_event( return self.get_success(
inject_event(
self.hs, self.hs,
room_id=self.room_id, room_id=self.room_id,
sender=sender, sender=sender,
@ -442,6 +451,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
content={"body": body}, content={"body": body},
**kwargs **kwargs
) )
)
def _inject_state_event( def _inject_state_event(
self, self,
@ -459,7 +469,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
if body is None: if body is None:
body = "state event %s" % (state_key,) body = "state event %s" % (state_key,)
return inject_event( return self.get_success(
inject_event(
self.hs, self.hs,
room_id=self.room_id, room_id=self.room_id,
sender=sender, sender=sender,
@ -467,3 +478,4 @@ class EventsStreamTestCase(BaseStreamTestCase):
state_key=state_key, state_key=state_key,
content={"body": body}, content={"body": body},
) )
)

View file

@ -118,12 +118,15 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test_get_joined_users_from_context(self): def test_get_joined_users_from_context(self):
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
bob_event = event_injection.inject_member_event( bob_event = self.get_success(
event_injection.inject_member_event(
self.hs, room, self.u_bob, Membership.JOIN self.hs, room, self.u_bob, Membership.JOIN
) )
)
# first, create a regular event # first, create a regular event
event, context = event_injection.create_event( event, context = self.get_success(
event_injection.create_event(
self.hs, self.hs,
room_id=room, room_id=room,
sender=self.u_alice, sender=self.u_alice,
@ -131,6 +134,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
type="m.test.1", type="m.test.1",
content={}, content={},
) )
)
users = self.get_success( users = self.get_success(
self.store.get_joined_users_from_context(event, context) self.store.get_joined_users_from_context(event, context)
@ -140,7 +144,8 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Regression test for #7376: create a state event whose key matches bob's # Regression test for #7376: create a state event whose key matches bob's
# user_id, but which is *not* a membership event, and persist that; then check # user_id, but which is *not* a membership event, and persist that; then check
# that `get_joined_users_from_context` returns the correct users for the next event. # that `get_joined_users_from_context` returns the correct users for the next event.
non_member_event = event_injection.inject_event( non_member_event = self.get_success(
event_injection.inject_event(
self.hs, self.hs,
room_id=room, room_id=room,
sender=self.u_bob, sender=self.u_bob,
@ -149,7 +154,9 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
state_key=self.u_bob, state_key=self.u_bob,
content={}, content={},
) )
event, context = event_injection.create_event( )
event, context = self.get_success(
event_injection.create_event(
self.hs, self.hs,
room_id=room, room_id=room,
sender=self.u_alice, sender=self.u_alice,
@ -157,6 +164,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
type="m.test.3", type="m.test.3",
content={}, content={},
) )
)
users = self.get_success( users = self.get_success(
self.store.get_joined_users_from_context(event, context) self.store.get_joined_users_from_context(event, context)
) )

View file

@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield defer.ensureDeferred(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield self.storage.persistence.persist_event(event, context)

View file

@ -22,14 +22,12 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import Collection from synapse.types import Collection
from tests.test_utils import get_awaitable_result
""" """
Utility functions for poking events into the storage of the server under test. Utility functions for poking events into the storage of the server under test.
""" """
def inject_member_event( async def inject_member_event(
hs: synapse.server.HomeServer, hs: synapse.server.HomeServer,
room_id: str, room_id: str,
sender: str, sender: str,
@ -46,7 +44,7 @@ def inject_member_event(
if extra_content: if extra_content:
content.update(extra_content) content.update(extra_content)
return inject_event( return await inject_event(
hs, hs,
room_id=room_id, room_id=room_id,
type=EventTypes.Member, type=EventTypes.Member,
@ -57,7 +55,7 @@ def inject_member_event(
) )
def inject_event( async def inject_event(
hs: synapse.server.HomeServer, hs: synapse.server.HomeServer,
room_version: Optional[str] = None, room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None, prev_event_ids: Optional[Collection[str]] = None,
@ -72,37 +70,27 @@ def inject_event(
prev_event_ids: prev_events for the event. If not specified, will be looked up prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created kwargs: fields for the event to be created
""" """
test_reactor = hs.get_reactor() event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
event, context = create_event(hs, room_version, prev_event_ids, **kwargs) await hs.get_storage().persistence.persist_event(event, context)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event return event
def create_event( async def create_event(
hs: synapse.server.HomeServer, hs: synapse.server.HomeServer,
room_version: Optional[str] = None, room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None, prev_event_ids: Optional[Collection[str]] = None,
**kwargs **kwargs
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
test_reactor = hs.get_reactor()
if room_version is None: if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
room_version = get_awaitable_result(d)
builder = hs.get_event_builder_factory().for_room_version( builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs KNOWN_ROOM_VERSIONS[room_version], kwargs
) )
d = hs.get_event_creation_handler().create_new_client_event( event, context = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids builder, prev_event_ids=prev_event_ids
) )
test_reactor.advance(0)
event, context = get_awaitable_result(d)
return event, context return event, context

View file

@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# #
# before we do that, we persist some other events to act as state. # before we do that, we persist some other events to act as state.
self.inject_visibility("@admin:hs", "joined") yield self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10): for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i) yield self.inject_room_member("@resident%i:hs" % i)
@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield defer.ensureDeferred(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield self.storage.persistence.persist_event(event, context)
return event return event
@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield defer.ensureDeferred(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield self.storage.persistence.persist_event(event, context)
@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
}, },
) )
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield defer.ensureDeferred(
builder self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield self.storage.persistence.persist_event(event, context)

View file

@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
user: MXID of the user to inject the membership for. user: MXID of the user to inject the membership for.
membership: The membership type. membership: The membership type.
""" """
self.get_success(
event_injection.inject_member_event(self.hs, room, user, membership) event_injection.inject_member_event(self.hs, room, user, membership)
)
class FederatingHomeserverTestCase(HomeserverTestCase): class FederatingHomeserverTestCase(HomeserverTestCase):

View file

@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id):
}, },
) )
event, context = yield event_creation_handler.create_new_client_event(builder) event, context = yield defer.ensureDeferred(
event_creation_handler.create_new_client_event(builder)
)
yield persistence_store.persist_event(event, context) yield persistence_store.persist_event(event, context)