Port PresenceHandler to async/await (#6991)

This commit is contained in:
Erik Johnston 2020-02-26 15:33:26 +00:00 committed by GitHub
parent 7728d87fd7
commit 1f773eec91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 115 deletions

1
changelog.d/6991.misc Normal file
View file

@ -0,0 +1 @@
Port `synapse.handlers.presence` to async/await.

View file

@ -1016,11 +1016,10 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
@defer.inlineCallbacks
def _bump_active_time(self, user):
async def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
await presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")

View file

@ -24,11 +24,12 @@ The methods that define policy are:
import logging
from contextlib import contextmanager
from typing import Dict, Set
from typing import Dict, List, Set
from six import iteritems, itervalues
from prometheus_client import Counter
from typing_extensions import ContextManager
from twisted.internet import defer
@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
MYPY = False
if MYPY:
import synapse.server
logger = logging.getLogger(__name__)
@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.clock = hs.get_clock()
@ -150,7 +154,7 @@ class PresenceHandler(object):
# Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted
self.unpersisted_users_changes = set()
self.unpersisted_users_changes = set() # type: Set[str]
hs.get_reactor().addSystemEventTrigger(
"before",
@ -160,12 +164,11 @@ class PresenceHandler(object):
self._on_shutdown,
)
self.serial_to_user = {}
self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self.user_to_num_current_syncs = {}
self.user_to_num_current_syncs = {} # type: Dict[str, int]
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
@ -213,8 +216,7 @@ class PresenceHandler(object):
self._event_pos = self.store.get_current_events_token()
self._event_processing = False
@defer.inlineCallbacks
def _on_shutdown(self):
async def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that
we haven't yet persisted, e.g. updates that only changes some internal
timers. This allows changes to persist across startup without having to
@ -235,7 +237,7 @@ class PresenceHandler(object):
if self.unpersisted_users_changes:
yield self.store.update_presence(
await self.store.update_presence(
[
self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes
@ -243,8 +245,7 @@ class PresenceHandler(object):
)
logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
async def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
@ -253,12 +254,11 @@ class PresenceHandler(object):
if unpersisted:
logger.info("Persisting %d unpersisted presence updates", len(unpersisted))
yield self.store.update_presence(
await self.store.update_presence(
[self.user_to_current_state[user_id] for user_id in unpersisted]
)
@defer.inlineCallbacks
def _update_states(self, new_states):
async def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
should be sent to clients/servers.
@ -267,7 +267,7 @@ class PresenceHandler(object):
with Measure(self.clock, "presence_update_states"):
# NOTE: We purposefully don't yield between now and when we've
# NOTE: We purposefully don't await between now and when we've
# calculated what we want to do with the new states, to avoid races.
to_notify = {} # Changes we want to notify everyone about
@ -311,7 +311,7 @@ class PresenceHandler(object):
if to_notify:
notified_presence_counter.inc(len(to_notify))
yield self._persist_and_notify(list(to_notify.values()))
await self._persist_and_notify(list(to_notify.values()))
self.unpersisted_users_changes |= {s.user_id for s in new_states}
self.unpersisted_users_changes -= set(to_notify.keys())
@ -326,7 +326,7 @@ class PresenceHandler(object):
self._push_to_remotes(to_federation_ping.values())
def _handle_timeouts(self):
async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
appropriate.
"""
@ -368,10 +368,9 @@ class PresenceHandler(object):
now=now,
)
return self._update_states(changes)
return await self._update_states(changes)
@defer.inlineCallbacks
def bump_presence_active_time(self, user):
async def bump_presence_active_time(self, user):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@ -383,16 +382,17 @@ class PresenceHandler(object):
bump_active_time_counter.inc()
prev_state = yield self.current_state_for_user(user_id)
prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
await self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence=True):
async def user_syncing(
self, user_id: str, affect_presence: bool = True
) -> ContextManager[None]:
"""Returns a context manager that should surround any stream requests
from the user.
@ -415,11 +415,11 @@ class PresenceHandler(object):
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_state = yield self.current_state_for_user(user_id)
prev_state = await self.current_state_for_user(user_id)
if prev_state.state == PresenceState.OFFLINE:
# If they're currently offline then bring them online, otherwise
# just update the last sync times.
yield self._update_states(
await self._update_states(
[
prev_state.copy_and_replace(
state=PresenceState.ONLINE,
@ -429,7 +429,7 @@ class PresenceHandler(object):
]
)
else:
yield self._update_states(
await self._update_states(
[
prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec()
@ -437,13 +437,12 @@ class PresenceHandler(object):
]
)
@defer.inlineCallbacks
def _end():
async def _end():
try:
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
yield self._update_states(
prev_state = await self.current_state_for_user(user_id)
await self._update_states(
[
prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec()
@ -480,8 +479,7 @@ class PresenceHandler(object):
else:
return set()
@defer.inlineCallbacks
def update_external_syncs_row(
async def update_external_syncs_row(
self, process_id, user_id, is_syncing, sync_time_msec
):
"""Update the syncing users for an external process as a delta.
@ -494,8 +492,8 @@ class PresenceHandler(object):
is_syncing (bool): Whether or not the user is now syncing
sync_time_msec(int): Time in ms when the user was last syncing
"""
with (yield self.external_sync_linearizer.queue(process_id)):
prev_state = yield self.current_state_for_user(user_id)
with (await self.external_sync_linearizer.queue(process_id)):
prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault(
process_id, set()
@ -525,25 +523,24 @@ class PresenceHandler(object):
process_presence.discard(user_id)
if updates:
yield self._update_states(updates)
await self._update_states(updates)
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
@defer.inlineCallbacks
def update_external_syncs_clear(self, process_id):
async def update_external_syncs_clear(self, process_id):
"""Marks all users that had been marked as syncing by a given process
as offline.
Used when the process has stopped/disappeared.
"""
with (yield self.external_sync_linearizer.queue(process_id)):
with (await self.external_sync_linearizer.queue(process_id)):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
prev_states = yield self.current_state_for_users(process_presence)
prev_states = await self.current_state_for_users(process_presence)
time_now_ms = self.clock.time_msec()
yield self._update_states(
await self._update_states(
[
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
for prev_state in itervalues(prev_states)
@ -551,15 +548,13 @@ class PresenceHandler(object):
)
self.external_process_last_updated_ms.pop(process_id, None)
@defer.inlineCallbacks
def current_state_for_user(self, user_id):
async def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
"""
res = yield self.current_state_for_users([user_id])
res = await self.current_state_for_users([user_id])
return res[user_id]
@defer.inlineCallbacks
def current_state_for_users(self, user_ids):
async def current_state_for_users(self, user_ids):
"""Get the current presence state for multiple users.
Returns:
@ -574,7 +569,7 @@ class PresenceHandler(object):
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
res = await self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state]
@ -587,14 +582,13 @@ class PresenceHandler(object):
return states
@defer.inlineCallbacks
def _persist_and_notify(self, states):
async def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to
interested remote servers
"""
stream_id, max_token = yield self.store.update_presence(states)
stream_id, max_token = await self.store.update_presence(states)
parties = yield get_interested_parties(self.store, states)
parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@ -606,9 +600,8 @@ class PresenceHandler(object):
self._push_to_remotes(states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield get_interested_parties(self.store, [state])
async def notify_for_states(self, state, stream_id):
parties = await get_interested_parties(self.store, [state])
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@ -626,8 +619,7 @@ class PresenceHandler(object):
"""
self.federation.send_presence(states)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server.
"""
now = self.clock.time_msec()
@ -670,21 +662,19 @@ class PresenceHandler(object):
new_fields["status_msg"] = push.get("status_msg", None)
new_fields["currently_active"] = push.get("currently_active", False)
prev_state = yield self.current_state_for_user(user_id)
prev_state = await self.current_state_for_user(user_id)
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
federation_presence_counter.inc(len(updates))
yield self._update_states(updates)
await self._update_states(updates)
@defer.inlineCallbacks
def get_state(self, target_user, as_event=False):
results = yield self.get_states([target_user.to_string()], as_event=as_event)
async def get_state(self, target_user, as_event=False):
results = await self.get_states([target_user.to_string()], as_event=as_event)
return results[0]
@defer.inlineCallbacks
def get_states(self, target_user_ids, as_event=False):
async def get_states(self, target_user_ids, as_event=False):
"""Get the presence state for users.
Args:
@ -695,7 +685,7 @@ class PresenceHandler(object):
list
"""
updates = yield self.current_state_for_users(target_user_ids)
updates = await self.current_state_for_users(target_user_ids)
updates = list(updates.values())
for user_id in set(target_user_ids) - {u.user_id for u in updates}:
@ -713,8 +703,7 @@ class PresenceHandler(object):
else:
return updates
@defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False):
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@ -730,7 +719,7 @@ class PresenceHandler(object):
user_id = target_user.to_string()
prev_state = yield self.current_state_for_user(user_id)
prev_state = await self.current_state_for_user(user_id)
new_fields = {"state": presence}
@ -741,16 +730,15 @@ class PresenceHandler(object):
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
await self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks
def is_visible(self, observed_user, observer_user):
async def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
"""
observer_room_ids = yield self.store.get_rooms_for_user(
observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string()
)
observed_room_ids = yield self.store.get_rooms_for_user(
observed_room_ids = await self.store.get_rooms_for_user(
observed_user.to_string()
)
@ -759,8 +747,7 @@ class PresenceHandler(object):
return False
@defer.inlineCallbacks
def get_all_presence_updates(self, last_id, current_id):
async def get_all_presence_updates(self, last_id, current_id):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@ -775,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = yield self.store.get_all_presence_updates(last_id, current_id)
rows = await self.store.get_all_presence_updates(last_id, current_id)
return rows
def notify_new_event(self):
@ -786,20 +773,18 @@ class PresenceHandler(object):
if self._event_processing:
return
@defer.inlineCallbacks
def _process_presence():
async def _process_presence():
assert not self._event_processing
self._event_processing = True
try:
yield self._unsafe_process()
await self._unsafe_process()
finally:
self._event_processing = False
run_as_background_process("presence.notify_new_event", _process_presence)
@defer.inlineCallbacks
def _unsafe_process(self):
async def _unsafe_process(self):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
@ -812,10 +797,10 @@ class PresenceHandler(object):
self._event_pos,
room_max_stream_ordering,
)
max_pos, deltas = yield self.store.get_current_state_deltas(
max_pos, deltas = await self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
yield self._handle_state_delta(deltas)
await self._handle_state_delta(deltas)
self._event_pos = max_pos
@ -824,8 +809,7 @@ class PresenceHandler(object):
max_pos
)
@defer.inlineCallbacks
def _handle_state_delta(self, deltas):
async def _handle_state_delta(self, deltas):
"""Process current state deltas to find new joins that need to be
handled.
"""
@ -846,13 +830,13 @@ class PresenceHandler(object):
# joins.
continue
event = yield self.store.get_event(event_id, allow_none=True)
event = await self.store.get_event(event_id, allow_none=True)
if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins
continue
if prev_event_id:
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if (
prev_event
and prev_event.content.get("membership") == Membership.JOIN
@ -860,10 +844,9 @@ class PresenceHandler(object):
# Ignore changes to join events.
continue
yield self._on_user_joined_room(room_id, state_key)
await self._on_user_joined_room(room_id, state_key)
@defer.inlineCallbacks
def _on_user_joined_room(self, room_id, user_id):
async def _on_user_joined_room(self, room_id, user_id):
"""Called when we detect a user joining the room via the current state
delta stream.
@ -882,8 +865,8 @@ class PresenceHandler(object):
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
state = yield self.current_state_for_user(user_id)
hosts = yield self.state.get_current_hosts_in_room(room_id)
state = await self.current_state_for_user(user_id)
hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
hosts = {host for host in hosts if host != self.server_name}
@ -903,10 +886,10 @@ class PresenceHandler(object):
# TODO: Check that this is actually a new server joining the
# room.
user_ids = yield self.state.get_current_users_in_room(room_id)
user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
states = await self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@ -996,9 +979,8 @@ class PresenceEventSource(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@defer.inlineCallbacks
@log_function
def get_new_events(
async def get_new_events(
self,
user,
from_key,
@ -1045,7 +1027,7 @@ class PresenceEventSource(object):
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
users_interested_in = yield self._get_interested_in(user, explicit_room_id)
users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set()
changed = None
@ -1071,7 +1053,7 @@ class PresenceEventSource(object):
else:
user_ids_changed = users_interested_in
updates = yield presence.current_state_for_users(user_ids_changed)
updates = await presence.current_state_for_users(user_ids_changed)
if include_offline:
return (list(updates.values()), max_token)
@ -1084,11 +1066,11 @@ class PresenceEventSource(object):
def get_current_key(self):
return self.store.get_current_presence_token()
def get_pagination_rows(self, user, pagination_config, key):
return self.get_new_events(user, from_key=None, include_offline=False)
async def get_pagination_rows(self, user, pagination_config, key):
return await self.get_new_events(user, from_key=None, include_offline=False)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_interested_in(self, user, explicit_room_id, cache_context):
@cached(num_args=2, cache_context=True)
async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
updates for
"""
@ -1096,13 +1078,13 @@ class PresenceEventSource(object):
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(
user_ids = await self.store.get_users_in_room(
explicit_room_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(user_ids)
@ -1277,8 +1259,8 @@ def get_interested_parties(store, states):
2-tuple: `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
room_ids_to_states = {}
users_to_states = {}
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
users_to_states = {} # type: Dict[str, List[UserPresenceState]]
for state in states:
room_ids = yield store.get_rooms_for_user(state.user_id)
for room_id in room_ids:

View file

@ -323,7 +323,11 @@ class ReplicationStreamer(object):
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
self.presence_handler.update_external_syncs_clear(connection.conn_id)
run_as_background_process(
"update_external_syncs_clear",
self.presence_handler.update_external_syncs_clear,
connection.conn_id,
)
def _batch_updates(updates):

View file

@ -3,6 +3,7 @@ import twisted.internet
import synapse.api.auth
import synapse.config.homeserver
import synapse.crypto.keyring
import synapse.federation.federation_server
import synapse.federation.sender
import synapse.federation.transport.client
import synapse.handlers
@ -107,5 +108,9 @@ class HomeServer(object):
self,
) -> synapse.replication.tcp.client.ReplicationClientHandler:
pass
def get_federation_registry(
self,
) -> synapse.federation.federation_server.FederationHandlerRegistry:
pass
def is_mine_id(self, domain_id: str) -> bool:
pass

View file

@ -494,9 +494,11 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, "@test2:server")
# Mark test2 as online, test will be offline with a last_active of 0
self.get_success(
self.presence_handler.set_state(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
)
)
self.reactor.pump([0]) # Wait for presence updates to be handled
#
@ -543,15 +545,19 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id)
# Mark test as online
self.get_success(
self.presence_handler.set_state(
UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
)
)
# Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet
self.get_success(
self.presence_handler.set_state(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
)
)
# Add servers to the room
self._add_new_user(room_id, "@alice:server2")

View file

@ -183,6 +183,7 @@ commands = mypy \
synapse/events/spamcheck.py \
synapse/federation/sender \
synapse/federation/transport \
synapse/handlers/presence.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
synapse/logging/ \