0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 03:41:54 +01:00

Add presence federation stream (#9819)

This commit is contained in:
Erik Johnston 2021-04-20 14:11:24 +01:00 committed by GitHub
parent db70435de7
commit de0d088adc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 426 additions and 31 deletions

1
changelog.d/9819.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for handling presence on a worker.

View file

@ -24,6 +24,7 @@ The methods that define policy are:
import abc import abc
import contextlib import contextlib
import logging import logging
from bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -53,7 +54,9 @@ from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime, ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState, ReplicationPresenceSetState,
) )
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
@ -128,10 +131,10 @@ class BasePresenceHandler(abc.ABC):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self._federation = None self._federation = None
if hs.should_send_federation() or not hs.config.worker_app: if hs.should_send_federation():
self._federation = hs.get_federation_sender() self._federation = hs.get_federation_sender()
self._send_federation = hs.should_send_federation() self._federation_queue = PresenceFederationQueue(hs, self)
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
@ -254,9 +257,17 @@ class BasePresenceHandler(abc.ABC):
""" """
pass pass
async def process_replication_rows(self, token, rows): async def process_replication_rows(
"""Process presence stream rows received over replication.""" self, stream_name: str, instance_name: str, token: int, rows: list
pass ):
"""Process streams received over replication."""
await self._federation_queue.process_replication_rows(
stream_name, instance_name, token, rows
)
def get_federation_queue(self) -> "PresenceFederationQueue":
"""Get the presence federation queue."""
return self._federation_queue
async def maybe_send_presence_to_interested_destinations( async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState] self, states: List[UserPresenceState]
@ -266,12 +277,9 @@ class BasePresenceHandler(abc.ABC):
users. users.
""" """
if not self._send_federation: if not self._federation:
return return
# If this worker sends federation we must have a FederationSender.
assert self._federation
states = [s for s in states if self.is_mine_id(s.user_id)] states = [s for s in states if self.is_mine_id(s.user_id)]
if not states: if not states:
@ -427,7 +435,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# If this is a federation sender, notify about presence updates. # If this is a federation sender, notify about presence updates.
await self.maybe_send_presence_to_interested_destinations(states) await self.maybe_send_presence_to_interested_destinations(states)
async def process_replication_rows(self, token, rows): async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
await super().process_replication_rows(stream_name, instance_name, token, rows)
if stream_name != PresenceStream.NAME:
return
states = [ states = [
UserPresenceState( UserPresenceState(
row.user_id, row.user_id,
@ -729,12 +744,10 @@ class PresenceHandler(BasePresenceHandler):
self.state, self.state,
) )
# Since this is master we know that we have a federation sender or
# queue, and so this will be defined.
assert self._federation
for destinations, states in hosts_and_states: for destinations, states in hosts_and_states:
self._federation.send_presence_to_destinations(states, destinations) self._federation_queue.send_presence_to_destinations(
states, destinations
)
async def _handle_timeouts(self): async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
@ -1213,13 +1226,9 @@ class PresenceHandler(BasePresenceHandler):
user_presence_states user_presence_states
) )
# Since this is master we know that we have a federation sender or
# queue, and so this will be defined.
assert self._federation
# Send out user presence updates for each destination # Send out user presence updates for each destination
for destination, user_state_set in presence_destinations.items(): for destination, user_state_set in presence_destinations.items():
self._federation.send_presence_to_destinations( self._federation_queue.send_presence_to_destinations(
destinations=[destination], states=user_state_set destinations=[destination], states=user_state_set
) )
@ -1864,3 +1873,197 @@ async def get_interested_remotes(
hosts_and_states.append(([host], states)) hosts_and_states.append(([host], states))
return hosts_and_states return hosts_and_states
class PresenceFederationQueue:
"""Handles sending ad hoc presence updates over federation, which are *not*
due to state updates (that get handled via the presence stream), e.g.
federation pings and sending existing present states to newly joined hosts.
Only the last N minutes will be queued, so if a federation sender instance
is down for longer then some updates will be dropped. This is OK as presence
is ephemeral, and so it will self correct eventually.
On workers the class tracks the last received position of the stream from
replication, and handles querying for missed updates over HTTP replication,
c.f. `get_current_token` and `get_replication_rows`.
"""
# How long to keep entries in the queue for. Workers that are down for
# longer than this duration will miss out on older updates.
_KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000
# How often to check if we can expire entries from the queue.
_CLEAR_ITEMS_EVERY_MS = 60 * 1000
def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
self._clock = hs.get_clock()
self._notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
self._presence_handler = presence_handler
self._repl_client = ReplicationGetStreamUpdates.make_client(hs)
# Should we keep a queue of recent presence updates? We only bother if
# another process may be handling federation sending.
self._queue_presence_updates = True
# Whether this instance is a presence writer.
self._presence_writer = hs.config.worker.worker_app is None
# The FederationSender instance, if this process sends federation traffic directly.
self._federation = None
if hs.should_send_federation():
self._federation = hs.get_federation_sender()
# We don't bother queuing up presence states if only this instance
# is sending federation.
if hs.config.worker.federation_shard_config.instances == [
self._instance_name
]:
self._queue_presence_updates = False
# The queue of recently queued updates as tuples of: `(timestamp,
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
self._next_id = 1
# Map from instance name to current token
self._current_tokens = {} # type: Dict[str, int]
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
def _clear_queue(self):
"""Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
# The queue is sorted by timestamp, so we can bisect to find the right
# place to purge before. Note that we are searching using a 1-tuple with
# the time, which does The Right Thing since the queue is a tuple where
# the first item is a timestamp.
index = bisect(self._queue, (clear_before,))
self._queue = self._queue[index:]
def send_presence_to_destinations(
self, states: Collection[UserPresenceState], destinations: Collection[str]
) -> None:
"""Send the presence states to the given destinations.
Will forward to the local federation sender (if there is one) and queue
to send over replication (if there are other federation sender instances.).
Must only be called on the master process.
"""
# This should only be called on a presence writer.
assert self._presence_writer
if self._federation:
self._federation.send_presence_to_destinations(
states=states,
destinations=destinations,
)
if not self._queue_presence_updates:
return
now = self._clock.time_msec()
stream_id = self._next_id
self._next_id += 1
self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))
self._notifier.notify_replication()
def get_current_token(self, instance_name: str) -> int:
"""Get the current position of the stream.
On workers this returns the last stream ID received from replication.
"""
if instance_name == self._instance_name:
return self._next_id - 1
else:
return self._current_tokens.get(instance_name, 0)
async def get_replication_rows(
self,
instance_name: str,
from_token: int,
upto_token: int,
target_row_count: int,
) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
"""Get all the updates between the two tokens.
We return rows in the form of `(destination, user_id)` to keep the size
of each row bounded (rather than returning the sets in a row).
On workers this will query the master process via HTTP replication.
"""
if instance_name != self._instance_name:
# If not local we query over http replication from the master
result = await self._repl_client(
instance_name=instance_name,
stream_name=PresenceFederationStream.NAME,
from_token=from_token,
upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]
# We can find the correct position in the queue by noting that there is
# exactly one entry per stream ID, and that the last entry has an ID of
# `self._next_id - 1`, so we can count backwards from the end.
#
# Since the start of the queue is periodically truncated we need to
# handle the case where `from_token` stream ID has already been dropped.
start_idx = max(from_token - self._next_id, -len(self._queue))
to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
limited = False
new_id = upto_token
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
if stream_id > upto_token:
break
new_id = stream_id
to_send.extend(
(stream_id, (destination, user_id))
for destination in destinations
for user_id in user_ids
)
if len(to_send) > target_row_count:
limited = True
break
return to_send, new_id, limited
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
if stream_name != PresenceFederationStream.NAME:
return
# We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect)
self._current_tokens[instance_name] = token
# If we're a federation sender we pull out the presence states to send
# and forward them on.
if not self._federation:
return
hosts_to_users = {} # type: Dict[str, Set[str]]
for row in rows:
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
for host, user_ids in hosts_to_users.items():
states = await self._presence_handler.current_state_for_users(user_ids)
self._federation.send_presence_to_destinations(
states=states.values(),
destinations=[host],
)

View file

@ -29,7 +29,6 @@ from synapse.replication.tcp.streams import (
AccountDataStream, AccountDataStream,
DeviceListsStream, DeviceListsStream,
GroupServerStream, GroupServerStream,
PresenceStream,
PushersStream, PushersStream,
PushRulesStream, PushRulesStream,
ReceiptsStream, ReceiptsStream,
@ -191,8 +190,6 @@ class ReplicationDataHandler:
self.stop_pusher(row.user_id, row.app_id, row.pushkey) self.stop_pusher(row.user_id, row.app_id, row.pushkey)
else: else:
await self.start_pusher(row.user_id, row.app_id, row.pushkey) await self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == PresenceStream.NAME:
await self._presence_handler.process_replication_rows(token, rows)
elif stream_name == EventsStream.NAME: elif stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so # We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows. # we don't need to optimise this for multiple rows.
@ -221,6 +218,10 @@ class ReplicationDataHandler:
membership=row.data.membership, membership=row.data.membership,
) )
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
)
# Notify any waiting deferreds. The list is ordered by position so we # Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is # just iterate through the list until we reach a position that is
# greater than the received row position. # greater than the received row position.

View file

@ -30,6 +30,7 @@ from synapse.replication.tcp.streams._base import (
CachesStream, CachesStream,
DeviceListsStream, DeviceListsStream,
GroupServerStream, GroupServerStream,
PresenceFederationStream,
PresenceStream, PresenceStream,
PublicRoomsStream, PublicRoomsStream,
PushersStream, PushersStream,
@ -50,6 +51,7 @@ STREAMS_MAP = {
EventsStream, EventsStream,
BackfillStream, BackfillStream,
PresenceStream, PresenceStream,
PresenceFederationStream,
TypingStream, TypingStream,
ReceiptsStream, ReceiptsStream,
PushRulesStream, PushRulesStream,
@ -71,6 +73,7 @@ __all__ = [
"Stream", "Stream",
"BackfillStream", "BackfillStream",
"PresenceStream", "PresenceStream",
"PresenceFederationStream",
"TypingStream", "TypingStream",
"ReceiptsStream", "ReceiptsStream",
"PushRulesStream", "PushRulesStream",

View file

@ -290,6 +290,30 @@ class PresenceStream(Stream):
) )
class PresenceFederationStream(Stream):
"""A stream used to send ad hoc presence updates over federation.
Streams the remote destination and the user ID of the presence state to
send.
"""
@attr.s(slots=True, auto_attribs=True)
class PresenceFederationStreamRow:
destination: str
user_id: str
NAME = "presence_federation"
ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"):
federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__(
hs.get_instance_name(),
federation_queue.get_current_token,
federation_queue.get_replication_rows,
)
class TypingStream(Stream): class TypingStream(Stream):
TypingStreamRow = namedtuple( TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str) "TypingStreamRow", ("room_id", "user_ids") # str # list(str)

View file

@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.federation.sender import FederationSender
from synapse.handlers.presence import ( from synapse.handlers.presence import (
EXTERNAL_PROCESS_EXPIRY, EXTERNAL_PROCESS_EXPIRY,
FEDERATION_PING_INTERVAL, FEDERATION_PING_INTERVAL,
@ -471,6 +472,168 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.state, PresenceState.OFFLINE)
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.instance_name = hs.get_instance_name()
self.queue = self.presence_handler.get_federation_queue()
def test_send_and_get(self):
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
prev_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
self.queue.send_presence_to_destinations((state3,), ("dest3",))
now_token = self.queue.get_current_token(self.instance_name)
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
)
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
expected_rows = [
(1, ("dest1", "@user1:test")),
(1, ("dest2", "@user1:test")),
(1, ("dest1", "@user2:test")),
(1, ("dest2", "@user2:test")),
(2, ("dest3", "@user3:test")),
]
self.assertCountEqual(rows, expected_rows)
def test_send_and_get_split(self):
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
prev_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
now_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state3,), ("dest3",))
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
)
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
expected_rows = [
(1, ("dest1", "@user1:test")),
(1, ("dest2", "@user1:test")),
(1, ("dest1", "@user2:test")),
(1, ("dest2", "@user2:test")),
]
self.assertCountEqual(rows, expected_rows)
def test_clear_queue_all(self):
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
prev_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
self.queue.send_presence_to_destinations((state3,), ("dest3",))
self.reactor.advance(10 * 60 * 1000)
now_token = self.queue.get_current_token(self.instance_name)
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
)
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
self.assertCountEqual(rows, [])
prev_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
self.queue.send_presence_to_destinations((state3,), ("dest3",))
now_token = self.queue.get_current_token(self.instance_name)
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
)
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
expected_rows = [
(3, ("dest1", "@user1:test")),
(3, ("dest2", "@user1:test")),
(3, ("dest1", "@user2:test")),
(3, ("dest2", "@user2:test")),
(4, ("dest3", "@user3:test")),
]
self.assertCountEqual(rows, expected_rows)
def test_partially_clear_queue(self):
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
prev_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
self.reactor.advance(2 * 60 * 1000)
self.queue.send_presence_to_destinations((state3,), ("dest3",))
self.reactor.advance(4 * 60 * 1000)
now_token = self.queue.get_current_token(self.instance_name)
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
)
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
expected_rows = [
(2, ("dest3", "@user3:test")),
]
self.assertCountEqual(rows, [])
prev_token = self.queue.get_current_token(self.instance_name)
self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
self.queue.send_presence_to_destinations((state3,), ("dest3",))
now_token = self.queue.get_current_token(self.instance_name)
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
)
self.assertEqual(upto_token, now_token)
self.assertFalse(limited)
expected_rows = [
(3, ("dest1", "@user1:test")),
(3, ("dest2", "@user1:test")),
(3, ("dest1", "@user2:test")),
(3, ("dest2", "@user2:test")),
(4, ("dest3", "@user3:test")),
]
self.assertCountEqual(rows, expected_rows)
class PresenceJoinTestCase(unittest.HomeserverTestCase): class PresenceJoinTestCase(unittest.HomeserverTestCase):
"""Tests remote servers get told about presence of users in the room when """Tests remote servers get told about presence of users in the room when
they join and when new local users join. they join and when new local users join.
@ -482,10 +645,17 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"server", federation_http_client=None, federation_sender=Mock() "server",
federation_http_client=None,
federation_sender=Mock(spec=FederationSender),
) )
return hs return hs
def default_config(self):
config = super().default_config()
config["send_federation"] = True
return config
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@ -529,9 +699,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Add a new remote server to the room # Add a new remote server to the room
self._add_new_user(room_id, "@alice:server2") self._add_new_user(room_id, "@alice:server2")
# We shouldn't have sent out any local presence *updates*
self.federation_sender.send_presence.assert_not_called()
# When new server is joined we send it the local users presence states. # When new server is joined we send it the local users presence states.
# We expect to only see user @test2:server, as @test:server is offline # We expect to only see user @test2:server, as @test:server is offline
# and has a zero last_active_ts # and has a zero last_active_ts
@ -550,7 +717,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.federation_sender.reset_mock() self.federation_sender.reset_mock()
self._add_new_user(room_id, "@bob:server3") self._add_new_user(room_id, "@bob:server3")
self.federation_sender.send_presence.assert_not_called()
self.federation_sender.send_presence_to_destinations.assert_called_once_with( self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=["server3"], states={expected_state} destinations=["server3"], states={expected_state}
) )
@ -595,9 +761,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.reactor.pump([0]) # Wait for presence updates to be handled self.reactor.pump([0]) # Wait for presence updates to be handled
# We shouldn't have sent out any local presence *updates*
self.federation_sender.send_presence.assert_not_called()
# We expect to only send test2 presence to server2 and server3 # We expect to only send test2 presence to server2 and server3
expected_state = self.get_success( expected_state = self.get_success(
self.presence_handler.current_state_for_user("@test2:server") self.presence_handler.current_state_for_user("@test2:server")