Convert the message handler to async/await. (#7884)

This commit is contained in:
Patrick Cloke 2020-07-22 12:29:15 -04:00 committed by GitHub
parent a4cf94a3c2
commit cc9bb3dc3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 273 additions and 238 deletions

1
changelog.d/7884.misc Normal file
View file

@ -0,0 +1 @@
Convert the message handler to async/await.

View file

@ -15,12 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth
@ -41,13 +39,22 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Collection, RoomAlias, UserID, create_requester
from synapse.types import (
Collection,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
@ -84,14 +91,22 @@ class MessageHandler(object):
"_schedule_next_expiry", self._schedule_next_expiry
)
@defer.inlineCallbacks
def get_room_data(
self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
):
async def get_room_data(
self,
user_id: str = None,
room_id: str = None,
event_type: Optional[str] = None,
state_key: str = "",
is_guest: bool = False,
) -> dict:
""" Get data from a room.
Args:
event : The room path event
user_id
room_id
event_type
state_key
is_guest
Returns:
The path data content.
Raises:
@ -100,30 +115,29 @@ class MessageHandler(object):
(
membership,
membership_event_id,
) = yield self.auth.check_user_in_room_or_world_readable(
) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN:
data = yield self.state.get_current_state(room_id, event_type, state_key)
data = await self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.state_store.get_state_for_events(
room_state = await self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
return data
@defer.inlineCallbacks
def get_state_events(
async def get_state_events(
self,
user_id,
room_id,
state_filter=StateFilter.all(),
at_token=None,
is_guest=False,
):
user_id: str,
room_id: str,
state_filter: StateFilter = StateFilter.all(),
at_token: Optional[StreamToken] = None,
is_guest: bool = False,
) -> List[dict]:
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left. If an explicit
@ -131,15 +145,14 @@ class MessageHandler(object):
visible.
Args:
user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from.
state_filter (StateFilter): The state filter used to fetch state
from the database.
at_token(StreamToken|None): the stream token of the at which we are requesting
user_id: The user requesting state events.
room_id: The room ID to get all state events from.
state_filter: The state filter used to fetch state from the database.
at_token: the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
is_guest(bool): whether this user is a guest
is_guest: whether this user is a guest
Returns:
A list of dicts representing state events. [{}, {}, {}]
Raises:
@ -153,20 +166,20 @@ class MessageHandler(object):
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
last_events, _ = await self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1
)
if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False
)
event = last_events[0]
if visible_events:
room_state = yield self.state_store.get_state_for_events(
room_state = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
@ -180,23 +193,23 @@ class MessageHandler(object):
(
membership,
membership_event_id,
) = yield self.auth.check_user_in_room_or_world_readable(
) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
state_ids = await self.store.get_filtered_current_state_ids(
room_id, state_filter=state_filter
)
room_state = yield self.store.get_events(state_ids.values())
room_state = await self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.state_store.get_state_for_events(
room_state = await self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
events = await self._event_serializer.serialize_events(
room_state.values(),
now,
# We don't bother bundling aggregations in when asked for state
@ -205,15 +218,14 @@ class MessageHandler(object):
)
return events
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
"""Get all the joined members in the room and their profile information.
If the user has left the room return the state events from when they left.
Args:
requester(Requester): The user requesting state events.
room_id(str): The room ID to get all state events from.
requester: The user requesting state events.
room_id: The room ID to get all state events from.
Returns:
A dict of user_id to profile info
"""
@ -221,7 +233,7 @@ class MessageHandler(object):
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
membership, _ = yield self.auth.check_user_in_room_or_world_readable(
membership, _ = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
if membership != Membership.JOIN:
@ -229,7 +241,7 @@ class MessageHandler(object):
"Getting joined members after leaving is not implemented"
)
users_with_profile = yield self.state.get_current_users_in_room(room_id)
users_with_profile = await self.state.get_current_users_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
@ -250,7 +262,7 @@ class MessageHandler(object):
for user_id, profile in users_with_profile.items()
}
def maybe_schedule_expiry(self, event):
def maybe_schedule_expiry(self, event: EventBase):
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
@ -259,7 +271,7 @@ class MessageHandler(object):
the master process, and therefore needs to be run on there.
Args:
event (EventBase): The event to schedule the expiry of.
event: The event to schedule the expiry of.
"""
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
@ -270,8 +282,7 @@ class MessageHandler(object):
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
@defer.inlineCallbacks
def _schedule_next_expiry(self):
async def _schedule_next_expiry(self):
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
@ -279,18 +290,18 @@ class MessageHandler(object):
future call to save_expiry_ts can schedule a new expiry task.
"""
# Try to get the expiry timestamp of the next event to expire.
res = yield self.store.get_next_event_to_expire()
res = await self.store.get_next_event_to_expire()
if res:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
def _schedule_expiry_for_event(self, event_id, expiry_ts):
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
Args:
event_id (str): The ID of the event to expire.
expiry_ts (int): The timestamp at which to expire the event.
event_id: The ID of the event to expire.
expiry_ts: The timestamp at which to expire the event.
"""
if self._scheduled_expiry:
# If the provided timestamp refers to a time before the scheduled time of the
@ -320,8 +331,7 @@ class MessageHandler(object):
event_id,
)
@defer.inlineCallbacks
def _expire_event(self, event_id):
async def _expire_event(self, event_id: str):
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
@ -336,12 +346,12 @@ class MessageHandler(object):
try:
# Expire the event if we know about it. This function also deletes the expiry
# date from the database in the same database transaction.
yield self.store.expire_event(event_id)
await self.store.expire_event(event_id)
except Exception as e:
logger.error("Could not expire event %s: %r", event_id, e)
# Schedule the expiry of the next event to expire.
yield self._schedule_next_expiry()
await self._schedule_next_expiry()
# The duration (in ms) after which rooms should be removed
@ -423,16 +433,15 @@ class EventCreationHandler(object):
self._dummy_events_threshold = hs.config.dummy_events_threshold
@defer.inlineCallbacks
def create_event(
async def create_event(
self,
requester,
event_dict,
token_id=None,
txn_id=None,
requester: Requester,
event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
require_consent=True,
):
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
Given a dict from a client, create a new event.
@ -443,31 +452,29 @@ class EventCreationHandler(object):
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
event_dict: An entire event
token_id
txn_id
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
If None, they will be requested from the database.
require_consent (bool): Whether to check if the requester has
consented to privacy policy.
require_consent: Whether to check if the requester has
consented to the privacy policy.
Raises:
ResourceLimitError if server is blocked to some resource being
exceeded
Returns:
Tuple of created event (FrozenEvent), Context
Tuple of created event, Context
"""
yield self.auth.check_auth_blocking(requester.user.to_string())
await self.auth.check_auth_blocking(requester.user.to_string())
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
else:
try:
room_version = yield self.store.get_room_version_id(
room_version = await self.store.get_room_version_id(
event_dict["room_id"]
)
except NotFoundError:
@ -488,15 +495,11 @@ class EventCreationHandler(object):
try:
if "displayname" not in content:
displayname = yield defer.ensureDeferred(
profile.get_displayname(target)
)
displayname = await profile.get_displayname(target)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = yield defer.ensureDeferred(
profile.get_avatar_url(target)
)
avatar_url = await profile.get_avatar_url(target)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
@ -504,9 +507,9 @@ class EventCreationHandler(object):
"Failed to get profile information for %r: %s", target, e
)
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt:
yield self.assert_accepted_privacy_policy(requester)
await self.assert_accepted_privacy_policy(requester)
if token_id is not None:
builder.internal_metadata.token_id = token_id
@ -514,7 +517,7 @@ class EventCreationHandler(object):
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
)
@ -530,10 +533,10 @@ class EventCreationHandler(object):
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
yield self.store.get_event(prev_event_id, allow_none=True)
await self.store.get_event(prev_event_id, allow_none=True)
if prev_event_id
else None
)
@ -556,37 +559,36 @@ class EventCreationHandler(object):
return (event, context)
def _is_exempt_from_privacy_policy(self, builder, requester):
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
) -> bool:
""""Determine if an event to be sent is exempt from having to consent
to the privacy policy
Args:
builder (synapse.events.builder.EventBuilder): event being created
requester (Requster): user requesting this event
builder: event being created
requester: user requesting this event
Returns:
Deferred[bool]: true if the event can be sent without the user
consenting
true if the event can be sent without the user consenting
"""
# the only thing the user can do is join the server notices room.
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
return self._is_server_notices_room(builder.room_id)
return await self._is_server_notices_room(builder.room_id)
elif membership == Membership.LEAVE:
# the user is always allowed to leave (but not kick people)
return builder.state_key == requester.user.to_string()
return succeed(False)
return False
@defer.inlineCallbacks
def _is_server_notices_room(self, room_id):
async def _is_server_notices_room(self, room_id: str) -> bool:
if self.config.server_notices_mxid is None:
return False
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = await self.store.get_users_in_room(room_id)
return self.config.server_notices_mxid in user_ids
@defer.inlineCallbacks
def assert_accepted_privacy_policy(self, requester):
async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy
Called when the given user is about to do something that requires
@ -595,12 +597,10 @@ class EventCreationHandler(object):
raised.
Args:
requester (synapse.types.Requester):
The user making the request
requester: The user making the request
Returns:
Deferred[None]: returns normally if the user has consented or is
exempt
Returns normally if the user has consented or is exempt
Raises:
ConsentNotGivenError: if the user has not given consent yet
@ -621,7 +621,7 @@ class EventCreationHandler(object):
):
return
u = yield self.store.get_user_by_id(user_id)
u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
@ -639,16 +639,20 @@ class EventCreationHandler(object):
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
async def send_nonmember_event(
self, requester, event, context, ratelimit=True
self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
) -> int:
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
requester
event the event to send.
context: the context of the event.
ratelimit: Whether to rate limit this send.
Return:
The stream_id of the persisted event.
@ -676,19 +680,20 @@ class EventCreationHandler(object):
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
) -> None:
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@ -700,7 +705,11 @@ class EventCreationHandler(object):
return
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
self,
requester: Requester,
event_dict: EventBase,
ratelimit: bool = True,
txn_id: Optional[str] = None,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
@ -730,17 +739,17 @@ class EventCreationHandler(object):
return event, stream_id
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(
self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
):
async def create_new_client_event(
self,
builder: EventBuilder,
requester: Optional[Requester] = None,
prev_event_ids: Optional[Collection[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
Args:
builder (EventBuilder):
requester (synapse.types.Requester|None):
builder:
requester:
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
@ -748,7 +757,7 @@ class EventCreationHandler(object):
If None, they will be requested from the database.
Returns:
Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
Tuple of created event, context
"""
if prev_event_ids is not None:
@ -757,10 +766,10 @@ class EventCreationHandler(object):
% (len(prev_event_ids),)
)
else:
prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
event = yield builder.build(prev_event_ids=prev_event_ids)
context = yield self.state.compute_event_context(event)
event = await builder.build(prev_event_ids=prev_event_ids)
context = await self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@ -774,7 +783,7 @@ class EventCreationHandler(object):
relates_to = relation["event_id"]
aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event(
already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender
)
if already_exists:
@ -786,7 +795,12 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event")
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
) -> int:
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@ -795,11 +809,11 @@ class EventCreationHandler(object):
processing.
Args:
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
requester
event
context
ratelimit
extra_users: Any extra users to notify about event
Return:
The stream_id of the persisted event.
@ -878,10 +892,9 @@ class EventCreationHandler(object):
self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def _validate_canonical_alias(
self, directory_handler, room_alias_str, expected_room_id
):
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
"""
Ensure that the given room alias points to the expected room ID.
@ -892,9 +905,7 @@ class EventCreationHandler(object):
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
mapping = yield defer.ensureDeferred(
directory_handler.get_association(room_alias)
)
mapping = await directory_handler.get_association(room_alias)
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
@ -913,7 +924,12 @@ class EventCreationHandler(object):
)
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
) -> int:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@ -1106,7 +1122,7 @@ class EventCreationHandler(object):
return event_stream_id
async def _bump_active_time(self, user):
async def _bump_active_time(self, user: UserID) -> None:
try:
presence = self.hs.get_presence_handler()
await presence.bump_presence_active_time(user)

View file

@ -41,9 +41,11 @@ class TestEventContext(unittest.HomeserverTestCase):
serialize/deserialize.
"""
event, context = create_event(
event, context = self.get_success(
create_event(
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
)
)
self._check_serialize_deserialize(event, context)
@ -51,13 +53,15 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
event, context = create_event(
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
sender=self.user_id,
state_key="",
)
)
self._check_serialize_deserialize(event, context)
@ -65,7 +69,8 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
event, context = create_event(
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.room.member",
@ -73,6 +78,7 @@ class TestEventContext(unittest.HomeserverTestCase):
state_key=self.user_id,
content={"membership": "leave"},
)
)
self._check_serialize_deserialize(event, context)

