mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 18:23:28 +01:00
Port SyncHandler to async/await
This commit is contained in:
parent
d085a8a0a5
commit
8437e2383e
6 changed files with 182 additions and 191 deletions
|
@ -16,8 +16,6 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.events import EventBase
|
||||
|
@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
|
|||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_stream(
|
||||
async def get_stream(
|
||||
self,
|
||||
auth_user_id,
|
||||
pagin_config,
|
||||
|
@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
|
|||
"""
|
||||
|
||||
if room_id:
|
||||
blocked = yield self.store.is_room_blocked(room_id)
|
||||
blocked = await self.store.is_room_blocked(room_id)
|
||||
if blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
# send any outstanding server notices to the user.
|
||||
yield self._server_notices_sender.on_user_syncing(auth_user_id)
|
||||
await self._server_notices_sender.on_user_syncing(auth_user_id)
|
||||
|
||||
auth_user = UserID.from_string(auth_user_id)
|
||||
presence_handler = self.hs.get_presence_handler()
|
||||
|
||||
context = yield presence_handler.user_syncing(
|
||||
context = await presence_handler.user_syncing(
|
||||
auth_user_id, affect_presence=affect_presence
|
||||
)
|
||||
with context:
|
||||
|
@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
|
|||
# thundering herds on restart.
|
||||
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
|
||||
|
||||
events, tokens = yield self.notifier.get_events_for(
|
||||
events, tokens = await self.notifier.get_events_for(
|
||||
auth_user,
|
||||
pagin_config,
|
||||
timeout,
|
||||
|
@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence.
|
||||
if event.state_key == auth_user_id:
|
||||
# Send down presence for everyone in the room.
|
||||
users = yield self.state.get_current_users_in_room(
|
||||
users = await self.state.get_current_users_in_room(
|
||||
event.room_id
|
||||
)
|
||||
states = yield presence_handler.get_states(users, as_event=True)
|
||||
states = await presence_handler.get_states(users, as_event=True)
|
||||
to_add.extend(states)
|
||||
else:
|
||||
|
||||
ev = yield presence_handler.get_state(
|
||||
ev = await presence_handler.get_state(
|
||||
UserID.from_string(event.state_key), as_event=True
|
||||
)
|
||||
to_add.append(ev)
|
||||
|
@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunks = yield self._event_serializer.serialize_events(
|
||||
chunks = await self._event_serializer.serialize_events(
|
||||
events,
|
||||
time_now,
|
||||
as_client_event=as_client_event,
|
||||
|
@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
|
|||
super(EventHandler, self).__init__(hs)
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, user, room_id, event_id):
|
||||
async def get_event(self, user, room_id, event_id):
|
||||
"""Retrieve a single specified event.
|
||||
|
||||
Args:
|
||||
|
@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
|
|||
AuthError if the user does not have the rights to inspect this
|
||||
event.
|
||||
"""
|
||||
event = yield self.store.get_event(event_id, check_room_id=room_id)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
if not event:
|
||||
return None
|
||||
|
||||
users = yield self.store.get_users_in_room(event.room_id)
|
||||
users = await self.store.get_users_in_room(event.room_id)
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
filtered = yield filter_events_for_client(
|
||||
filtered = await filter_events_for_client(
|
||||
self.storage, user.to_string(), [event], is_peeking=is_peeking
|
||||
)
|
||||
|
||||
|
|
|
@ -22,8 +22,6 @@ from six import iteritems, itervalues
|
|||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
|
@ -241,8 +239,7 @@ class SyncHandler(object):
|
|||
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_sync_for_user(
|
||||
async def wait_for_sync_for_user(
|
||||
self, sync_config, since_token=None, timeout=0, full_state=False
|
||||
):
|
||||
"""Get the sync for a client if we have new data for it now. Otherwise
|
||||
|
@ -255,9 +252,9 @@ class SyncHandler(object):
|
|||
# not been exceeded (if not part of the group by this point, almost certain
|
||||
# auth_blocking will occur)
|
||||
user_id = sync_config.user.to_string()
|
||||
yield self.auth.check_auth_blocking(user_id)
|
||||
await self.auth.check_auth_blocking(user_id)
|
||||
|
||||
res = yield self.response_cache.wrap(
|
||||
res = await self.response_cache.wrap(
|
||||
sync_config.request_key,
|
||||
self._wait_for_sync_for_user,
|
||||
sync_config,
|
||||
|
@ -267,8 +264,9 @@ class SyncHandler(object):
|
|||
)
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
|
||||
async def _wait_for_sync_for_user(
|
||||
self, sync_config, since_token, timeout, full_state
|
||||
):
|
||||
if since_token is None:
|
||||
sync_type = "initial_sync"
|
||||
elif full_state:
|
||||
|
@ -283,7 +281,7 @@ class SyncHandler(object):
|
|||
if timeout == 0 or since_token is None or full_state:
|
||||
# we are going to return immediately, so don't bother calling
|
||||
# notifier.wait_for_events.
|
||||
result = yield self.current_sync_for_user(
|
||||
result = await self.current_sync_for_user(
|
||||
sync_config, since_token, full_state=full_state
|
||||
)
|
||||
else:
|
||||
|
@ -291,7 +289,7 @@ class SyncHandler(object):
|
|||
def current_sync_callback(before_token, after_token):
|
||||
return self.current_sync_for_user(sync_config, since_token)
|
||||
|
||||
result = yield self.notifier.wait_for_events(
|
||||
result = await self.notifier.wait_for_events(
|
||||
sync_config.user.to_string(),
|
||||
timeout,
|
||||
current_sync_callback,
|
||||
|
@ -314,15 +312,13 @@ class SyncHandler(object):
|
|||
"""
|
||||
return self.generate_sync_result(sync_config, since_token, full_state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_rules_for_user(self, user):
|
||||
async def push_rules_for_user(self, user):
|
||||
user_id = user.to_string()
|
||||
rules = yield self.store.get_push_rules_for_user(user_id)
|
||||
rules = await self.store.get_push_rules_for_user(user_id)
|
||||
rules = format_push_rules_for_user(user, rules)
|
||||
return rules
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
|
||||
async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
|
||||
"""Get the ephemeral events for each room the user is in
|
||||
Args:
|
||||
sync_result_builder(SyncResultBuilder)
|
||||
|
@ -343,7 +339,7 @@ class SyncHandler(object):
|
|||
room_ids = sync_result_builder.joined_room_ids
|
||||
|
||||
typing_source = self.event_sources.sources["typing"]
|
||||
typing, typing_key = yield typing_source.get_new_events(
|
||||
typing, typing_key = typing_source.get_new_events(
|
||||
user=sync_config.user,
|
||||
from_key=typing_key,
|
||||
limit=sync_config.filter_collection.ephemeral_limit(),
|
||||
|
@ -365,7 +361,7 @@ class SyncHandler(object):
|
|||
receipt_key = since_token.receipt_key if since_token else "0"
|
||||
|
||||
receipt_source = self.event_sources.sources["receipt"]
|
||||
receipts, receipt_key = yield receipt_source.get_new_events(
|
||||
receipts, receipt_key = await receipt_source.get_new_events(
|
||||
user=sync_config.user,
|
||||
from_key=receipt_key,
|
||||
limit=sync_config.filter_collection.ephemeral_limit(),
|
||||
|
@ -382,8 +378,7 @@ class SyncHandler(object):
|
|||
|
||||
return now_token, ephemeral_by_room
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _load_filtered_recents(
|
||||
async def _load_filtered_recents(
|
||||
self,
|
||||
room_id,
|
||||
sync_config,
|
||||
|
@ -415,10 +410,10 @@ class SyncHandler(object):
|
|||
# ensure that we always include current state in the timeline
|
||||
current_state_ids = frozenset()
|
||||
if any(e.is_state() for e in recents):
|
||||
current_state_ids = yield self.state.get_current_state_ids(room_id)
|
||||
current_state_ids = await self.state.get_current_state_ids(room_id)
|
||||
current_state_ids = frozenset(itervalues(current_state_ids))
|
||||
|
||||
recents = yield filter_events_for_client(
|
||||
recents = await filter_events_for_client(
|
||||
self.storage,
|
||||
sync_config.user.to_string(),
|
||||
recents,
|
||||
|
@ -449,14 +444,14 @@ class SyncHandler(object):
|
|||
# Otherwise, we want to return the last N events in the room
|
||||
# in toplogical ordering.
|
||||
if since_key:
|
||||
events, end_key = yield self.store.get_room_events_stream_for_room(
|
||||
events, end_key = await self.store.get_room_events_stream_for_room(
|
||||
room_id,
|
||||
limit=load_limit + 1,
|
||||
from_key=since_key,
|
||||
to_key=end_key,
|
||||
)
|
||||
else:
|
||||
events, end_key = yield self.store.get_recent_events_for_room(
|
||||
events, end_key = await self.store.get_recent_events_for_room(
|
||||
room_id, limit=load_limit + 1, end_token=end_key
|
||||
)
|
||||
loaded_recents = sync_config.filter_collection.filter_room_timeline(
|
||||
|
@ -468,10 +463,10 @@ class SyncHandler(object):
|
|||
# ensure that we always include current state in the timeline
|
||||
current_state_ids = frozenset()
|
||||
if any(e.is_state() for e in loaded_recents):
|
||||
current_state_ids = yield self.state.get_current_state_ids(room_id)
|
||||
current_state_ids = await self.state.get_current_state_ids(room_id)
|
||||
current_state_ids = frozenset(itervalues(current_state_ids))
|
||||
|
||||
loaded_recents = yield filter_events_for_client(
|
||||
loaded_recents = await filter_events_for_client(
|
||||
self.storage,
|
||||
sync_config.user.to_string(),
|
||||
loaded_recents,
|
||||
|
@ -498,8 +493,7 @@ class SyncHandler(object):
|
|||
limited=limited or newly_joined_room,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_after_event(self, event, state_filter=StateFilter.all()):
|
||||
async def get_state_after_event(self, event, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the room state after the given event
|
||||
|
||||
|
@ -511,7 +505,7 @@ class SyncHandler(object):
|
|||
Returns:
|
||||
A Deferred map from ((type, state_key)->Event)
|
||||
"""
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
event.event_id, state_filter=state_filter
|
||||
)
|
||||
if event.is_state():
|
||||
|
@ -519,8 +513,9 @@ class SyncHandler(object):
|
|||
state_ids[(event.type, event.state_key)] = event.event_id
|
||||
return state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
|
||||
async def get_state_at(
|
||||
self, room_id, stream_position, state_filter=StateFilter.all()
|
||||
):
|
||||
""" Get the room state at a particular stream position
|
||||
|
||||
Args:
|
||||
|
@ -536,13 +531,13 @@ class SyncHandler(object):
|
|||
# get_recent_events_for_room operates by topo ordering. This therefore
|
||||
# does not reliably give you the state at the given stream position.
|
||||
# (https://github.com/matrix-org/synapse/issues/3305)
|
||||
last_events, _ = yield self.store.get_recent_events_for_room(
|
||||
last_events, _ = await self.store.get_recent_events_for_room(
|
||||
room_id, end_token=stream_position.room_key, limit=1
|
||||
)
|
||||
|
||||
if last_events:
|
||||
last_event = last_events[-1]
|
||||
state = yield self.get_state_after_event(
|
||||
state = await self.get_state_after_event(
|
||||
last_event, state_filter=state_filter
|
||||
)
|
||||
|
||||
|
@ -551,8 +546,7 @@ class SyncHandler(object):
|
|||
state = {}
|
||||
return state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_summary(self, room_id, sync_config, batch, state, now_token):
|
||||
async def compute_summary(self, room_id, sync_config, batch, state, now_token):
|
||||
""" Works out a room summary block for this room, summarising the number
|
||||
of joined members in the room, and providing the 'hero' members if the
|
||||
room has no name so clients can consistently name rooms. Also adds
|
||||
|
@ -574,7 +568,7 @@ class SyncHandler(object):
|
|||
# FIXME: we could/should get this from room_stats when matthew/stats lands
|
||||
|
||||
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
|
||||
last_events, _ = yield self.store.get_recent_event_ids_for_room(
|
||||
last_events, _ = await self.store.get_recent_event_ids_for_room(
|
||||
room_id, end_token=now_token.room_key, limit=1
|
||||
)
|
||||
|
||||
|
@ -582,7 +576,7 @@ class SyncHandler(object):
|
|||
return None
|
||||
|
||||
last_event = last_events[-1]
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
last_event.event_id,
|
||||
state_filter=StateFilter.from_types(
|
||||
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
|
||||
|
@ -590,7 +584,7 @@ class SyncHandler(object):
|
|||
)
|
||||
|
||||
# this is heavily cached, thus: fast.
|
||||
details = yield self.store.get_room_summary(room_id)
|
||||
details = await self.store.get_room_summary(room_id)
|
||||
|
||||
name_id = state_ids.get((EventTypes.Name, ""))
|
||||
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
|
||||
|
@ -608,12 +602,12 @@ class SyncHandler(object):
|
|||
# calculating heroes. Empty strings are falsey, so we check
|
||||
# for the "name" value and default to an empty string.
|
||||
if name_id:
|
||||
name = yield self.store.get_event(name_id, allow_none=True)
|
||||
name = await self.store.get_event(name_id, allow_none=True)
|
||||
if name and name.content.get("name"):
|
||||
return summary
|
||||
|
||||
if canonical_alias_id:
|
||||
canonical_alias = yield self.store.get_event(
|
||||
canonical_alias = await self.store.get_event(
|
||||
canonical_alias_id, allow_none=True
|
||||
)
|
||||
if canonical_alias and canonical_alias.content.get("alias"):
|
||||
|
@ -678,7 +672,7 @@ class SyncHandler(object):
|
|||
)
|
||||
]
|
||||
|
||||
missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
|
||||
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
|
||||
missing_hero_state = missing_hero_state.values()
|
||||
|
||||
for s in missing_hero_state:
|
||||
|
@ -697,8 +691,7 @@ class SyncHandler(object):
|
|||
logger.debug("found LruCache for %r", cache_key)
|
||||
return cache
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_state_delta(
|
||||
async def compute_state_delta(
|
||||
self, room_id, batch, sync_config, since_token, now_token, full_state
|
||||
):
|
||||
""" Works out the difference in state between the start of the timeline
|
||||
|
@ -759,16 +752,16 @@ class SyncHandler(object):
|
|||
|
||||
if full_state:
|
||||
if batch:
|
||||
current_state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
current_state_ids = await self.state_store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
batch.events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
else:
|
||||
current_state_ids = yield self.get_state_at(
|
||||
current_state_ids = await self.get_state_at(
|
||||
room_id, stream_position=now_token, state_filter=state_filter
|
||||
)
|
||||
|
||||
|
@ -783,13 +776,13 @@ class SyncHandler(object):
|
|||
)
|
||||
elif batch.limited:
|
||||
if batch:
|
||||
state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
|
||||
state_at_timeline_start = await self.state_store.get_state_ids_for_event(
|
||||
batch.events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
else:
|
||||
# We can get here if the user has ignored the senders of all
|
||||
# the recent events.
|
||||
state_at_timeline_start = yield self.get_state_at(
|
||||
state_at_timeline_start = await self.get_state_at(
|
||||
room_id, stream_position=now_token, state_filter=state_filter
|
||||
)
|
||||
|
||||
|
@ -807,19 +800,19 @@ class SyncHandler(object):
|
|||
# about them).
|
||||
state_filter = StateFilter.all()
|
||||
|
||||
state_at_previous_sync = yield self.get_state_at(
|
||||
state_at_previous_sync = await self.get_state_at(
|
||||
room_id, stream_position=since_token, state_filter=state_filter
|
||||
)
|
||||
|
||||
if batch:
|
||||
current_state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
current_state_ids = await self.state_store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id, state_filter=state_filter
|
||||
)
|
||||
else:
|
||||
# Its not clear how we get here, but empirically we do
|
||||
# (#5407). Logging has been added elsewhere to try and
|
||||
# figure out where this state comes from.
|
||||
current_state_ids = yield self.get_state_at(
|
||||
current_state_ids = await self.get_state_at(
|
||||
room_id, stream_position=now_token, state_filter=state_filter
|
||||
)
|
||||
|
||||
|
@ -843,7 +836,7 @@ class SyncHandler(object):
|
|||
# So we fish out all the member events corresponding to the
|
||||
# timeline here, and then dedupe any redundant ones below.
|
||||
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
batch.events[0].event_id,
|
||||
# we only want members!
|
||||
state_filter=StateFilter.from_types(
|
||||
|
@ -883,7 +876,7 @@ class SyncHandler(object):
|
|||
|
||||
state = {}
|
||||
if state_ids:
|
||||
state = yield self.store.get_events(list(state_ids.values()))
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
|
||||
return {
|
||||
(e.type, e.state_key): e
|
||||
|
@ -892,10 +885,9 @@ class SyncHandler(object):
|
|||
)
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def unread_notifs_for_room_id(self, room_id, sync_config):
|
||||
async def unread_notifs_for_room_id(self, room_id, sync_config):
|
||||
with Measure(self.clock, "unread_notifs_for_room_id"):
|
||||
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
|
||||
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
|
||||
user_id=sync_config.user.to_string(),
|
||||
room_id=room_id,
|
||||
receipt_type="m.read",
|
||||
|
@ -903,7 +895,7 @@ class SyncHandler(object):
|
|||
|
||||
notifs = []
|
||||
if last_unread_event_id:
|
||||
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id, sync_config.user.to_string(), last_unread_event_id
|
||||
)
|
||||
return notifs
|
||||
|
@ -912,8 +904,9 @@ class SyncHandler(object):
|
|||
# count is whatever it was last time.
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
|
||||
async def generate_sync_result(
|
||||
self, sync_config, since_token=None, full_state=False
|
||||
):
|
||||
"""Generates a sync result.
|
||||
|
||||
Args:
|
||||
|
@ -928,7 +921,7 @@ class SyncHandler(object):
|
|||
# this is due to some of the underlying streams not supporting the ability
|
||||
# to query up to a given point.
|
||||
# Always use the `now_token` in `SyncResultBuilder`
|
||||
now_token = yield self.event_sources.get_current_token()
|
||||
now_token = await self.event_sources.get_current_token()
|
||||
|
||||
logger.info(
|
||||
"Calculating sync response for %r between %s and %s",
|
||||
|
@ -944,10 +937,9 @@ class SyncHandler(object):
|
|||
# See https://github.com/matrix-org/matrix-doc/issues/1144
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
joined_room_ids = yield self.get_rooms_for_user_at(
|
||||
joined_room_ids = await self.get_rooms_for_user_at(
|
||||
user_id, now_token.room_stream_id
|
||||
)
|
||||
|
||||
sync_result_builder = SyncResultBuilder(
|
||||
sync_config,
|
||||
full_state,
|
||||
|
@ -956,11 +948,11 @@ class SyncHandler(object):
|
|||
joined_room_ids=joined_room_ids,
|
||||
)
|
||||
|
||||
account_data_by_room = yield self._generate_sync_entry_for_account_data(
|
||||
account_data_by_room = await self._generate_sync_entry_for_account_data(
|
||||
sync_result_builder
|
||||
)
|
||||
|
||||
res = yield self._generate_sync_entry_for_rooms(
|
||||
res = await self._generate_sync_entry_for_rooms(
|
||||
sync_result_builder, account_data_by_room
|
||||
)
|
||||
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
|
||||
|
@ -970,13 +962,13 @@ class SyncHandler(object):
|
|||
since_token is None and sync_config.filter_collection.blocks_all_presence()
|
||||
)
|
||||
if self.hs_config.use_presence and not block_all_presence_data:
|
||||
yield self._generate_sync_entry_for_presence(
|
||||
await self._generate_sync_entry_for_presence(
|
||||
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
|
||||
)
|
||||
|
||||
yield self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||
await self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||
|
||||
device_lists = yield self._generate_sync_entry_for_device_list(
|
||||
device_lists = await self._generate_sync_entry_for_device_list(
|
||||
sync_result_builder,
|
||||
newly_joined_rooms=newly_joined_rooms,
|
||||
newly_joined_or_invited_users=newly_joined_or_invited_users,
|
||||
|
@ -987,11 +979,11 @@ class SyncHandler(object):
|
|||
device_id = sync_config.device_id
|
||||
one_time_key_counts = {}
|
||||
if device_id:
|
||||
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
|
||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||
user_id, device_id
|
||||
)
|
||||
|
||||
yield self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
await self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
|
||||
# debug for https://github.com/matrix-org/synapse/issues/4422
|
||||
for joined_room in sync_result_builder.joined:
|
||||
|
@ -1015,18 +1007,17 @@ class SyncHandler(object):
|
|||
)
|
||||
|
||||
@measure_func("_generate_sync_entry_for_groups")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_groups(self, sync_result_builder):
|
||||
async def _generate_sync_entry_for_groups(self, sync_result_builder):
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
now_token = sync_result_builder.now_token
|
||||
|
||||
if since_token and since_token.groups_key:
|
||||
results = yield self.store.get_groups_changes_for_user(
|
||||
results = self.store.get_groups_changes_for_user(
|
||||
user_id, since_token.groups_key, now_token.groups_key
|
||||
)
|
||||
else:
|
||||
results = yield self.store.get_all_groups_for_user(
|
||||
results = await self.store.get_all_groups_for_user(
|
||||
user_id, now_token.groups_key
|
||||
)
|
||||
|
||||
|
@ -1059,8 +1050,7 @@ class SyncHandler(object):
|
|||
)
|
||||
|
||||
@measure_func("_generate_sync_entry_for_device_list")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_device_list(
|
||||
async def _generate_sync_entry_for_device_list(
|
||||
self,
|
||||
sync_result_builder,
|
||||
newly_joined_rooms,
|
||||
|
@ -1108,32 +1098,32 @@ class SyncHandler(object):
|
|||
# room with by looking at all users that have left a room plus users
|
||||
# that were in a room we've left.
|
||||
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
# Step 1a, check for changes in devices of users we share a room with
|
||||
users_that_have_changed = yield self.store.get_users_whose_devices_changed(
|
||||
users_that_have_changed = await self.store.get_users_whose_devices_changed(
|
||||
since_token.device_list_key, users_who_share_room
|
||||
)
|
||||
|
||||
# Step 1b, check for newly joined rooms
|
||||
for room_id in newly_joined_rooms:
|
||||
joined_users = yield self.state.get_current_users_in_room(room_id)
|
||||
joined_users = await self.state.get_current_users_in_room(room_id)
|
||||
newly_joined_or_invited_users.update(joined_users)
|
||||
|
||||
# TODO: Check that these users are actually new, i.e. either they
|
||||
# weren't in the previous sync *or* they left and rejoined.
|
||||
users_that_have_changed.update(newly_joined_or_invited_users)
|
||||
|
||||
user_signatures_changed = yield self.store.get_users_whose_signatures_changed(
|
||||
user_signatures_changed = await self.store.get_users_whose_signatures_changed(
|
||||
user_id, since_token.device_list_key
|
||||
)
|
||||
users_that_have_changed.update(user_signatures_changed)
|
||||
|
||||
# Now find users that we no longer track
|
||||
for room_id in newly_left_rooms:
|
||||
left_users = yield self.state.get_current_users_in_room(room_id)
|
||||
left_users = await self.state.get_current_users_in_room(room_id)
|
||||
newly_left_users.update(left_users)
|
||||
|
||||
# Remove any users that we still share a room with.
|
||||
|
@ -1143,8 +1133,7 @@ class SyncHandler(object):
|
|||
else:
|
||||
return DeviceLists(changed=[], left=[])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_to_device(self, sync_result_builder):
|
||||
async def _generate_sync_entry_for_to_device(self, sync_result_builder):
|
||||
"""Generates the portion of the sync response. Populates
|
||||
`sync_result_builder` with the result.
|
||||
|
||||
|
@ -1165,14 +1154,14 @@ class SyncHandler(object):
|
|||
# We only delete messages when a new message comes in, but that's
|
||||
# fine so long as we delete them at some point.
|
||||
|
||||
deleted = yield self.store.delete_messages_for_device(
|
||||
deleted = await self.store.delete_messages_for_device(
|
||||
user_id, device_id, since_stream_id
|
||||
)
|
||||
logger.debug(
|
||||
"Deleted %d to-device messages up to %d", deleted, since_stream_id
|
||||
)
|
||||
|
||||
messages, stream_id = yield self.store.get_new_messages_for_device(
|
||||
messages, stream_id = await self.store.get_new_messages_for_device(
|
||||
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||
)
|
||||
|
||||
|
@ -1190,8 +1179,7 @@ class SyncHandler(object):
|
|||
else:
|
||||
sync_result_builder.to_device = []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_account_data(self, sync_result_builder):
|
||||
async def _generate_sync_entry_for_account_data(self, sync_result_builder):
|
||||
"""Generates the account data portion of the sync response. Populates
|
||||
`sync_result_builder` with the result.
|
||||
|
||||
|
@ -1209,25 +1197,25 @@ class SyncHandler(object):
|
|||
(
|
||||
account_data,
|
||||
account_data_by_room,
|
||||
) = yield self.store.get_updated_account_data_for_user(
|
||||
) = self.store.get_updated_account_data_for_user(
|
||||
user_id, since_token.account_data_key
|
||||
)
|
||||
|
||||
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
|
||||
push_rules_changed = await self.store.have_push_rules_changed_for_user(
|
||||
user_id, int(since_token.push_rules_key)
|
||||
)
|
||||
|
||||
if push_rules_changed:
|
||||
account_data["m.push_rules"] = yield self.push_rules_for_user(
|
||||
account_data["m.push_rules"] = await self.push_rules_for_user(
|
||||
sync_config.user
|
||||
)
|
||||
else:
|
||||
(
|
||||
account_data,
|
||||
account_data_by_room,
|
||||
) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
|
||||
) = await self.store.get_account_data_for_user(sync_config.user.to_string())
|
||||
|
||||
account_data["m.push_rules"] = yield self.push_rules_for_user(
|
||||
account_data["m.push_rules"] = await self.push_rules_for_user(
|
||||
sync_config.user
|
||||
)
|
||||
|
||||
|
@ -1242,8 +1230,7 @@ class SyncHandler(object):
|
|||
|
||||
return account_data_by_room
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_presence(
|
||||
async def _generate_sync_entry_for_presence(
|
||||
self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
|
||||
):
|
||||
"""Generates the presence portion of the sync response. Populates the
|
||||
|
@ -1271,7 +1258,7 @@ class SyncHandler(object):
|
|||
presence_key = None
|
||||
include_offline = False
|
||||
|
||||
presence, presence_key = yield presence_source.get_new_events(
|
||||
presence, presence_key = await presence_source.get_new_events(
|
||||
user=user,
|
||||
from_key=presence_key,
|
||||
is_guest=sync_config.is_guest,
|
||||
|
@ -1283,12 +1270,12 @@ class SyncHandler(object):
|
|||
|
||||
extra_users_ids = set(newly_joined_or_invited_users)
|
||||
for room_id in newly_joined_rooms:
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
extra_users_ids.update(users)
|
||||
extra_users_ids.discard(user.to_string())
|
||||
|
||||
if extra_users_ids:
|
||||
states = yield self.presence_handler.get_states(extra_users_ids)
|
||||
states = await self.presence_handler.get_states(extra_users_ids)
|
||||
presence.extend(states)
|
||||
|
||||
# Deduplicate the presence entries so that there's at most one per user
|
||||
|
@ -1298,8 +1285,9 @@ class SyncHandler(object):
|
|||
|
||||
sync_result_builder.presence = presence
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
|
||||
async def _generate_sync_entry_for_rooms(
|
||||
self, sync_result_builder, account_data_by_room
|
||||
):
|
||||
"""Generates the rooms portion of the sync response. Populates the
|
||||
`sync_result_builder` with the result.
|
||||
|
||||
|
@ -1321,7 +1309,7 @@ class SyncHandler(object):
|
|||
if block_all_room_ephemeral:
|
||||
ephemeral_by_room = {}
|
||||
else:
|
||||
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
||||
now_token, ephemeral_by_room = await self.ephemeral_by_room(
|
||||
sync_result_builder,
|
||||
now_token=sync_result_builder.now_token,
|
||||
since_token=sync_result_builder.since_token,
|
||||
|
@ -1333,16 +1321,16 @@ class SyncHandler(object):
|
|||
since_token = sync_result_builder.since_token
|
||||
if not sync_result_builder.full_state:
|
||||
if since_token and not ephemeral_by_room and not account_data_by_room:
|
||||
have_changed = yield self._have_rooms_changed(sync_result_builder)
|
||||
have_changed = await self._have_rooms_changed(sync_result_builder)
|
||||
if not have_changed:
|
||||
tags_by_room = yield self.store.get_updated_tags(
|
||||
tags_by_room = await self.store.get_updated_tags(
|
||||
user_id, since_token.account_data_key
|
||||
)
|
||||
if not tags_by_room:
|
||||
logger.debug("no-oping sync")
|
||||
return [], [], [], []
|
||||
|
||||
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
|
||||
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", user_id=user_id
|
||||
)
|
||||
|
||||
|
@ -1352,18 +1340,18 @@ class SyncHandler(object):
|
|||
ignored_users = frozenset()
|
||||
|
||||
if since_token:
|
||||
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
|
||||
res = await self._get_rooms_changed(sync_result_builder, ignored_users)
|
||||
room_entries, invited, newly_joined_rooms, newly_left_rooms = res
|
||||
|
||||
tags_by_room = yield self.store.get_updated_tags(
|
||||
tags_by_room = await self.store.get_updated_tags(
|
||||
user_id, since_token.account_data_key
|
||||
)
|
||||
else:
|
||||
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
|
||||
res = await self._get_all_rooms(sync_result_builder, ignored_users)
|
||||
room_entries, invited, newly_joined_rooms = res
|
||||
newly_left_rooms = []
|
||||
|
||||
tags_by_room = yield self.store.get_tags_for_user(user_id)
|
||||
tags_by_room = await self.store.get_tags_for_user(user_id)
|
||||
|
||||
def handle_room_entries(room_entry):
|
||||
return self._generate_room_entry(
|
||||
|
@ -1376,7 +1364,7 @@ class SyncHandler(object):
|
|||
always_include=sync_result_builder.full_state,
|
||||
)
|
||||
|
||||
yield concurrently_execute(handle_room_entries, room_entries, 10)
|
||||
await concurrently_execute(handle_room_entries, room_entries, 10)
|
||||
|
||||
sync_result_builder.invited.extend(invited)
|
||||
|
||||
|
@ -1410,8 +1398,7 @@ class SyncHandler(object):
|
|||
newly_left_users,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _have_rooms_changed(self, sync_result_builder):
|
||||
async def _have_rooms_changed(self, sync_result_builder):
|
||||
"""Returns whether there may be any new events that should be sent down
|
||||
the sync. Returns True if there are.
|
||||
"""
|
||||
|
@ -1422,7 +1409,7 @@ class SyncHandler(object):
|
|||
assert since_token
|
||||
|
||||
# Get a list of membership change events that have happened.
|
||||
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||
rooms_changed = await self.store.get_membership_changes_for_user(
|
||||
user_id, since_token.room_key, now_token.room_key
|
||||
)
|
||||
|
||||
|
@ -1435,8 +1422,7 @@ class SyncHandler(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_rooms_changed(self, sync_result_builder, ignored_users):
|
||||
async def _get_rooms_changed(self, sync_result_builder, ignored_users):
|
||||
"""Gets the the changes that have happened since the last sync.
|
||||
|
||||
Args:
|
||||
|
@ -1461,7 +1447,7 @@ class SyncHandler(object):
|
|||
assert since_token
|
||||
|
||||
# Get a list of membership change events that have happened.
|
||||
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||
rooms_changed = await self.store.get_membership_changes_for_user(
|
||||
user_id, since_token.room_key, now_token.room_key
|
||||
)
|
||||
|
||||
|
@ -1499,11 +1485,11 @@ class SyncHandler(object):
|
|||
continue
|
||||
|
||||
if room_id in sync_result_builder.joined_room_ids or has_join:
|
||||
old_state_ids = yield self.get_state_at(room_id, since_token)
|
||||
old_state_ids = await self.get_state_at(room_id, since_token)
|
||||
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
|
||||
old_mem_ev = None
|
||||
if old_mem_ev_id:
|
||||
old_mem_ev = yield self.store.get_event(
|
||||
old_mem_ev = await self.store.get_event(
|
||||
old_mem_ev_id, allow_none=True
|
||||
)
|
||||
|
||||
|
@ -1536,13 +1522,13 @@ class SyncHandler(object):
|
|||
newly_left_rooms.append(room_id)
|
||||
else:
|
||||
if not old_state_ids:
|
||||
old_state_ids = yield self.get_state_at(room_id, since_token)
|
||||
old_state_ids = await self.get_state_at(room_id, since_token)
|
||||
old_mem_ev_id = old_state_ids.get(
|
||||
(EventTypes.Member, user_id), None
|
||||
)
|
||||
old_mem_ev = None
|
||||
if old_mem_ev_id:
|
||||
old_mem_ev = yield self.store.get_event(
|
||||
old_mem_ev = await self.store.get_event(
|
||||
old_mem_ev_id, allow_none=True
|
||||
)
|
||||
if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
|
||||
|
@ -1566,7 +1552,7 @@ class SyncHandler(object):
|
|||
|
||||
if leave_events:
|
||||
leave_event = leave_events[-1]
|
||||
leave_stream_token = yield self.store.get_stream_token_for_event(
|
||||
leave_stream_token = await self.store.get_stream_token_for_event(
|
||||
leave_event.event_id
|
||||
)
|
||||
leave_token = since_token.copy_and_replace(
|
||||
|
@ -1603,7 +1589,7 @@ class SyncHandler(object):
|
|||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||
|
||||
# Get all events for rooms we're currently joined to.
|
||||
room_to_events = yield self.store.get_room_events_stream_for_rooms(
|
||||
room_to_events = await self.store.get_room_events_stream_for_rooms(
|
||||
room_ids=sync_result_builder.joined_room_ids,
|
||||
from_key=since_token.room_key,
|
||||
to_key=now_token.room_key,
|
||||
|
@ -1652,8 +1638,7 @@ class SyncHandler(object):
|
|||
|
||||
return room_entries, invited, newly_joined_rooms, newly_left_rooms
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_all_rooms(self, sync_result_builder, ignored_users):
|
||||
async def _get_all_rooms(self, sync_result_builder, ignored_users):
|
||||
"""Returns entries for all rooms for the user.
|
||||
|
||||
Args:
|
||||
|
@ -1677,7 +1662,7 @@ class SyncHandler(object):
|
|||
Membership.BAN,
|
||||
)
|
||||
|
||||
room_list = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
room_list = await self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user_id, membership_list=membership_list
|
||||
)
|
||||
|
||||
|
@ -1700,7 +1685,7 @@ class SyncHandler(object):
|
|||
elif event.membership == Membership.INVITE:
|
||||
if event.sender in ignored_users:
|
||||
continue
|
||||
invite = yield self.store.get_event(event.event_id)
|
||||
invite = await self.store.get_event(event.event_id)
|
||||
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
|
||||
elif event.membership in (Membership.LEAVE, Membership.BAN):
|
||||
# Always send down rooms we were banned or kicked from.
|
||||
|
@ -1726,8 +1711,7 @@ class SyncHandler(object):
|
|||
|
||||
return room_entries, invited, []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_room_entry(
|
||||
async def _generate_room_entry(
|
||||
self,
|
||||
sync_result_builder,
|
||||
ignored_users,
|
||||
|
@ -1769,7 +1753,7 @@ class SyncHandler(object):
|
|||
since_token = room_builder.since_token
|
||||
upto_token = room_builder.upto_token
|
||||
|
||||
batch = yield self._load_filtered_recents(
|
||||
batch = await self._load_filtered_recents(
|
||||
room_id,
|
||||
sync_config,
|
||||
now_token=upto_token,
|
||||
|
@ -1796,7 +1780,7 @@ class SyncHandler(object):
|
|||
# tag was added by synapse e.g. for server notice rooms.
|
||||
if full_state:
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
tags = yield self.store.get_tags_for_room(user_id, room_id)
|
||||
tags = await self.store.get_tags_for_room(user_id, room_id)
|
||||
|
||||
# If there aren't any tags, don't send the empty tags list down
|
||||
# sync
|
||||
|
@ -1821,7 +1805,7 @@ class SyncHandler(object):
|
|||
):
|
||||
return
|
||||
|
||||
state = yield self.compute_state_delta(
|
||||
state = await self.compute_state_delta(
|
||||
room_id, batch, sync_config, since_token, now_token, full_state=full_state
|
||||
)
|
||||
|
||||
|
@ -1844,7 +1828,7 @@ class SyncHandler(object):
|
|||
)
|
||||
or since_token is None
|
||||
):
|
||||
summary = yield self.compute_summary(
|
||||
summary = await self.compute_summary(
|
||||
room_id, sync_config, batch, state, now_token
|
||||
)
|
||||
|
||||
|
@ -1861,7 +1845,7 @@ class SyncHandler(object):
|
|||
)
|
||||
|
||||
if room_sync or always_include:
|
||||
notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
|
||||
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
|
||||
|
||||
if notifs is not None:
|
||||
unread_notifications["notification_count"] = notifs["notify_count"]
|
||||
|
@ -1887,8 +1871,7 @@ class SyncHandler(object):
|
|||
else:
|
||||
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_user_at(self, user_id, stream_ordering):
|
||||
async def get_rooms_for_user_at(self, user_id, stream_ordering):
|
||||
"""Get set of joined rooms for a user at the given stream ordering.
|
||||
|
||||
The stream ordering *must* be recent, otherwise this may throw an
|
||||
|
@ -1903,7 +1886,7 @@ class SyncHandler(object):
|
|||
Deferred[frozenset[str]]: Set of room_ids the user is in at given
|
||||
stream_ordering.
|
||||
"""
|
||||
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
|
||||
joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
|
||||
|
||||
joined_room_ids = set()
|
||||
|
||||
|
@ -1921,10 +1904,10 @@ class SyncHandler(object):
|
|||
|
||||
logger.info("User joined room after current token: %s", room_id)
|
||||
|
||||
extrems = yield self.store.get_forward_extremeties_for_room(
|
||||
extrems = await self.store.get_forward_extremeties_for_room(
|
||||
room_id, stream_ordering
|
||||
)
|
||||
users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
|
||||
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
|
||||
if user_id in users_in_room:
|
||||
joined_room_ids.add(room_id)
|
||||
|
||||
|
|
|
@ -304,8 +304,7 @@ class Notifier(object):
|
|||
without waking up any of the normal user event streams"""
|
||||
self.notify_replication()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(
|
||||
async def wait_for_events(
|
||||
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
|
||||
):
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
|
@ -313,9 +312,9 @@ class Notifier(object):
|
|||
"""
|
||||
user_stream = self.user_to_user_stream.get(user_id)
|
||||
if user_stream is None:
|
||||
current_token = yield self.event_sources.get_current_token()
|
||||
current_token = await self.event_sources.get_current_token()
|
||||
if room_ids is None:
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
user_stream = _NotifierUserStream(
|
||||
user_id=user_id,
|
||||
rooms=room_ids,
|
||||
|
@ -344,11 +343,11 @@ class Notifier(object):
|
|||
self.hs.get_reactor(),
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
yield listener.deferred
|
||||
await listener.deferred
|
||||
|
||||
current_token = user_stream.current_token
|
||||
|
||||
result = yield callback(prev_token, current_token)
|
||||
result = await callback(prev_token, current_token)
|
||||
if result:
|
||||
break
|
||||
|
||||
|
@ -364,12 +363,11 @@ class Notifier(object):
|
|||
# This happened if there was no timeout or if the timeout had
|
||||
# already expired.
|
||||
current_token = user_stream.current_token
|
||||
result = yield callback(prev_token, current_token)
|
||||
result = await callback(prev_token, current_token)
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_for(
|
||||
async def get_events_for(
|
||||
self,
|
||||
user,
|
||||
pagination_config,
|
||||
|
@ -391,15 +389,14 @@ class Notifier(object):
|
|||
"""
|
||||
from_token = pagination_config.from_token
|
||||
if not from_token:
|
||||
from_token = yield self.event_sources.get_current_token()
|
||||
from_token = await self.event_sources.get_current_token()
|
||||
|
||||
limit = pagination_config.limit
|
||||
|
||||
room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
|
||||
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
|
||||
is_peeking = not is_joined
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_for_updates(before_token, after_token):
|
||||
async def check_for_updates(before_token, after_token):
|
||||
if not after_token.is_after(before_token):
|
||||
return EventStreamResult([], (from_token, from_token))
|
||||
|
||||
|
@ -415,7 +412,7 @@ class Notifier(object):
|
|||
if only_keys and name not in only_keys:
|
||||
continue
|
||||
|
||||
new_events, new_key = yield source.get_new_events(
|
||||
new_events, new_key = await source.get_new_events(
|
||||
user=user,
|
||||
from_key=getattr(from_token, keyname),
|
||||
limit=limit,
|
||||
|
@ -425,7 +422,7 @@ class Notifier(object):
|
|||
)
|
||||
|
||||
if name == "room":
|
||||
new_events = yield filter_events_for_client(
|
||||
new_events = await filter_events_for_client(
|
||||
self.storage,
|
||||
user.to_string(),
|
||||
new_events,
|
||||
|
@ -461,7 +458,7 @@ class Notifier(object):
|
|||
user_id_for_stream,
|
||||
)
|
||||
|
||||
result = yield self.wait_for_events(
|
||||
result = await self.wait_for_events(
|
||||
user_id_for_stream,
|
||||
timeout,
|
||||
check_for_updates,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
|
@ -64,12 +65,22 @@ def measure_func(name=None):
|
|||
def wrapper(func):
|
||||
block_name = func.__name__ if name is None else name
|
||||
|
||||
@wraps(func)
|
||||
@defer.inlineCallbacks
|
||||
def measured_func(self, *args, **kwargs):
|
||||
with Measure(self.clock, block_name):
|
||||
r = yield func(self, *args, **kwargs)
|
||||
return r
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def measured_func(self, *args, **kwargs):
|
||||
with Measure(self.clock, block_name):
|
||||
r = await func(self, *args, **kwargs)
|
||||
return r
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
@defer.inlineCallbacks
|
||||
def measured_func(self, *args, **kwargs):
|
||||
with Measure(self.clock, block_name):
|
||||
r = yield func(self, *args, **kwargs)
|
||||
return r
|
||||
|
||||
return measured_func
|
||||
|
||||
|
|
|
@ -12,54 +12,53 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, ResourceLimitError
|
||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
|
||||
from synapse.handlers.sync import SyncConfig, SyncHandler
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
from synapse.types import UserID
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class SyncTestCase(tests.unittest.TestCase):
|
||||
class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||
""" Tests Sync Handler. """
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup)
|
||||
self.sync_handler = SyncHandler(self.hs)
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.hs = hs
|
||||
self.sync_handler = self.hs.get_sync_handler()
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||
|
||||
user_id1 = "@user1:server"
|
||||
user_id2 = "@user2:server"
|
||||
sync_config = self._generate_sync_config(user_id1)
|
||||
|
||||
self.reactor.advance(100) # So we get not 0 time
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.max_mau_value = 1
|
||||
|
||||
# Check that the happy case does not throw errors
|
||||
yield self.store.upsert_monthly_active_user(user_id1)
|
||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||
self.get_success(self.store.upsert_monthly_active_user(user_id1))
|
||||
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
|
||||
|
||||
# Test that global lock works
|
||||
self.hs.config.hs_disabled = True
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
e = self.get_failure(
|
||||
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
|
||||
)
|
||||
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
self.hs.config.hs_disabled = False
|
||||
|
||||
sync_config = self._generate_sync_config(user_id2)
|
||||
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
e = self.get_failure(
|
||||
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
|
||||
)
|
||||
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def _generate_sync_config(self, user_id):
|
||||
return SyncConfig(
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import gc
|
||||
import hashlib
|
||||
import hmac
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
@ -25,7 +26,7 @@ from mock import Mock
|
|||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet.defer import Deferred, succeed
|
||||
from twisted.internet.defer import Deferred, ensureDeferred, succeed
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase):
|
|||
self.reactor.pump([by] * 100)
|
||||
|
||||
def get_success(self, d, by=0.0):
|
||||
if inspect.isawaitable(d):
|
||||
d = ensureDeferred(d)
|
||||
if not isinstance(d, Deferred):
|
||||
return d
|
||||
self.pump(by=by)
|
||||
|
@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase):
|
|||
"""
|
||||
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
|
||||
"""
|
||||
if inspect.isawaitable(d):
|
||||
d = ensureDeferred(d)
|
||||
if not isinstance(d, Deferred):
|
||||
return d
|
||||
self.pump()
|
||||
|
|
Loading…
Reference in a new issue