mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-22 11:20:26 +01:00
Convert the message handler to async/await. (#7884)
This commit is contained in:
parent
a4cf94a3c2
commit
cc9bb3dc3f
10 changed files with 273 additions and 238 deletions
1
changelog.d/7884.misc
Normal file
1
changelog.d/7884.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert the message handler to async/await.
|
|
@ -15,12 +15,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, json
|
from canonicaljson import encode_canonical_json, json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.defer import succeed
|
|
||||||
from twisted.internet.interfaces import IDelayedCall
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
|
@ -41,13 +39,22 @@ from synapse.api.errors import (
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||||
from synapse.api.urls import ConsentURIBuilder
|
from synapse.api.urls import ConsentURIBuilder
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.builder import EventBuilder
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import Collection, RoomAlias, UserID, create_requester
|
from synapse.types import (
|
||||||
|
Collection,
|
||||||
|
Requester,
|
||||||
|
RoomAlias,
|
||||||
|
StreamToken,
|
||||||
|
UserID,
|
||||||
|
create_requester,
|
||||||
|
)
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.frozenutils import frozendict_json_encoder
|
from synapse.util.frozenutils import frozendict_json_encoder
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
@ -84,14 +91,22 @@ class MessageHandler(object):
|
||||||
"_schedule_next_expiry", self._schedule_next_expiry
|
"_schedule_next_expiry", self._schedule_next_expiry
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_room_data(
|
||||||
def get_room_data(
|
self,
|
||||||
self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
|
user_id: str = None,
|
||||||
):
|
room_id: str = None,
|
||||||
|
event_type: Optional[str] = None,
|
||||||
|
state_key: str = "",
|
||||||
|
is_guest: bool = False,
|
||||||
|
) -> dict:
|
||||||
""" Get data from a room.
|
""" Get data from a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event : The room path event
|
user_id
|
||||||
|
room_id
|
||||||
|
event_type
|
||||||
|
state_key
|
||||||
|
is_guest
|
||||||
Returns:
|
Returns:
|
||||||
The path data content.
|
The path data content.
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -100,30 +115,29 @@ class MessageHandler(object):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
membership_event_id,
|
membership_event_id,
|
||||||
) = yield self.auth.check_user_in_room_or_world_readable(
|
) = await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, user_id, allow_departed_users=True
|
room_id, user_id, allow_departed_users=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
data = await self.state.get_current_state(room_id, event_type, state_key)
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
key = (event_type, state_key)
|
key = (event_type, state_key)
|
||||||
room_state = yield self.state_store.get_state_for_events(
|
room_state = await self.state_store.get_state_for_events(
|
||||||
[membership_event_id], StateFilter.from_types([key])
|
[membership_event_id], StateFilter.from_types([key])
|
||||||
)
|
)
|
||||||
data = room_state[membership_event_id].get(key)
|
data = room_state[membership_event_id].get(key)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_events(
|
||||||
def get_state_events(
|
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id: str,
|
||||||
room_id,
|
room_id: str,
|
||||||
state_filter=StateFilter.all(),
|
state_filter: StateFilter = StateFilter.all(),
|
||||||
at_token=None,
|
at_token: Optional[StreamToken] = None,
|
||||||
is_guest=False,
|
is_guest: bool = False,
|
||||||
):
|
) -> List[dict]:
|
||||||
"""Retrieve all state events for a given room. If the user is
|
"""Retrieve all state events for a given room. If the user is
|
||||||
joined to the room then return the current state. If the user has
|
joined to the room then return the current state. If the user has
|
||||||
left the room return the state events from when they left. If an explicit
|
left the room return the state events from when they left. If an explicit
|
||||||
|
@ -131,15 +145,14 @@ class MessageHandler(object):
|
||||||
visible.
|
visible.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): The user requesting state events.
|
user_id: The user requesting state events.
|
||||||
room_id(str): The room ID to get all state events from.
|
room_id: The room ID to get all state events from.
|
||||||
state_filter (StateFilter): The state filter used to fetch state
|
state_filter: The state filter used to fetch state from the database.
|
||||||
from the database.
|
at_token: the stream token of the at which we are requesting
|
||||||
at_token(StreamToken|None): the stream token of the at which we are requesting
|
|
||||||
the stats. If the user is not allowed to view the state as of that
|
the stats. If the user is not allowed to view the state as of that
|
||||||
stream token, we raise a 403 SynapseError. If None, returns the current
|
stream token, we raise a 403 SynapseError. If None, returns the current
|
||||||
state based on the current_state_events table.
|
state based on the current_state_events table.
|
||||||
is_guest(bool): whether this user is a guest
|
is_guest: whether this user is a guest
|
||||||
Returns:
|
Returns:
|
||||||
A list of dicts representing state events. [{}, {}, {}]
|
A list of dicts representing state events. [{}, {}, {}]
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -153,20 +166,20 @@ class MessageHandler(object):
|
||||||
# get_recent_events_for_room operates by topo ordering. This therefore
|
# get_recent_events_for_room operates by topo ordering. This therefore
|
||||||
# does not reliably give you the state at the given stream position.
|
# does not reliably give you the state at the given stream position.
|
||||||
# (https://github.com/matrix-org/synapse/issues/3305)
|
# (https://github.com/matrix-org/synapse/issues/3305)
|
||||||
last_events, _ = yield self.store.get_recent_events_for_room(
|
last_events, _ = await self.store.get_recent_events_for_room(
|
||||||
room_id, end_token=at_token.room_key, limit=1
|
room_id, end_token=at_token.room_key, limit=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if not last_events:
|
if not last_events:
|
||||||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||||
|
|
||||||
visible_events = yield filter_events_for_client(
|
visible_events = await filter_events_for_client(
|
||||||
self.storage, user_id, last_events, filter_send_to_client=False
|
self.storage, user_id, last_events, filter_send_to_client=False
|
||||||
)
|
)
|
||||||
|
|
||||||
event = last_events[0]
|
event = last_events[0]
|
||||||
if visible_events:
|
if visible_events:
|
||||||
room_state = yield self.state_store.get_state_for_events(
|
room_state = await self.state_store.get_state_for_events(
|
||||||
[event.event_id], state_filter=state_filter
|
[event.event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = room_state[event.event_id]
|
room_state = room_state[event.event_id]
|
||||||
|
@ -180,23 +193,23 @@ class MessageHandler(object):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
membership_event_id,
|
membership_event_id,
|
||||||
) = yield self.auth.check_user_in_room_or_world_readable(
|
) = await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, user_id, allow_departed_users=True
|
room_id, user_id, allow_departed_users=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
state_ids = yield self.store.get_filtered_current_state_ids(
|
state_ids = await self.store.get_filtered_current_state_ids(
|
||||||
room_id, state_filter=state_filter
|
room_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = yield self.store.get_events(state_ids.values())
|
room_state = await self.store.get_events(state_ids.values())
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
room_state = yield self.state_store.get_state_for_events(
|
room_state = await self.state_store.get_state_for_events(
|
||||||
[membership_event_id], state_filter=state_filter
|
[membership_event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = room_state[membership_event_id]
|
room_state = room_state[membership_event_id]
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
events = yield self._event_serializer.serialize_events(
|
events = await self._event_serializer.serialize_events(
|
||||||
room_state.values(),
|
room_state.values(),
|
||||||
now,
|
now,
|
||||||
# We don't bother bundling aggregations in when asked for state
|
# We don't bother bundling aggregations in when asked for state
|
||||||
|
@ -205,15 +218,14 @@ class MessageHandler(object):
|
||||||
)
|
)
|
||||||
return events
|
return events
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
|
||||||
def get_joined_members(self, requester, room_id):
|
|
||||||
"""Get all the joined members in the room and their profile information.
|
"""Get all the joined members in the room and their profile information.
|
||||||
|
|
||||||
If the user has left the room return the state events from when they left.
|
If the user has left the room return the state events from when they left.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester(Requester): The user requesting state events.
|
requester: The user requesting state events.
|
||||||
room_id(str): The room ID to get all state events from.
|
room_id: The room ID to get all state events from.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of user_id to profile info
|
A dict of user_id to profile info
|
||||||
"""
|
"""
|
||||||
|
@ -221,7 +233,7 @@ class MessageHandler(object):
|
||||||
if not requester.app_service:
|
if not requester.app_service:
|
||||||
# We check AS auth after fetching the room membership, as it
|
# We check AS auth after fetching the room membership, as it
|
||||||
# requires us to pull out all joined members anyway.
|
# requires us to pull out all joined members anyway.
|
||||||
membership, _ = yield self.auth.check_user_in_room_or_world_readable(
|
membership, _ = await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, user_id, allow_departed_users=True
|
room_id, user_id, allow_departed_users=True
|
||||||
)
|
)
|
||||||
if membership != Membership.JOIN:
|
if membership != Membership.JOIN:
|
||||||
|
@ -229,7 +241,7 @@ class MessageHandler(object):
|
||||||
"Getting joined members after leaving is not implemented"
|
"Getting joined members after leaving is not implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
users_with_profile = yield self.state.get_current_users_in_room(room_id)
|
users_with_profile = await self.state.get_current_users_in_room(room_id)
|
||||||
|
|
||||||
# If this is an AS, double check that they are allowed to see the members.
|
# If this is an AS, double check that they are allowed to see the members.
|
||||||
# This can either be because the AS user is in the room or because there
|
# This can either be because the AS user is in the room or because there
|
||||||
|
@ -250,7 +262,7 @@ class MessageHandler(object):
|
||||||
for user_id, profile in users_with_profile.items()
|
for user_id, profile in users_with_profile.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def maybe_schedule_expiry(self, event):
|
def maybe_schedule_expiry(self, event: EventBase):
|
||||||
"""Schedule the expiry of an event if there's not already one scheduled,
|
"""Schedule the expiry of an event if there's not already one scheduled,
|
||||||
or if the one running is for an event that will expire after the provided
|
or if the one running is for an event that will expire after the provided
|
||||||
timestamp.
|
timestamp.
|
||||||
|
@ -259,7 +271,7 @@ class MessageHandler(object):
|
||||||
the master process, and therefore needs to be run on there.
|
the master process, and therefore needs to be run on there.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (EventBase): The event to schedule the expiry of.
|
event: The event to schedule the expiry of.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
||||||
|
@ -270,8 +282,7 @@ class MessageHandler(object):
|
||||||
# a task scheduled for a timestamp that's sooner than the provided one.
|
# a task scheduled for a timestamp that's sooner than the provided one.
|
||||||
self._schedule_expiry_for_event(event.event_id, expiry_ts)
|
self._schedule_expiry_for_event(event.event_id, expiry_ts)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _schedule_next_expiry(self):
|
||||||
def _schedule_next_expiry(self):
|
|
||||||
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
|
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
|
||||||
and schedule an expiry task for it.
|
and schedule an expiry task for it.
|
||||||
|
|
||||||
|
@ -279,18 +290,18 @@ class MessageHandler(object):
|
||||||
future call to save_expiry_ts can schedule a new expiry task.
|
future call to save_expiry_ts can schedule a new expiry task.
|
||||||
"""
|
"""
|
||||||
# Try to get the expiry timestamp of the next event to expire.
|
# Try to get the expiry timestamp of the next event to expire.
|
||||||
res = yield self.store.get_next_event_to_expire()
|
res = await self.store.get_next_event_to_expire()
|
||||||
if res:
|
if res:
|
||||||
event_id, expiry_ts = res
|
event_id, expiry_ts = res
|
||||||
self._schedule_expiry_for_event(event_id, expiry_ts)
|
self._schedule_expiry_for_event(event_id, expiry_ts)
|
||||||
|
|
||||||
def _schedule_expiry_for_event(self, event_id, expiry_ts):
|
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
|
||||||
"""Schedule an expiry task for the provided event if there's not already one
|
"""Schedule an expiry task for the provided event if there's not already one
|
||||||
scheduled at a timestamp that's sooner than the provided one.
|
scheduled at a timestamp that's sooner than the provided one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id (str): The ID of the event to expire.
|
event_id: The ID of the event to expire.
|
||||||
expiry_ts (int): The timestamp at which to expire the event.
|
expiry_ts: The timestamp at which to expire the event.
|
||||||
"""
|
"""
|
||||||
if self._scheduled_expiry:
|
if self._scheduled_expiry:
|
||||||
# If the provided timestamp refers to a time before the scheduled time of the
|
# If the provided timestamp refers to a time before the scheduled time of the
|
||||||
|
@ -320,8 +331,7 @@ class MessageHandler(object):
|
||||||
event_id,
|
event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _expire_event(self, event_id: str):
|
||||||
def _expire_event(self, event_id):
|
|
||||||
"""Retrieve and expire an event that needs to be expired from the database.
|
"""Retrieve and expire an event that needs to be expired from the database.
|
||||||
|
|
||||||
If the event doesn't exist in the database, log it and delete the expiry date
|
If the event doesn't exist in the database, log it and delete the expiry date
|
||||||
|
@ -336,12 +346,12 @@ class MessageHandler(object):
|
||||||
try:
|
try:
|
||||||
# Expire the event if we know about it. This function also deletes the expiry
|
# Expire the event if we know about it. This function also deletes the expiry
|
||||||
# date from the database in the same database transaction.
|
# date from the database in the same database transaction.
|
||||||
yield self.store.expire_event(event_id)
|
await self.store.expire_event(event_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Could not expire event %s: %r", event_id, e)
|
logger.error("Could not expire event %s: %r", event_id, e)
|
||||||
|
|
||||||
# Schedule the expiry of the next event to expire.
|
# Schedule the expiry of the next event to expire.
|
||||||
yield self._schedule_next_expiry()
|
await self._schedule_next_expiry()
|
||||||
|
|
||||||
|
|
||||||
# The duration (in ms) after which rooms should be removed
|
# The duration (in ms) after which rooms should be removed
|
||||||
|
@ -423,16 +433,15 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_event(
|
||||||
def create_event(
|
|
||||||
self,
|
self,
|
||||||
requester,
|
requester: Requester,
|
||||||
event_dict,
|
event_dict: dict,
|
||||||
token_id=None,
|
token_id: Optional[str] = None,
|
||||||
txn_id=None,
|
txn_id: Optional[str] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[Collection[str]] = None,
|
||||||
require_consent=True,
|
require_consent: bool = True,
|
||||||
):
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""
|
"""
|
||||||
Given a dict from a client, create a new event.
|
Given a dict from a client, create a new event.
|
||||||
|
|
||||||
|
@ -443,31 +452,29 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester
|
requester
|
||||||
event_dict (dict): An entire event
|
event_dict: An entire event
|
||||||
token_id (str)
|
token_id
|
||||||
txn_id (str)
|
txn_id
|
||||||
|
|
||||||
prev_event_ids:
|
prev_event_ids:
|
||||||
the forward extremities to use as the prev_events for the
|
the forward extremities to use as the prev_events for the
|
||||||
new event.
|
new event.
|
||||||
|
|
||||||
If None, they will be requested from the database.
|
If None, they will be requested from the database.
|
||||||
|
require_consent: Whether to check if the requester has
|
||||||
require_consent (bool): Whether to check if the requester has
|
consented to the privacy policy.
|
||||||
consented to privacy policy.
|
|
||||||
Raises:
|
Raises:
|
||||||
ResourceLimitError if server is blocked to some resource being
|
ResourceLimitError if server is blocked to some resource being
|
||||||
exceeded
|
exceeded
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of created event (FrozenEvent), Context
|
Tuple of created event, Context
|
||||||
"""
|
"""
|
||||||
yield self.auth.check_auth_blocking(requester.user.to_string())
|
await self.auth.check_auth_blocking(requester.user.to_string())
|
||||||
|
|
||||||
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
||||||
room_version = event_dict["content"]["room_version"]
|
room_version = event_dict["content"]["room_version"]
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
room_version = yield self.store.get_room_version_id(
|
room_version = await self.store.get_room_version_id(
|
||||||
event_dict["room_id"]
|
event_dict["room_id"]
|
||||||
)
|
)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
|
@ -488,15 +495,11 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "displayname" not in content:
|
if "displayname" not in content:
|
||||||
displayname = yield defer.ensureDeferred(
|
displayname = await profile.get_displayname(target)
|
||||||
profile.get_displayname(target)
|
|
||||||
)
|
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
content["displayname"] = displayname
|
content["displayname"] = displayname
|
||||||
if "avatar_url" not in content:
|
if "avatar_url" not in content:
|
||||||
avatar_url = yield defer.ensureDeferred(
|
avatar_url = await profile.get_avatar_url(target)
|
||||||
profile.get_avatar_url(target)
|
|
||||||
)
|
|
||||||
if avatar_url is not None:
|
if avatar_url is not None:
|
||||||
content["avatar_url"] = avatar_url
|
content["avatar_url"] = avatar_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -504,9 +507,9 @@ class EventCreationHandler(object):
|
||||||
"Failed to get profile information for %r: %s", target, e
|
"Failed to get profile information for %r: %s", target, e
|
||||||
)
|
)
|
||||||
|
|
||||||
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
|
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
|
||||||
if require_consent and not is_exempt:
|
if require_consent and not is_exempt:
|
||||||
yield self.assert_accepted_privacy_policy(requester)
|
await self.assert_accepted_privacy_policy(requester)
|
||||||
|
|
||||||
if token_id is not None:
|
if token_id is not None:
|
||||||
builder.internal_metadata.token_id = token_id
|
builder.internal_metadata.token_id = token_id
|
||||||
|
@ -514,7 +517,7 @@ class EventCreationHandler(object):
|
||||||
if txn_id is not None:
|
if txn_id is not None:
|
||||||
builder.internal_metadata.txn_id = txn_id
|
builder.internal_metadata.txn_id = txn_id
|
||||||
|
|
||||||
event, context = yield self.create_new_client_event(
|
event, context = await self.create_new_client_event(
|
||||||
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
|
builder=builder, requester=requester, prev_event_ids=prev_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -530,10 +533,10 @@ class EventCreationHandler(object):
|
||||||
# federation as well as those created locally. As of room v3, aliases events
|
# federation as well as those created locally. As of room v3, aliases events
|
||||||
# can be created by users that are not in the room, therefore we have to
|
# can be created by users that are not in the room, therefore we have to
|
||||||
# tolerate them in event_auth.check().
|
# tolerate them in event_auth.check().
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
||||||
prev_event = (
|
prev_event = (
|
||||||
yield self.store.get_event(prev_event_id, allow_none=True)
|
await self.store.get_event(prev_event_id, allow_none=True)
|
||||||
if prev_event_id
|
if prev_event_id
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -556,37 +559,36 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
return (event, context)
|
return (event, context)
|
||||||
|
|
||||||
def _is_exempt_from_privacy_policy(self, builder, requester):
|
async def _is_exempt_from_privacy_policy(
|
||||||
|
self, builder: EventBuilder, requester: Requester
|
||||||
|
) -> bool:
|
||||||
""""Determine if an event to be sent is exempt from having to consent
|
""""Determine if an event to be sent is exempt from having to consent
|
||||||
to the privacy policy
|
to the privacy policy
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
builder (synapse.events.builder.EventBuilder): event being created
|
builder: event being created
|
||||||
requester (Requster): user requesting this event
|
requester: user requesting this event
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[bool]: true if the event can be sent without the user
|
true if the event can be sent without the user consenting
|
||||||
consenting
|
|
||||||
"""
|
"""
|
||||||
# the only thing the user can do is join the server notices room.
|
# the only thing the user can do is join the server notices room.
|
||||||
if builder.type == EventTypes.Member:
|
if builder.type == EventTypes.Member:
|
||||||
membership = builder.content.get("membership", None)
|
membership = builder.content.get("membership", None)
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
return self._is_server_notices_room(builder.room_id)
|
return await self._is_server_notices_room(builder.room_id)
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
# the user is always allowed to leave (but not kick people)
|
# the user is always allowed to leave (but not kick people)
|
||||||
return builder.state_key == requester.user.to_string()
|
return builder.state_key == requester.user.to_string()
|
||||||
return succeed(False)
|
return False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _is_server_notices_room(self, room_id: str) -> bool:
|
||||||
def _is_server_notices_room(self, room_id):
|
|
||||||
if self.config.server_notices_mxid is None:
|
if self.config.server_notices_mxid is None:
|
||||||
return False
|
return False
|
||||||
user_ids = yield self.store.get_users_in_room(room_id)
|
user_ids = await self.store.get_users_in_room(room_id)
|
||||||
return self.config.server_notices_mxid in user_ids
|
return self.config.server_notices_mxid in user_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
|
||||||
def assert_accepted_privacy_policy(self, requester):
|
|
||||||
"""Check if a user has accepted the privacy policy
|
"""Check if a user has accepted the privacy policy
|
||||||
|
|
||||||
Called when the given user is about to do something that requires
|
Called when the given user is about to do something that requires
|
||||||
|
@ -595,12 +597,10 @@ class EventCreationHandler(object):
|
||||||
raised.
|
raised.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester (synapse.types.Requester):
|
requester: The user making the request
|
||||||
The user making the request
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[None]: returns normally if the user has consented or is
|
Returns normally if the user has consented or is exempt
|
||||||
exempt
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ConsentNotGivenError: if the user has not given consent yet
|
ConsentNotGivenError: if the user has not given consent yet
|
||||||
|
@ -621,7 +621,7 @@ class EventCreationHandler(object):
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
u = yield self.store.get_user_by_id(user_id)
|
u = await self.store.get_user_by_id(user_id)
|
||||||
assert u is not None
|
assert u is not None
|
||||||
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
|
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
|
||||||
# support and bot users are not required to consent
|
# support and bot users are not required to consent
|
||||||
|
@ -639,16 +639,20 @@ class EventCreationHandler(object):
|
||||||
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
||||||
|
|
||||||
async def send_nonmember_event(
|
async def send_nonmember_event(
|
||||||
self, requester, event, context, ratelimit=True
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
event: EventBase,
|
||||||
|
context: EventContext,
|
||||||
|
ratelimit: bool = True,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Persists and notifies local clients and federation of an event.
|
Persists and notifies local clients and federation of an event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (FrozenEvent) the event to send.
|
requester
|
||||||
context (Context) the context of the event.
|
event the event to send.
|
||||||
ratelimit (bool): Whether to rate limit this send.
|
context: the context of the event.
|
||||||
is_guest (bool): Whether the sender is a guest.
|
ratelimit: Whether to rate limit this send.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
The stream_id of the persisted event.
|
The stream_id of the persisted event.
|
||||||
|
@ -676,19 +680,20 @@ class EventCreationHandler(object):
|
||||||
requester=requester, event=event, context=context, ratelimit=ratelimit
|
requester=requester, event=event, context=context, ratelimit=ratelimit
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def deduplicate_state_event(
|
||||||
def deduplicate_state_event(self, event, context):
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Checks whether event is in the latest resolved state in context.
|
Checks whether event is in the latest resolved state in context.
|
||||||
|
|
||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
if not prev_event_id:
|
if not prev_event_id:
|
||||||
return
|
return
|
||||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
|
||||||
if not prev_event:
|
if not prev_event:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -700,7 +705,11 @@ class EventCreationHandler(object):
|
||||||
return
|
return
|
||||||
|
|
||||||
async def create_and_send_nonmember_event(
|
async def create_and_send_nonmember_event(
|
||||||
self, requester, event_dict, ratelimit=True, txn_id=None
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
event_dict: EventBase,
|
||||||
|
ratelimit: bool = True,
|
||||||
|
txn_id: Optional[str] = None,
|
||||||
) -> Tuple[EventBase, int]:
|
) -> Tuple[EventBase, int]:
|
||||||
"""
|
"""
|
||||||
Creates an event, then sends it.
|
Creates an event, then sends it.
|
||||||
|
@ -730,17 +739,17 @@ class EventCreationHandler(object):
|
||||||
return event, stream_id
|
return event, stream_id
|
||||||
|
|
||||||
@measure_func("create_new_client_event")
|
@measure_func("create_new_client_event")
|
||||||
@defer.inlineCallbacks
|
async def create_new_client_event(
|
||||||
def create_new_client_event(
|
self,
|
||||||
self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
|
builder: EventBuilder,
|
||||||
):
|
requester: Optional[Requester] = None,
|
||||||
|
prev_event_ids: Optional[Collection[str]] = None,
|
||||||
|
) -> Tuple[EventBase, EventContext]:
|
||||||
"""Create a new event for a local client
|
"""Create a new event for a local client
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
builder (EventBuilder):
|
builder:
|
||||||
|
requester:
|
||||||
requester (synapse.types.Requester|None):
|
|
||||||
|
|
||||||
prev_event_ids:
|
prev_event_ids:
|
||||||
the forward extremities to use as the prev_events for the
|
the forward extremities to use as the prev_events for the
|
||||||
new event.
|
new event.
|
||||||
|
@ -748,7 +757,7 @@ class EventCreationHandler(object):
|
||||||
If None, they will be requested from the database.
|
If None, they will be requested from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
|
Tuple of created event, context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if prev_event_ids is not None:
|
if prev_event_ids is not None:
|
||||||
|
@ -757,10 +766,10 @@ class EventCreationHandler(object):
|
||||||
% (len(prev_event_ids),)
|
% (len(prev_event_ids),)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
|
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
||||||
|
|
||||||
event = yield builder.build(prev_event_ids=prev_event_ids)
|
event = await builder.build(prev_event_ids=prev_event_ids)
|
||||||
context = yield self.state.compute_event_context(event)
|
context = await self.state.compute_event_context(event)
|
||||||
if requester:
|
if requester:
|
||||||
context.app_service = requester.app_service
|
context.app_service = requester.app_service
|
||||||
|
|
||||||
|
@ -774,7 +783,7 @@ class EventCreationHandler(object):
|
||||||
relates_to = relation["event_id"]
|
relates_to = relation["event_id"]
|
||||||
aggregation_key = relation["key"]
|
aggregation_key = relation["key"]
|
||||||
|
|
||||||
already_exists = yield self.store.has_user_annotated_event(
|
already_exists = await self.store.has_user_annotated_event(
|
||||||
relates_to, event.type, aggregation_key, event.sender
|
relates_to, event.type, aggregation_key, event.sender
|
||||||
)
|
)
|
||||||
if already_exists:
|
if already_exists:
|
||||||
|
@ -786,7 +795,12 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
@measure_func("handle_new_client_event")
|
@measure_func("handle_new_client_event")
|
||||||
async def handle_new_client_event(
|
async def handle_new_client_event(
|
||||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
event: EventBase,
|
||||||
|
context: EventContext,
|
||||||
|
ratelimit: bool = True,
|
||||||
|
extra_users: List[UserID] = [],
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Processes a new event. This includes checking auth, persisting it,
|
"""Processes a new event. This includes checking auth, persisting it,
|
||||||
notifying users, sending to remote servers, etc.
|
notifying users, sending to remote servers, etc.
|
||||||
|
@ -795,11 +809,11 @@ class EventCreationHandler(object):
|
||||||
processing.
|
processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester (Requester)
|
requester
|
||||||
event (FrozenEvent)
|
event
|
||||||
context (EventContext)
|
context
|
||||||
ratelimit (bool)
|
ratelimit
|
||||||
extra_users (list(UserID)): Any extra users to notify about event
|
extra_users: Any extra users to notify about event
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
The stream_id of the persisted event.
|
The stream_id of the persisted event.
|
||||||
|
@ -878,10 +892,9 @@ class EventCreationHandler(object):
|
||||||
self.store.remove_push_actions_from_staging, event.event_id
|
self.store.remove_push_actions_from_staging, event.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _validate_canonical_alias(
|
||||||
def _validate_canonical_alias(
|
self, directory_handler, room_alias_str: str, expected_room_id: str
|
||||||
self, directory_handler, room_alias_str, expected_room_id
|
) -> None:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Ensure that the given room alias points to the expected room ID.
|
Ensure that the given room alias points to the expected room ID.
|
||||||
|
|
||||||
|
@ -892,9 +905,7 @@ class EventCreationHandler(object):
|
||||||
"""
|
"""
|
||||||
room_alias = RoomAlias.from_string(room_alias_str)
|
room_alias = RoomAlias.from_string(room_alias_str)
|
||||||
try:
|
try:
|
||||||
mapping = yield defer.ensureDeferred(
|
mapping = await directory_handler.get_association(room_alias)
|
||||||
directory_handler.get_association(room_alias)
|
|
||||||
)
|
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
|
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
|
||||||
if e.errcode == Codes.NOT_FOUND:
|
if e.errcode == Codes.NOT_FOUND:
|
||||||
|
@ -913,7 +924,12 @@ class EventCreationHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def persist_and_notify_client_event(
|
async def persist_and_notify_client_event(
|
||||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
event: EventBase,
|
||||||
|
context: EventContext,
|
||||||
|
ratelimit: bool = True,
|
||||||
|
extra_users: List[UserID] = [],
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Called when we have fully built the event, have already
|
"""Called when we have fully built the event, have already
|
||||||
calculated the push actions for the event, and checked auth.
|
calculated the push actions for the event, and checked auth.
|
||||||
|
@ -1106,7 +1122,7 @@ class EventCreationHandler(object):
|
||||||
|
|
||||||
return event_stream_id
|
return event_stream_id
|
||||||
|
|
||||||
async def _bump_active_time(self, user):
|
async def _bump_active_time(self, user: UserID) -> None:
|
||||||
try:
|
try:
|
||||||
presence = self.hs.get_presence_handler()
|
presence = self.hs.get_presence_handler()
|
||||||
await presence.bump_presence_active_time(user)
|
await presence.bump_presence_active_time(user)
|
||||||
|
|
|
@ -41,9 +41,11 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||||
serialize/deserialize.
|
serialize/deserialize.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event, context = create_event(
|
event, context = self.get_success(
|
||||||
|
create_event(
|
||||||
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
|
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self._check_serialize_deserialize(event, context)
|
self._check_serialize_deserialize(event, context)
|
||||||
|
|
||||||
|
@ -51,13 +53,15 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||||
"""Test that an EventContext for a state event (with not previous entry)
|
"""Test that an EventContext for a state event (with not previous entry)
|
||||||
is the same after serialize/deserialize.
|
is the same after serialize/deserialize.
|
||||||
"""
|
"""
|
||||||
event, context = create_event(
|
event, context = self.get_success(
|
||||||
|
create_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
type="m.test",
|
type="m.test",
|
||||||
sender=self.user_id,
|
sender=self.user_id,
|
||||||
state_key="",
|
state_key="",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self._check_serialize_deserialize(event, context)
|
self._check_serialize_deserialize(event, context)
|
||||||
|
|
||||||
|
@ -65,7 +69,8 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||||
"""Test that an EventContext for a state event (which replaces a
|
"""Test that an EventContext for a state event (which replaces a
|
||||||
previous entry) is the same after serialize/deserialize.
|
previous entry) is the same after serialize/deserialize.
|
||||||
"""
|
"""
|
||||||
event, context = create_event(
|
event, context = self.get_success(
|
||||||
|
create_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
type="m.room.member",
|
type="m.room.member",
|
||||||
|
@ -73,6 +78,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||||
state_key=self.user_id,
|
state_key=self.user_id,
|
||||||
content={"membership": "leave"},
|
content={"membership": "leave"},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self._check_serialize_deserialize(event, context)
|
self._check_serialize_deserialize(event, context)
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
OTHER_USER = "@other_user:localhost"
|
OTHER_USER = "@other_user:localhost"
|
||||||
|
|
||||||
# have the user join
|
# have the user join
|
||||||
|
self.get_success(
|
||||||
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
|
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
|
||||||
|
)
|
||||||
|
|
||||||
# Update existing power levels with mod at PL50
|
# Update existing power levels with mod at PL50
|
||||||
pls = self.helper.get_state(
|
pls = self.helper.get_state(
|
||||||
|
@ -157,7 +159,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
# roll back all the state by de-modding the user
|
# roll back all the state by de-modding the user
|
||||||
prev_events = fork_point
|
prev_events = fork_point
|
||||||
pls["users"][OTHER_USER] = 0
|
pls["users"][OTHER_USER] = 0
|
||||||
pl_event = inject_event(
|
pl_event = self.get_success(
|
||||||
|
inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
prev_event_ids=prev_events,
|
prev_event_ids=prev_events,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
|
@ -166,6 +169,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
content=pls,
|
content=pls,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# one more bit of state that doesn't get rolled back
|
# one more bit of state that doesn't get rolled back
|
||||||
state2 = self._inject_state_event()
|
state2 = self._inject_state_event()
|
||||||
|
@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
# have the users join
|
# have the users join
|
||||||
for u in user_ids:
|
for u in user_ids:
|
||||||
|
self.get_success(
|
||||||
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
|
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
|
||||||
|
)
|
||||||
|
|
||||||
# Update existing power levels with mod at PL50
|
# Update existing power levels with mod at PL50
|
||||||
pls = self.helper.get_state(
|
pls = self.helper.get_state(
|
||||||
|
@ -306,7 +312,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
pl_events = []
|
pl_events = []
|
||||||
for u in user_ids:
|
for u in user_ids:
|
||||||
pls["users"][u] = 0
|
pls["users"][u] = 0
|
||||||
e = inject_event(
|
e = self.get_success(
|
||||||
|
inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
prev_event_ids=prev_events,
|
prev_event_ids=prev_events,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
|
@ -315,6 +322,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
content=pls,
|
content=pls,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
prev_events = [e.event_id]
|
prev_events = [e.event_id]
|
||||||
pl_events.append(e)
|
pl_events.append(e)
|
||||||
|
|
||||||
|
@ -434,7 +442,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
body = "event %i" % (self.event_count,)
|
body = "event %i" % (self.event_count,)
|
||||||
self.event_count += 1
|
self.event_count += 1
|
||||||
|
|
||||||
return inject_event(
|
return self.get_success(
|
||||||
|
inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
|
@ -442,6 +451,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
content={"body": body},
|
content={"body": body},
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _inject_state_event(
|
def _inject_state_event(
|
||||||
self,
|
self,
|
||||||
|
@ -459,7 +469,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
if body is None:
|
if body is None:
|
||||||
body = "state event %s" % (state_key,)
|
body = "state event %s" % (state_key,)
|
||||||
|
|
||||||
return inject_event(
|
return self.get_success(
|
||||||
|
inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
|
@ -467,3 +478,4 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
state_key=state_key,
|
state_key=state_key,
|
||||||
content={"body": body},
|
content={"body": body},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -118,12 +118,15 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_joined_users_from_context(self):
|
def test_get_joined_users_from_context(self):
|
||||||
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
||||||
bob_event = event_injection.inject_member_event(
|
bob_event = self.get_success(
|
||||||
|
event_injection.inject_member_event(
|
||||||
self.hs, room, self.u_bob, Membership.JOIN
|
self.hs, room, self.u_bob, Membership.JOIN
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# first, create a regular event
|
# first, create a regular event
|
||||||
event, context = event_injection.create_event(
|
event, context = self.get_success(
|
||||||
|
event_injection.create_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=room,
|
room_id=room,
|
||||||
sender=self.u_alice,
|
sender=self.u_alice,
|
||||||
|
@ -131,6 +134,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.test.1",
|
type="m.test.1",
|
||||||
content={},
|
content={},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
users = self.get_success(
|
users = self.get_success(
|
||||||
self.store.get_joined_users_from_context(event, context)
|
self.store.get_joined_users_from_context(event, context)
|
||||||
|
@ -140,7 +144,8 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
# Regression test for #7376: create a state event whose key matches bob's
|
# Regression test for #7376: create a state event whose key matches bob's
|
||||||
# user_id, but which is *not* a membership event, and persist that; then check
|
# user_id, but which is *not* a membership event, and persist that; then check
|
||||||
# that `get_joined_users_from_context` returns the correct users for the next event.
|
# that `get_joined_users_from_context` returns the correct users for the next event.
|
||||||
non_member_event = event_injection.inject_event(
|
non_member_event = self.get_success(
|
||||||
|
event_injection.inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=room,
|
room_id=room,
|
||||||
sender=self.u_bob,
|
sender=self.u_bob,
|
||||||
|
@ -149,7 +154,9 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
state_key=self.u_bob,
|
state_key=self.u_bob,
|
||||||
content={},
|
content={},
|
||||||
)
|
)
|
||||||
event, context = event_injection.create_event(
|
)
|
||||||
|
event, context = self.get_success(
|
||||||
|
event_injection.create_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
room_id=room,
|
room_id=room,
|
||||||
sender=self.u_alice,
|
sender=self.u_alice,
|
||||||
|
@ -157,6 +164,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
type="m.test.3",
|
type="m.test.3",
|
||||||
content={},
|
content={},
|
||||||
)
|
)
|
||||||
|
)
|
||||||
users = self.get_success(
|
users = self.get_success(
|
||||||
self.store.get_joined_users_from_context(event, context)
|
self.store.get_joined_users_from_context(event, context)
|
||||||
)
|
)
|
||||||
|
|
|
@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield defer.ensureDeferred(
|
||||||
builder
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield self.storage.persistence.persist_event(event, context)
|
||||||
|
|
|
@ -22,14 +22,12 @@ from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
|
|
||||||
from tests.test_utils import get_awaitable_result
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Utility functions for poking events into the storage of the server under test.
|
Utility functions for poking events into the storage of the server under test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def inject_member_event(
|
async def inject_member_event(
|
||||||
hs: synapse.server.HomeServer,
|
hs: synapse.server.HomeServer,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
sender: str,
|
sender: str,
|
||||||
|
@ -46,7 +44,7 @@ def inject_member_event(
|
||||||
if extra_content:
|
if extra_content:
|
||||||
content.update(extra_content)
|
content.update(extra_content)
|
||||||
|
|
||||||
return inject_event(
|
return await inject_event(
|
||||||
hs,
|
hs,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
type=EventTypes.Member,
|
type=EventTypes.Member,
|
||||||
|
@ -57,7 +55,7 @@ def inject_member_event(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def inject_event(
|
async def inject_event(
|
||||||
hs: synapse.server.HomeServer,
|
hs: synapse.server.HomeServer,
|
||||||
room_version: Optional[str] = None,
|
room_version: Optional[str] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[Collection[str]] = None,
|
||||||
|
@ -72,37 +70,27 @@ def inject_event(
|
||||||
prev_event_ids: prev_events for the event. If not specified, will be looked up
|
prev_event_ids: prev_events for the event. If not specified, will be looked up
|
||||||
kwargs: fields for the event to be created
|
kwargs: fields for the event to be created
|
||||||
"""
|
"""
|
||||||
test_reactor = hs.get_reactor()
|
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
|
||||||
|
|
||||||
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
|
await hs.get_storage().persistence.persist_event(event, context)
|
||||||
|
|
||||||
d = hs.get_storage().persistence.persist_event(event, context)
|
|
||||||
test_reactor.advance(0)
|
|
||||||
get_awaitable_result(d)
|
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
def create_event(
|
async def create_event(
|
||||||
hs: synapse.server.HomeServer,
|
hs: synapse.server.HomeServer,
|
||||||
room_version: Optional[str] = None,
|
room_version: Optional[str] = None,
|
||||||
prev_event_ids: Optional[Collection[str]] = None,
|
prev_event_ids: Optional[Collection[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[EventBase, EventContext]:
|
) -> Tuple[EventBase, EventContext]:
|
||||||
test_reactor = hs.get_reactor()
|
|
||||||
|
|
||||||
if room_version is None:
|
if room_version is None:
|
||||||
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
|
room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
|
||||||
test_reactor.advance(0)
|
|
||||||
room_version = get_awaitable_result(d)
|
|
||||||
|
|
||||||
builder = hs.get_event_builder_factory().for_room_version(
|
builder = hs.get_event_builder_factory().for_room_version(
|
||||||
KNOWN_ROOM_VERSIONS[room_version], kwargs
|
KNOWN_ROOM_VERSIONS[room_version], kwargs
|
||||||
)
|
)
|
||||||
d = hs.get_event_creation_handler().create_new_client_event(
|
event, context = await hs.get_event_creation_handler().create_new_client_event(
|
||||||
builder, prev_event_ids=prev_event_ids
|
builder, prev_event_ids=prev_event_ids
|
||||||
)
|
)
|
||||||
test_reactor.advance(0)
|
|
||||||
event, context = get_awaitable_result(d)
|
|
||||||
|
|
||||||
return event, context
|
return event, context
|
||||||
|
|
|
@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
#
|
#
|
||||||
|
|
||||||
# before we do that, we persist some other events to act as state.
|
# before we do that, we persist some other events to act as state.
|
||||||
self.inject_visibility("@admin:hs", "joined")
|
yield self.inject_visibility("@admin:hs", "joined")
|
||||||
for i in range(0, 10):
|
for i in range(0, 10):
|
||||||
yield self.inject_room_member("@resident%i:hs" % i)
|
yield self.inject_room_member("@resident%i:hs" % i)
|
||||||
|
|
||||||
|
@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield defer.ensureDeferred(
|
||||||
builder
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield self.storage.persistence.persist_event(event, context)
|
||||||
return event
|
return event
|
||||||
|
@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield defer.ensureDeferred(
|
||||||
builder
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield self.storage.persistence.persist_event(event, context)
|
||||||
|
@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield defer.ensureDeferred(
|
||||||
builder
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield self.storage.persistence.persist_event(event, context)
|
||||||
|
|
|
@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
|
||||||
user: MXID of the user to inject the membership for.
|
user: MXID of the user to inject the membership for.
|
||||||
membership: The membership type.
|
membership: The membership type.
|
||||||
"""
|
"""
|
||||||
|
self.get_success(
|
||||||
event_injection.inject_member_event(self.hs, room, user, membership)
|
event_injection.inject_member_event(self.hs, room, user, membership)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FederatingHomeserverTestCase(HomeserverTestCase):
|
class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||||
|
|
|
@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield event_creation_handler.create_new_client_event(builder)
|
event, context = yield defer.ensureDeferred(
|
||||||
|
event_creation_handler.create_new_client_event(builder)
|
||||||
|
)
|
||||||
|
|
||||||
yield persistence_store.persist_event(event, context)
|
yield persistence_store.persist_event(event, context)
|
||||||
|
|
Loading…
Add table
Reference in a new issue