Convert federation backfill to async

PaginationHandler.get_messages is only called by RoomMessageListRestServlet,
which is async.

Chase the code path down from there:
 - FederationHandler.maybe_backfill (and nested try_backfill)
 - FederationHandler.backfill
This commit is contained in:
Richard van der Hoff 2019-12-10 16:54:34 +00:00
parent 7c429f92d6
commit 7712e751b8
2 changed files with 35 additions and 39 deletions

View file

@ -756,8 +756,7 @@ class FederationHandler(BaseHandler):
yield self.user_joined_room(user, room_id) yield self.user_joined_room(user, room_id)
@log_function @log_function
@defer.inlineCallbacks async def backfill(self, dest, room_id, limit, extremities):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -774,9 +773,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
events = yield self.federation_client.backfill( events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities dest, room_id, limit=limit, extremities=extremities
) )
@ -791,7 +790,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev) # self._sanity_check_event(ev)
# Don't bother processing events we already have. # Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline( seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events) set(e.event_id for e in events)
) )
@ -814,7 +813,7 @@ class FederationHandler(BaseHandler):
state_events = {} state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self._get_state_for_room( state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id destination=dest, room_id=room_id, event_id=e_id
) )
auth_events.update({a.event_id: a for a in auth}) auth_events.update({a.event_id: a for a in auth})
@ -839,7 +838,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events. # We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch: while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth) logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events) auth_events.update(ret_events)
required_auth.update( required_auth.update(
@ -853,7 +852,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch, missing_auth - failed_to_fetch,
) )
results = yield make_deferred_yieldable( results = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -880,7 +879,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events) failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_seen_events( seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
@ -942,7 +941,7 @@ class FederationHandler(BaseHandler):
) )
) )
yield self._handle_new_events(dest, ev_infos, backfilled=True) await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one # Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -958,16 +957,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the # We store these one at a time since each event depends on the
# previous to work out the state. # previous to work out the state.
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
yield self._handle_new_event(dest, event, backfilled=True) await self._handle_new_event(dest, event, backfilled=True)
return events return events
@defer.inlineCallbacks async def maybe_backfill(self, room_id, current_depth):
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating, """Checks the database to see if we should backfill before paginating,
and if so do. and if so do.
""" """
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities: if not extremities:
logger.debug("Not backfilling as no extremeties found.") logger.debug("Not backfilling as no extremeties found.")
@ -999,9 +997,9 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event # state *before* the event, ignoring the special casing certain event
# types have. # types have.
forward_events = yield self.store.get_successor_events(list(extremities)) forward_events = await self.store.get_successor_events(list(extremities))
extremities_events = yield self.store.get_events( extremities_events = await self.store.get_events(
forward_events, forward_events,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
@ -1009,7 +1007,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false # We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = yield filter_events_for_server( filtered_extremities = await filter_events_for_server(
self.storage, self.storage,
self.server_name, self.server_name,
list(extremities_events.values()), list(extremities_events.values()),
@ -1039,7 +1037,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room # First we try hosts that are already in the room
# TODO: HEURISTIC ALERT. # TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id) curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state):
"""Get joined domains from state """Get joined domains from state
@ -1078,12 +1076,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name
] ]
@defer.inlineCallbacks async def try_backfill(domains):
def try_backfill(domains):
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
yield self.backfill( await self.backfill(
dom, room_id, limit=100, extremities=extremities dom, room_id, limit=100, extremities=extremities
) )
# If this succeeded then we probably already have the # If this succeeded then we probably already have the
@ -1114,7 +1111,7 @@ class FederationHandler(BaseHandler):
return False return False
success = yield try_backfill(likely_domains) success = await try_backfill(likely_domains)
if success: if success:
return True return True
@ -1128,7 +1125,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = yield make_deferred_yieldable( states = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
) )
@ -1138,7 +1135,7 @@ class FederationHandler(BaseHandler):
# event_ids. # event_ids.
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)], [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False, get_prev_content=False,
) )
@ -1154,7 +1151,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill( success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains] [dom for dom, _ in likely_domains if dom not in tried_domains]
) )
if success: if success:

View file

@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id) await self.storage.purge_events.purge_room(room_id)
@defer.inlineCallbacks async def get_messages(
def get_messages(
self, self,
requester, requester,
room_id=None, room_id=None,
@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token_for_pagination() await self.hs.get_event_sources().get_current_token_for_pagination()
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)): with (await self.pagination_lock.read(room_id)):
( (
membership, membership,
member_event_id, member_event_id,
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b": if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological: if room_token.topological:
max_topo = room_token.topological max_topo = room_token.topological
else: else:
max_topo = yield self.store.get_max_topological_token( max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream room_id, room_token.stream
) )
@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the # they left the room, to save the effort of loading from the
# database. # database.
leave_token = yield self.store.get_topological_token_for_event( leave_token = await self.store.get_topological_token_for_event(
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo room_id, max_topo
) )
events, next_key = yield self.store.paginate_room_events( events, next_key = await self.store.paginate_room_events(
room_id=room_id, room_id=room_id,
from_key=source_config.from_key, from_key=source_config.from_key,
to_key=source_config.to_key, to_key=source_config.to_key,
@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter: if event_filter:
events = event_filter.filter(events) events = event_filter.filter(events)
events = yield filter_events_for_client( events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None) self.storage, user_id, events, is_peeking=(member_event_id is None)
) )
@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events (EventTypes.Member, event.sender) for event in events
) )
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter events[0].event_id, state_filter=state_filter
) )
if state_ids: if state_ids:
state = yield self.store.get_events(list(state_ids.values())) state = await self.store.get_events(list(state_ids.values()))
state = state.values() state = state.values()
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": ( "chunk": (
yield self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event events, time_now, as_client_event=as_client_event
) )
), ),
@ -406,7 +405,7 @@ class PaginationHandler(object):
} }
if state: if state:
chunk["state"] = yield self._event_serializer.serialize_events( chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event state, time_now, as_client_event=as_client_event
) )