View file

@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
OTHER_USER = "@other_user:localhost"
# have the user join
self.get_success(
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@ -157,7 +159,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
pl_event = inject_event(
pl_event = self.get_success(
inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
@ -166,6 +169,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room_id=self.room_id,
content=pls,
)
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# have the users join
for u in user_ids:
self.get_success(
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@ -306,7 +312,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = inject_event(
e = self.get_success(
inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
@ -315,6 +322,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
room_id=self.room_id,
content=pls,
)
)
prev_events = [e.event_id]
pl_events.append(e)
@ -434,7 +442,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
body = "event %i" % (self.event_count,)
self.event_count += 1
return inject_event(
return self.get_success(
inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
@ -442,6 +451,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
content={"body": body},
**kwargs
)
)
def _inject_state_event(
self,
@ -459,7 +469,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
if body is None:
body = "state event %s" % (state_key,)
return inject_event(
return self.get_success(
inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
@ -467,3 +478,4 @@ class EventsStreamTestCase(BaseStreamTestCase):
state_key=state_key,
content={"body": body},
)
)

View file

@ -118,12 +118,15 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test_get_joined_users_from_context(self):
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
bob_event = event_injection.inject_member_event(
bob_event = self.get_success(
event_injection.inject_member_event(
self.hs, room, self.u_bob, Membership.JOIN
)
)
# first, create a regular event
event, context = event_injection.create_event(
event, context = self.get_success(
event_injection.create_event(
self.hs,
room_id=room,
sender=self.u_alice,
@ -131,6 +134,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
type="m.test.1",
content={},
)
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
@ -140,7 +144,8 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Regression test for #7376: create a state event whose key matches bob's
# user_id, but which is *not* a membership event, and persist that; then check
# that `get_joined_users_from_context` returns the correct users for the next event.
non_member_event = event_injection.inject_event(
non_member_event = self.get_success(
event_injection.inject_event(
self.hs,
room_id=room,
sender=self.u_bob,
@ -149,7 +154,9 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
state_key=self.u_bob,
content={},
)
event, context = event_injection.create_event(
)
event, context = self.get_success(
event_injection.create_event(
self.hs,
room_id=room,
sender=self.u_alice,
@ -157,6 +164,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
type="m.test.3",
content={},
)
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
)

View file

@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)

View file

@ -22,14 +22,12 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
"""
Utility functions for poking events into the storage of the server under test.
"""
def inject_member_event(
async def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
@ -46,7 +44,7 @@ def inject_member_event(
if extra_content:
content.update(extra_content)
return inject_event(
return await inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
@ -57,7 +55,7 @@ def inject_member_event(
)
def inject_event(
async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
@ -72,37 +70,27 @@ def inject_event(
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
test_reactor = hs.get_reactor()
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
await hs.get_storage().persistence.persist_event(event, context)
return event
def create_event(
async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
room_version = get_awaitable_result(d)
room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
d = hs.get_event_creation_handler().create_new_client_event(
event, context = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
test_reactor.advance(0)
event, context = get_awaitable_result(d)
return event, context

View file

@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
#
# before we do that, we persist some other events to act as state.
self.inject_visibility("@admin:hs", "joined")
yield self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
return event
@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
event, context = yield self.event_creation_handler.create_new_client_event(
builder
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)

View file

@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
self.get_success(
event_injection.inject_member_event(self.hs, room, user, membership)
)
class FederatingHomeserverTestCase(HomeserverTestCase):

View file

@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id):
},
)
event, context = yield event_creation_handler.create_new_client_event(builder)
event, context = yield defer.ensureDeferred(
event_creation_handler.create_new_client_event(builder)
)
yield persistence_store.persist_event(event, context)