Port SyncHandler to async/await

This commit is contained in:
Erik Johnston 2019-12-05 17:58:25 +00:00
parent d085a8a0a5
commit 8437e2383e
6 changed files with 182 additions and 191 deletions

View file

@ -16,8 +16,6 @@
import logging import logging
import random import random
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@log_function @log_function
def get_stream( async def get_stream(
self, self,
auth_user_id, auth_user_id,
pagin_config, pagin_config,
@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
""" """
if room_id: if room_id:
blocked = yield self.store.is_room_blocked(room_id) blocked = await self.store.is_room_blocked(room_id)
if blocked: if blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
# send any outstanding server notices to the user. # send any outstanding server notices to the user.
yield self._server_notices_sender.on_user_syncing(auth_user_id) await self._server_notices_sender.on_user_syncing(auth_user_id)
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler() presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing( context = await presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence auth_user_id, affect_presence=affect_presence
) )
with context: with context:
@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for( events, tokens = await self.notifier.get_events_for(
auth_user, auth_user,
pagin_config, pagin_config,
timeout, timeout,
@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
# Send down presence. # Send down presence.
if event.state_key == auth_user_id: if event.state_key == auth_user_id:
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = yield self.state.get_current_users_in_room( users = await self.state.get_current_users_in_room(
event.room_id event.room_id
) )
states = yield presence_handler.get_states(users, as_event=True) states = await presence_handler.get_states(users, as_event=True)
to_add.extend(states) to_add.extend(states)
else: else:
ev = yield presence_handler.get_state( ev = await presence_handler.get_state(
UserID.from_string(event.state_key), as_event=True UserID.from_string(event.state_key), as_event=True
) )
to_add.append(ev) to_add.append(ev)
@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunks = yield self._event_serializer.serialize_events( chunks = await self._event_serializer.serialize_events(
events, events,
time_now, time_now,
as_client_event=as_client_event, as_client_event=as_client_event,
@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
super(EventHandler, self).__init__(hs) super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage() self.storage = hs.get_storage()
@defer.inlineCallbacks async def get_event(self, user, room_id, event_id):
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event. """Retrieve a single specified event.
Args: Args:
@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
AuthError if the user does not have the rights to inspect this AuthError if the user does not have the rights to inspect this
event. event.
""" """
event = yield self.store.get_event(event_id, check_room_id=room_id) event = await self.store.get_event(event_id, check_room_id=room_id)
if not event: if not event:
return None return None
users = yield self.store.get_users_in_room(event.room_id) users = await self.store.get_users_in_room(event.room_id)
is_peeking = user.to_string() not in users is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client( filtered = await filter_events_for_client(
self.storage, user.to_string(), [event], is_peeking=is_peeking self.storage, user.to_string(), [event], is_peeking=is_peeking
) )

View file

@ -22,8 +22,6 @@ from six import iteritems, itervalues
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
@ -241,8 +239,7 @@ class SyncHandler(object):
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) )
@defer.inlineCallbacks async def wait_for_sync_for_user(
def wait_for_sync_for_user(
self, sync_config, since_token=None, timeout=0, full_state=False self, sync_config, since_token=None, timeout=0, full_state=False
): ):
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
@ -255,9 +252,9 @@ class SyncHandler(object):
# not been exceeded (if not part of the group by this point, almost certain # not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur) # auth_blocking will occur)
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
yield self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
res = yield self.response_cache.wrap( res = await self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, sync_config,
@ -267,8 +264,9 @@ class SyncHandler(object):
) )
return res return res
@defer.inlineCallbacks async def _wait_for_sync_for_user(
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): self, sync_config, since_token, timeout, full_state
):
if since_token is None: if since_token is None:
sync_type = "initial_sync" sync_type = "initial_sync"
elif full_state: elif full_state:
@ -283,7 +281,7 @@ class SyncHandler(object):
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result = yield self.current_sync_for_user( result = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state sync_config, since_token, full_state=full_state
) )
else: else:
@ -291,7 +289,7 @@ class SyncHandler(object):
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
result = yield self.notifier.wait_for_events( result = await self.notifier.wait_for_events(
sync_config.user.to_string(), sync_config.user.to_string(),
timeout, timeout,
current_sync_callback, current_sync_callback,
@ -314,15 +312,13 @@ class SyncHandler(object):
""" """
return self.generate_sync_result(sync_config, since_token, full_state) return self.generate_sync_result(sync_config, since_token, full_state)
@defer.inlineCallbacks async def push_rules_for_user(self, user):
def push_rules_for_user(self, user):
user_id = user.to_string() user_id = user.to_string()
rules = yield self.store.get_push_rules_for_user(user_id) rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules) rules = format_push_rules_for_user(user, rules)
return rules return rules
@defer.inlineCallbacks async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in """Get the ephemeral events for each room the user is in
Args: Args:
sync_result_builder(SyncResultBuilder) sync_result_builder(SyncResultBuilder)
@ -343,7 +339,7 @@ class SyncHandler(object):
room_ids = sync_result_builder.joined_room_ids room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter_collection.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
@ -365,7 +361,7 @@ class SyncHandler(object):
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events( receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter_collection.ephemeral_limit(), limit=sync_config.filter_collection.ephemeral_limit(),
@ -382,8 +378,7 @@ class SyncHandler(object):
return now_token, ephemeral_by_room return now_token, ephemeral_by_room
@defer.inlineCallbacks async def _load_filtered_recents(
def _load_filtered_recents(
self, self,
room_id, room_id,
sync_config, sync_config,
@ -415,10 +410,10 @@ class SyncHandler(object):
# ensure that we always include current state in the timeline # ensure that we always include current state in the timeline
current_state_ids = frozenset() current_state_ids = frozenset()
if any(e.is_state() for e in recents): if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = await self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(itervalues(current_state_ids)) current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client( recents = await filter_events_for_client(
self.storage, self.storage,
sync_config.user.to_string(), sync_config.user.to_string(),
recents, recents,
@ -449,14 +444,14 @@ class SyncHandler(object):
# Otherwise, we want to return the last N events in the room # Otherwise, we want to return the last N events in the room
# in toplogical ordering. # in toplogical ordering.
if since_key: if since_key:
events, end_key = yield self.store.get_room_events_stream_for_room( events, end_key = await self.store.get_room_events_stream_for_room(
room_id, room_id,
limit=load_limit + 1, limit=load_limit + 1,
from_key=since_key, from_key=since_key,
to_key=end_key, to_key=end_key,
) )
else: else:
events, end_key = yield self.store.get_recent_events_for_room( events, end_key = await self.store.get_recent_events_for_room(
room_id, limit=load_limit + 1, end_token=end_key room_id, limit=load_limit + 1, end_token=end_key
) )
loaded_recents = sync_config.filter_collection.filter_room_timeline( loaded_recents = sync_config.filter_collection.filter_room_timeline(
@ -468,10 +463,10 @@ class SyncHandler(object):
# ensure that we always include current state in the timeline # ensure that we always include current state in the timeline
current_state_ids = frozenset() current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents): if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = await self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(itervalues(current_state_ids)) current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client( loaded_recents = await filter_events_for_client(
self.storage, self.storage,
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
@ -498,8 +493,7 @@ class SyncHandler(object):
limited=limited or newly_joined_room, limited=limited or newly_joined_room,
) )
@defer.inlineCallbacks async def get_state_after_event(self, event, state_filter=StateFilter.all()):
def get_state_after_event(self, event, state_filter=StateFilter.all()):
""" """
Get the room state after the given event Get the room state after the given event
@ -511,7 +505,7 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter
) )
if event.is_state(): if event.is_state():
@ -519,8 +513,9 @@ class SyncHandler(object):
state_ids[(event.type, event.state_key)] = event.event_id state_ids[(event.type, event.state_key)] = event.event_id
return state_ids return state_ids
@defer.inlineCallbacks async def get_state_at(
def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): self, room_id, stream_position, state_filter=StateFilter.all()
):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
Args: Args:
@ -536,13 +531,13 @@ class SyncHandler(object):
# get_recent_events_for_room operates by topo ordering. This therefore # get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position. # does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305) # (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room( last_events, _ = await self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1 room_id, end_token=stream_position.room_key, limit=1
) )
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = yield self.get_state_after_event( state = await self.get_state_after_event(
last_event, state_filter=state_filter last_event, state_filter=state_filter
) )
@ -551,8 +546,7 @@ class SyncHandler(object):
state = {} state = {}
return state return state
@defer.inlineCallbacks async def compute_summary(self, room_id, sync_config, batch, state, now_token):
def compute_summary(self, room_id, sync_config, batch, state, now_token):
""" Works out a room summary block for this room, summarising the number """ Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds room has no name so clients can consistently name rooms. Also adds
@ -574,7 +568,7 @@ class SyncHandler(object):
# FIXME: we could/should get this from room_stats when matthew/stats lands # FIXME: we could/should get this from room_stats when matthew/stats lands
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room( last_events, _ = await self.store.get_recent_event_ids_for_room(
room_id, end_token=now_token.room_key, limit=1 room_id, end_token=now_token.room_key, limit=1
) )
@ -582,7 +576,7 @@ class SyncHandler(object):
return None return None
last_event = last_events[-1] last_event = last_events[-1]
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
last_event.event_id, last_event.event_id,
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -590,7 +584,7 @@ class SyncHandler(object):
) )
# this is heavily cached, thus: fast. # this is heavily cached, thus: fast.
details = yield self.store.get_room_summary(room_id) details = await self.store.get_room_summary(room_id)
name_id = state_ids.get((EventTypes.Name, "")) name_id = state_ids.get((EventTypes.Name, ""))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
@ -608,12 +602,12 @@ class SyncHandler(object):
# calculating heroes. Empty strings are falsey, so we check # calculating heroes. Empty strings are falsey, so we check
# for the "name" value and default to an empty string. # for the "name" value and default to an empty string.
if name_id: if name_id:
name = yield self.store.get_event(name_id, allow_none=True) name = await self.store.get_event(name_id, allow_none=True)
if name and name.content.get("name"): if name and name.content.get("name"):
return summary return summary
if canonical_alias_id: if canonical_alias_id:
canonical_alias = yield self.store.get_event( canonical_alias = await self.store.get_event(
canonical_alias_id, allow_none=True canonical_alias_id, allow_none=True
) )
if canonical_alias and canonical_alias.content.get("alias"): if canonical_alias and canonical_alias.content.get("alias"):
@ -678,7 +672,7 @@ class SyncHandler(object):
) )
] ]
missing_hero_state = yield self.store.get_events(missing_hero_event_ids) missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values() missing_hero_state = missing_hero_state.values()
for s in missing_hero_state: for s in missing_hero_state:
@ -697,8 +691,7 @@ class SyncHandler(object):
logger.debug("found LruCache for %r", cache_key) logger.debug("found LruCache for %r", cache_key)
return cache return cache
@defer.inlineCallbacks async def compute_state_delta(
def compute_state_delta(
self, room_id, batch, sync_config, since_token, now_token, full_state self, room_id, batch, sync_config, since_token, now_token, full_state
): ):
""" Works out the difference in state between the start of the timeline """ Works out the difference in state between the start of the timeline
@ -759,16 +752,16 @@ class SyncHandler(object):
if full_state: if full_state:
if batch: if batch:
current_state_ids = yield self.state_store.get_state_ids_for_event( current_state_ids = await self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
else: else:
current_state_ids = yield self.get_state_at( current_state_ids = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter room_id, stream_position=now_token, state_filter=state_filter
) )
@ -783,13 +776,13 @@ class SyncHandler(object):
) )
elif batch.limited: elif batch.limited:
if batch: if batch:
state_at_timeline_start = yield self.state_store.get_state_ids_for_event( state_at_timeline_start = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
else: else:
# We can get here if the user has ignored the senders of all # We can get here if the user has ignored the senders of all
# the recent events. # the recent events.
state_at_timeline_start = yield self.get_state_at( state_at_timeline_start = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter room_id, stream_position=now_token, state_filter=state_filter
) )
@ -807,19 +800,19 @@ class SyncHandler(object):
# about them). # about them).
state_filter = StateFilter.all() state_filter = StateFilter.all()
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = await self.get_state_at(
room_id, stream_position=since_token, state_filter=state_filter room_id, stream_position=since_token, state_filter=state_filter
) )
if batch: if batch:
current_state_ids = yield self.state_store.get_state_ids_for_event( current_state_ids = await self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
else: else:
# Its not clear how we get here, but empirically we do # Its not clear how we get here, but empirically we do
# (#5407). Logging has been added elsewhere to try and # (#5407). Logging has been added elsewhere to try and
# figure out where this state comes from. # figure out where this state comes from.
current_state_ids = yield self.get_state_at( current_state_ids = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter room_id, stream_position=now_token, state_filter=state_filter
) )
@ -843,7 +836,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the # So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below. # timeline here, and then dedupe any redundant ones below.
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, batch.events[0].event_id,
# we only want members! # we only want members!
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
@ -883,7 +876,7 @@ class SyncHandler(object):
state = {} state = {}
if state_ids: if state_ids:
state = yield self.store.get_events(list(state_ids.values())) state = await self.store.get_events(list(state_ids.values()))
return { return {
(e.type, e.state_key): e (e.type, e.state_key): e
@ -892,10 +885,9 @@ class SyncHandler(object):
) )
} }
@defer.inlineCallbacks async def unread_notifs_for_room_id(self, room_id, sync_config):
def unread_notifs_for_room_id(self, room_id, sync_config):
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
room_id=room_id, room_id=room_id,
receipt_type="m.read", receipt_type="m.read",
@ -903,7 +895,7 @@ class SyncHandler(object):
notifs = [] notifs = []
if last_unread_event_id: if last_unread_event_id:
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id room_id, sync_config.user.to_string(), last_unread_event_id
) )
return notifs return notifs
@ -912,8 +904,9 @@ class SyncHandler(object):
# count is whatever it was last time. # count is whatever it was last time.
return None return None
@defer.inlineCallbacks async def generate_sync_result(
def generate_sync_result(self, sync_config, since_token=None, full_state=False): self, sync_config, since_token=None, full_state=False
):
"""Generates a sync result. """Generates a sync result.
Args: Args:
@ -928,7 +921,7 @@ class SyncHandler(object):
# this is due to some of the underlying streams not supporting the ability # this is due to some of the underlying streams not supporting the ability
# to query up to a given point. # to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder` # Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token() now_token = await self.event_sources.get_current_token()
logger.info( logger.info(
"Calculating sync response for %r between %s and %s", "Calculating sync response for %r between %s and %s",
@ -944,10 +937,9 @@ class SyncHandler(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144 # See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError() raise NotImplementedError()
else: else:
joined_room_ids = yield self.get_rooms_for_user_at( joined_room_ids = await self.get_rooms_for_user_at(
user_id, now_token.room_stream_id user_id, now_token.room_stream_id
) )
sync_result_builder = SyncResultBuilder( sync_result_builder = SyncResultBuilder(
sync_config, sync_config,
full_state, full_state,
@ -956,11 +948,11 @@ class SyncHandler(object):
joined_room_ids=joined_room_ids, joined_room_ids=joined_room_ids,
) )
account_data_by_room = yield self._generate_sync_entry_for_account_data( account_data_by_room = await self._generate_sync_entry_for_account_data(
sync_result_builder sync_result_builder
) )
res = yield self._generate_sync_entry_for_rooms( res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room sync_result_builder, account_data_by_room
) )
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
@ -970,13 +962,13 @@ class SyncHandler(object):
since_token is None and sync_config.filter_collection.blocks_all_presence() since_token is None and sync_config.filter_collection.blocks_all_presence()
) )
if self.hs_config.use_presence and not block_all_presence_data: if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence( await self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
) )
yield self._generate_sync_entry_for_to_device(sync_result_builder) await self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list( device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder, sync_result_builder,
newly_joined_rooms=newly_joined_rooms, newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_users=newly_joined_or_invited_users, newly_joined_or_invited_users=newly_joined_or_invited_users,
@ -987,11 +979,11 @@ class SyncHandler(object):
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_key_counts = {} one_time_key_counts = {}
if device_id: if device_id:
one_time_key_counts = yield self.store.count_e2e_one_time_keys( one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
yield self._generate_sync_entry_for_groups(sync_result_builder) await self._generate_sync_entry_for_groups(sync_result_builder)
# debug for https://github.com/matrix-org/synapse/issues/4422 # debug for https://github.com/matrix-org/synapse/issues/4422
for joined_room in sync_result_builder.joined: for joined_room in sync_result_builder.joined:
@ -1015,18 +1007,17 @@ class SyncHandler(object):
) )
@measure_func("_generate_sync_entry_for_groups") @measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks async def _generate_sync_entry_for_groups(self, sync_result_builder):
def _generate_sync_entry_for_groups(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string() user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token now_token = sync_result_builder.now_token
if since_token and since_token.groups_key: if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user( results = self.store.get_groups_changes_for_user(
user_id, since_token.groups_key, now_token.groups_key user_id, since_token.groups_key, now_token.groups_key
) )
else: else:
results = yield self.store.get_all_groups_for_user( results = await self.store.get_all_groups_for_user(
user_id, now_token.groups_key user_id, now_token.groups_key
) )
@ -1059,8 +1050,7 @@ class SyncHandler(object):
) )
@measure_func("_generate_sync_entry_for_device_list") @measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks async def _generate_sync_entry_for_device_list(
def _generate_sync_entry_for_device_list(
self, self,
sync_result_builder, sync_result_builder,
newly_joined_rooms, newly_joined_rooms,
@ -1108,32 +1098,32 @@ class SyncHandler(object):
# room with by looking at all users that have left a room plus users # room with by looking at all users that have left a room plus users
# that were in a room we've left. # that were in a room we've left.
users_who_share_room = yield self.store.get_users_who_share_room_with_user( users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id user_id
) )
# Step 1a, check for changes in devices of users we share a room with # Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = yield self.store.get_users_whose_devices_changed( users_that_have_changed = await self.store.get_users_whose_devices_changed(
since_token.device_list_key, users_who_share_room since_token.device_list_key, users_who_share_room
) )
# Step 1b, check for newly joined rooms # Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_users_in_room(room_id) joined_users = await self.state.get_current_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users) newly_joined_or_invited_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they # TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined. # weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_users) users_that_have_changed.update(newly_joined_or_invited_users)
user_signatures_changed = yield self.store.get_users_whose_signatures_changed( user_signatures_changed = await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key user_id, since_token.device_list_key
) )
users_that_have_changed.update(user_signatures_changed) users_that_have_changed.update(user_signatures_changed)
# Now find users that we no longer track # Now find users that we no longer track
for room_id in newly_left_rooms: for room_id in newly_left_rooms:
left_users = yield self.state.get_current_users_in_room(room_id) left_users = await self.state.get_current_users_in_room(room_id)
newly_left_users.update(left_users) newly_left_users.update(left_users)
# Remove any users that we still share a room with. # Remove any users that we still share a room with.
@ -1143,8 +1133,7 @@ class SyncHandler(object):
else: else:
return DeviceLists(changed=[], left=[]) return DeviceLists(changed=[], left=[])
@defer.inlineCallbacks async def _generate_sync_entry_for_to_device(self, sync_result_builder):
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates """Generates the portion of the sync response. Populates
`sync_result_builder` with the result. `sync_result_builder` with the result.
@ -1165,14 +1154,14 @@ class SyncHandler(object):
# We only delete messages when a new message comes in, but that's # We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point. # fine so long as we delete them at some point.
deleted = yield self.store.delete_messages_for_device( deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id user_id, device_id, since_stream_id
) )
logger.debug( logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id "Deleted %d to-device messages up to %d", deleted, since_stream_id
) )
messages, stream_id = yield self.store.get_new_messages_for_device( messages, stream_id = await self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key user_id, device_id, since_stream_id, now_token.to_device_key
) )
@ -1190,8 +1179,7 @@ class SyncHandler(object):
else: else:
sync_result_builder.to_device = [] sync_result_builder.to_device = []
@defer.inlineCallbacks async def _generate_sync_entry_for_account_data(self, sync_result_builder):
def _generate_sync_entry_for_account_data(self, sync_result_builder):
"""Generates the account data portion of the sync response. Populates """Generates the account data portion of the sync response. Populates
`sync_result_builder` with the result. `sync_result_builder` with the result.
@ -1209,25 +1197,25 @@ class SyncHandler(object):
( (
account_data, account_data,
account_data_by_room, account_data_by_room,
) = yield self.store.get_updated_account_data_for_user( ) = self.store.get_updated_account_data_for_user(
user_id, since_token.account_data_key user_id, since_token.account_data_key
) )
push_rules_changed = yield self.store.have_push_rules_changed_for_user( push_rules_changed = await self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key) user_id, int(since_token.push_rules_key)
) )
if push_rules_changed: if push_rules_changed:
account_data["m.push_rules"] = yield self.push_rules_for_user( account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user sync_config.user
) )
else: else:
( (
account_data, account_data,
account_data_by_room, account_data_by_room,
) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
account_data["m.push_rules"] = yield self.push_rules_for_user( account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user sync_config.user
) )
@ -1242,8 +1230,7 @@ class SyncHandler(object):
return account_data_by_room return account_data_by_room
@defer.inlineCallbacks async def _generate_sync_entry_for_presence(
def _generate_sync_entry_for_presence(
self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
): ):
"""Generates the presence portion of the sync response. Populates the """Generates the presence portion of the sync response. Populates the
@ -1271,7 +1258,7 @@ class SyncHandler(object):
presence_key = None presence_key = None
include_offline = False include_offline = False
presence, presence_key = yield presence_source.get_new_events( presence, presence_key = await presence_source.get_new_events(
user=user, user=user,
from_key=presence_key, from_key=presence_key,
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
@ -1283,12 +1270,12 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_or_invited_users) extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
users = yield self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users) extra_users_ids.update(users)
extra_users_ids.discard(user.to_string()) extra_users_ids.discard(user.to_string())
if extra_users_ids: if extra_users_ids:
states = yield self.presence_handler.get_states(extra_users_ids) states = await self.presence_handler.get_states(extra_users_ids)
presence.extend(states) presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user # Deduplicate the presence entries so that there's at most one per user
@ -1298,8 +1285,9 @@ class SyncHandler(object):
sync_result_builder.presence = presence sync_result_builder.presence = presence
@defer.inlineCallbacks async def _generate_sync_entry_for_rooms(
def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): self, sync_result_builder, account_data_by_room
):
"""Generates the rooms portion of the sync response. Populates the """Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result. `sync_result_builder` with the result.
@ -1321,7 +1309,7 @@ class SyncHandler(object):
if block_all_room_ephemeral: if block_all_room_ephemeral:
ephemeral_by_room = {} ephemeral_by_room = {}
else: else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room( now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder, sync_result_builder,
now_token=sync_result_builder.now_token, now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token, since_token=sync_result_builder.since_token,
@ -1333,16 +1321,16 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
if not sync_result_builder.full_state: if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room: if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = yield self._have_rooms_changed(sync_result_builder) have_changed = await self._have_rooms_changed(sync_result_builder)
if not have_changed: if not have_changed:
tags_by_room = yield self.store.get_updated_tags( tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key user_id, since_token.account_data_key
) )
if not tags_by_room: if not tags_by_room:
logger.debug("no-oping sync") logger.debug("no-oping sync")
return [], [], [], [] return [], [], [], []
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id "m.ignored_user_list", user_id=user_id
) )
@ -1352,18 +1340,18 @@ class SyncHandler(object):
ignored_users = frozenset() ignored_users = frozenset()
if since_token: if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users) res = await self._get_rooms_changed(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms, newly_left_rooms = res room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags( tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key user_id, since_token.account_data_key
) )
else: else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users) res = await self._get_all_rooms(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res room_entries, invited, newly_joined_rooms = res
newly_left_rooms = [] newly_left_rooms = []
tags_by_room = yield self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
def handle_room_entries(room_entry): def handle_room_entries(room_entry):
return self._generate_room_entry( return self._generate_room_entry(
@ -1376,7 +1364,7 @@ class SyncHandler(object):
always_include=sync_result_builder.full_state, always_include=sync_result_builder.full_state,
) )
yield concurrently_execute(handle_room_entries, room_entries, 10) await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited) sync_result_builder.invited.extend(invited)
@ -1410,8 +1398,7 @@ class SyncHandler(object):
newly_left_users, newly_left_users,
) )
@defer.inlineCallbacks async def _have_rooms_changed(self, sync_result_builder):
def _have_rooms_changed(self, sync_result_builder):
"""Returns whether there may be any new events that should be sent down """Returns whether there may be any new events that should be sent down
the sync. Returns True if there are. the sync. Returns True if there are.
""" """
@ -1422,7 +1409,7 @@ class SyncHandler(object):
assert since_token assert since_token
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
) )
@ -1435,8 +1422,7 @@ class SyncHandler(object):
return True return True
return False return False
@defer.inlineCallbacks async def _get_rooms_changed(self, sync_result_builder, ignored_users):
def _get_rooms_changed(self, sync_result_builder, ignored_users):
"""Gets the the changes that have happened since the last sync. """Gets the the changes that have happened since the last sync.
Args: Args:
@ -1461,7 +1447,7 @@ class SyncHandler(object):
assert since_token assert since_token
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
) )
@ -1499,11 +1485,11 @@ class SyncHandler(object):
continue continue
if room_id in sync_result_builder.joined_room_ids or has_join: if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state_ids = await self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None old_mem_ev = None
if old_mem_ev_id: if old_mem_ev_id:
old_mem_ev = yield self.store.get_event( old_mem_ev = await self.store.get_event(
old_mem_ev_id, allow_none=True old_mem_ev_id, allow_none=True
) )
@ -1536,13 +1522,13 @@ class SyncHandler(object):
newly_left_rooms.append(room_id) newly_left_rooms.append(room_id)
else: else:
if not old_state_ids: if not old_state_ids:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state_ids = await self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get( old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None (EventTypes.Member, user_id), None
) )
old_mem_ev = None old_mem_ev = None
if old_mem_ev_id: if old_mem_ev_id:
old_mem_ev = yield self.store.get_event( old_mem_ev = await self.store.get_event(
old_mem_ev_id, allow_none=True old_mem_ev_id, allow_none=True
) )
if old_mem_ev and old_mem_ev.membership == Membership.JOIN: if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
@ -1566,7 +1552,7 @@ class SyncHandler(object):
if leave_events: if leave_events:
leave_event = leave_events[-1] leave_event = leave_events[-1]
leave_stream_token = yield self.store.get_stream_token_for_event( leave_stream_token = await self.store.get_stream_token_for_event(
leave_event.event_id leave_event.event_id
) )
leave_token = since_token.copy_and_replace( leave_token = since_token.copy_and_replace(
@ -1603,7 +1589,7 @@ class SyncHandler(object):
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
# Get all events for rooms we're currently joined to. # Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms( room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=sync_result_builder.joined_room_ids, room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
@ -1652,8 +1638,7 @@ class SyncHandler(object):
return room_entries, invited, newly_joined_rooms, newly_left_rooms return room_entries, invited, newly_joined_rooms, newly_left_rooms
@defer.inlineCallbacks async def _get_all_rooms(self, sync_result_builder, ignored_users):
def _get_all_rooms(self, sync_result_builder, ignored_users):
"""Returns entries for all rooms for the user. """Returns entries for all rooms for the user.
Args: Args:
@ -1677,7 +1662,7 @@ class SyncHandler(object):
Membership.BAN, Membership.BAN,
) )
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=membership_list user_id=user_id, membership_list=membership_list
) )
@ -1700,7 +1685,7 @@ class SyncHandler(object):
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
if event.sender in ignored_users: if event.sender in ignored_users:
continue continue
invite = yield self.store.get_event(event.event_id) invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership in (Membership.LEAVE, Membership.BAN): elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from. # Always send down rooms we were banned or kicked from.
@ -1726,8 +1711,7 @@ class SyncHandler(object):
return room_entries, invited, [] return room_entries, invited, []
@defer.inlineCallbacks async def _generate_room_entry(
def _generate_room_entry(
self, self,
sync_result_builder, sync_result_builder,
ignored_users, ignored_users,
@ -1769,7 +1753,7 @@ class SyncHandler(object):
since_token = room_builder.since_token since_token = room_builder.since_token
upto_token = room_builder.upto_token upto_token = room_builder.upto_token
batch = yield self._load_filtered_recents( batch = await self._load_filtered_recents(
room_id, room_id,
sync_config, sync_config,
now_token=upto_token, now_token=upto_token,
@ -1796,7 +1780,7 @@ class SyncHandler(object):
# tag was added by synapse e.g. for server notice rooms. # tag was added by synapse e.g. for server notice rooms.
if full_state: if full_state:
user_id = sync_result_builder.sync_config.user.to_string() user_id = sync_result_builder.sync_config.user.to_string()
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = await self.store.get_tags_for_room(user_id, room_id)
# If there aren't any tags, don't send the empty tags list down # If there aren't any tags, don't send the empty tags list down
# sync # sync
@ -1821,7 +1805,7 @@ class SyncHandler(object):
): ):
return return
state = yield self.compute_state_delta( state = await self.compute_state_delta(
room_id, batch, sync_config, since_token, now_token, full_state=full_state room_id, batch, sync_config, since_token, now_token, full_state=full_state
) )
@ -1844,7 +1828,7 @@ class SyncHandler(object):
) )
or since_token is None or since_token is None
): ):
summary = yield self.compute_summary( summary = await self.compute_summary(
room_id, sync_config, batch, state, now_token room_id, sync_config, batch, state, now_token
) )
@ -1861,7 +1845,7 @@ class SyncHandler(object):
) )
if room_sync or always_include: if room_sync or always_include:
notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None: if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"] unread_notifications["notification_count"] = notifs["notify_count"]
@ -1887,8 +1871,7 @@ class SyncHandler(object):
else: else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype) raise Exception("Unrecognized rtype: %r", room_builder.rtype)
@defer.inlineCallbacks async def get_rooms_for_user_at(self, user_id, stream_ordering):
def get_rooms_for_user_at(self, user_id, stream_ordering):
"""Get set of joined rooms for a user at the given stream ordering. """Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an The stream ordering *must* be recent, otherwise this may throw an
@ -1903,7 +1886,7 @@ class SyncHandler(object):
Deferred[frozenset[str]]: Set of room_ids the user is in at given Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering. stream_ordering.
""" """
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
joined_room_ids = set() joined_room_ids = set()
@ -1921,10 +1904,10 @@ class SyncHandler(object):
logger.info("User joined room after current token: %s", room_id) logger.info("User joined room after current token: %s", room_id)
extrems = yield self.store.get_forward_extremeties_for_room( extrems = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering room_id, stream_ordering
) )
users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room: if user_id in users_in_room:
joined_room_ids.add(room_id) joined_room_ids.add(room_id)

View file

@ -304,8 +304,7 @@ class Notifier(object):
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""
self.notify_replication() self.notify_replication()
@defer.inlineCallbacks async def wait_for_events(
def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
): ):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
@ -313,9 +312,9 @@ class Notifier(object):
""" """
user_stream = self.user_to_user_stream.get(user_id) user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None: if user_stream is None:
current_token = yield self.event_sources.get_current_token() current_token = await self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user_id=user_id, user_id=user_id,
rooms=room_ids, rooms=room_ids,
@ -344,11 +343,11 @@ class Notifier(object):
self.hs.get_reactor(), self.hs.get_reactor(),
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
yield listener.deferred await listener.deferred
current_token = user_stream.current_token current_token = user_stream.current_token
result = yield callback(prev_token, current_token) result = await callback(prev_token, current_token)
if result: if result:
break break
@ -364,12 +363,11 @@ class Notifier(object):
# This happened if there was no timeout or if the timeout had # This happened if there was no timeout or if the timeout had
# already expired. # already expired.
current_token = user_stream.current_token current_token = user_stream.current_token
result = yield callback(prev_token, current_token) result = await callback(prev_token, current_token)
return result return result
@defer.inlineCallbacks async def get_events_for(
def get_events_for(
self, self,
user, user,
pagination_config, pagination_config,
@ -391,15 +389,14 @@ class Notifier(object):
""" """
from_token = pagination_config.from_token from_token = pagination_config.from_token
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = await self.event_sources.get_current_token()
limit = pagination_config.limit limit = pagination_config.limit
room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined is_peeking = not is_joined
@defer.inlineCallbacks async def check_for_updates(before_token, after_token):
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token)) return EventStreamResult([], (from_token, from_token))
@ -415,7 +412,7 @@ class Notifier(object):
if only_keys and name not in only_keys: if only_keys and name not in only_keys:
continue continue
new_events, new_key = yield source.get_new_events( new_events, new_key = await source.get_new_events(
user=user, user=user,
from_key=getattr(from_token, keyname), from_key=getattr(from_token, keyname),
limit=limit, limit=limit,
@ -425,7 +422,7 @@ class Notifier(object):
) )
if name == "room": if name == "room":
new_events = yield filter_events_for_client( new_events = await filter_events_for_client(
self.storage, self.storage,
user.to_string(), user.to_string(),
new_events, new_events,
@ -461,7 +458,7 @@ class Notifier(object):
user_id_for_stream, user_id_for_stream,
) )
result = yield self.wait_for_events( result = await self.wait_for_events(
user_id_for_stream, user_id_for_stream,
timeout, timeout,
check_for_updates, check_for_updates,

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from functools import wraps from functools import wraps
@ -64,12 +65,22 @@ def measure_func(name=None):
def wrapper(func): def wrapper(func):
block_name = func.__name__ if name is None else name block_name = func.__name__ if name is None else name
@wraps(func) if inspect.iscoroutinefunction(func):
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs): @wraps(func)
with Measure(self.clock, block_name): async def measured_func(self, *args, **kwargs):
r = yield func(self, *args, **kwargs) with Measure(self.clock, block_name):
return r r = await func(self, *args, **kwargs)
return r
else:
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
return measured_func return measured_func

View file

@ -12,54 +12,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.handlers.sync import SyncConfig
from synapse.types import UserID from synapse.types import UserID
import tests.unittest import tests.unittest
import tests.utils import tests.utils
from tests.utils import setup_test_homeserver
class SyncTestCase(tests.unittest.TestCase): class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """ """ Tests Sync Handler. """
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self): self.hs = hs
self.hs = yield setup_test_homeserver(self.addCleanup) self.sync_handler = self.hs.get_sync_handler()
self.sync_handler = SyncHandler(self.hs)
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self): def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server" user_id1 = "@user1:server"
user_id2 = "@user2:server" user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1) sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1) self.get_success(self.store.upsert_monthly_active_user(user_id1))
yield self.sync_handler.wait_for_sync_for_user(sync_config) self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works # Test that global lock works
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
with self.assertRaises(ResourceLimitError) as e: e = self.get_failure(
yield self.sync_handler.wait_for_sync_for_user(sync_config) self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(ResourceLimitError) as e: e = self.get_failure(
yield self.sync_handler.wait_for_sync_for_user(sync_config) self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id): def _generate_sync_config(self, user_id):
return SyncConfig( return SyncConfig(

View file

@ -18,6 +18,7 @@
import gc import gc
import hashlib import hashlib
import hmac import hmac
import inspect
import logging import logging
import time import time
@ -25,7 +26,7 @@ from mock import Mock
from canonicaljson import json from canonicaljson import json
from twisted.internet.defer import Deferred, succeed from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest from twisted.trial import unittest
@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100) self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0): def get_success(self, d, by=0.0):
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred): if not isinstance(d, Deferred):
return d return d
self.pump(by=by) self.pump(by=by)
@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase):
""" """
Run a Deferred and get a Failure from it. The failure must be of the type `exc`. Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
""" """
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred): if not isinstance(d, Deferred):
return d return d
self.pump() self.pump()