Clarifications and small fixes to to-device related code (#11247)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Andrew Morgan 2021-11-09 14:31:15 +00:00 committed by GitHub
parent b6f4d122ef
commit a026695083
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 78 additions and 17 deletions

1
changelog.d/11247.misc Normal file
View file

@ -0,0 +1 @@
Clean up code relating to to-device messages and sending ephemeral events to application services.

View file

@ -188,7 +188,7 @@ class ApplicationServicesHandler:
self, self,
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None, users: Collection[Union[str, UserID]],
) -> None: ) -> None:
""" """
This is called by the notifier in the background when an ephemeral event is handled This is called by the notifier in the background when an ephemeral event is handled
@ -203,7 +203,9 @@ class ApplicationServicesHandler:
value for `stream_key` will cause this function to return early. value for `stream_key` will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into Ephemeral events will only be pushed to appservices that have opted into
them. receiving them by setting `push_ephemeral` to true in their registration
file. Note that while MSC2409 is experimental, this option is called
`de.sorunome.msc2409.push_ephemeral`.
Appservices will only receive ephemeral events that fall within their Appservices will only receive ephemeral events that fall within their
registered user and room namespaces. registered user and room namespaces.
@ -214,6 +216,7 @@ class ApplicationServicesHandler:
if not self.notify_appservices: if not self.notify_appservices:
return return
# Ignore any unsupported streams
if stream_key not in ("typing_key", "receipt_key", "presence_key"): if stream_key not in ("typing_key", "receipt_key", "presence_key"):
return return
@ -230,18 +233,25 @@ class ApplicationServicesHandler:
# Additional context: https://github.com/matrix-org/synapse/pull/11137 # Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int) assert isinstance(new_token, int)
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
services = [ services = [
service service
for service in self.store.get_app_services() for service in self.store.get_app_services()
if service.supports_ephemeral if service.supports_ephemeral
] ]
if not services: if not services:
# Bail out early if none of the target appservices have explicitly registered
# to receive these ephemeral events.
return return
# We only start a new background process if necessary rather than # We only start a new background process if necessary rather than
# optimistically (to cut down on overhead). # optimistically (to cut down on overhead).
self._notify_interested_services_ephemeral( self._notify_interested_services_ephemeral(
services, stream_key, new_token, users or [] services, stream_key, new_token, users
) )
@wrap_as_background_process("notify_interested_services_ephemeral") @wrap_as_background_process("notify_interested_services_ephemeral")
@ -252,7 +262,7 @@ class ApplicationServicesHandler:
new_token: int, new_token: int,
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
) -> None: ) -> None:
logger.debug("Checking interested services for %s" % (stream_key)) logger.debug("Checking interested services for %s", stream_key)
with Measure(self.clock, "notify_interested_services_ephemeral"): with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services: for service in services:
if stream_key == "typing_key": if stream_key == "typing_key":
@ -345,6 +355,9 @@ class ApplicationServicesHandler:
Args: Args:
service: The application service to check for which events it should receive. service: The application service to check for which events it should receive.
new_token: A receipts event stream token. Purely used to double-check that the
from_token we pull from the database isn't greater than or equal to this
token. Prevents accidentally duplicating work.
Returns: Returns:
A list of JSON dictionaries containing data derived from the read receipts that A list of JSON dictionaries containing data derived from the read receipts that
@ -382,6 +395,9 @@ class ApplicationServicesHandler:
Args: Args:
service: The application service that ephemeral events are being sent to. service: The application service that ephemeral events are being sent to.
users: The users that should receive the presence update. users: The users that should receive the presence update.
new_token: A presence update stream token. Purely used to double-check that the
from_token we pull from the database isn't greater than or equal to this
token. Prevents accidentally duplicating work.
Returns: Returns:
A list of json dictionaries containing data derived from the presence events A list of json dictionaries containing data derived from the presence events

View file

@ -89,6 +89,13 @@ class DeviceMessageHandler:
) )
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
"""
Handle receiving to-device messages from remote homeservers.
Args:
origin: The remote homeserver.
content: The JSON dictionary containing the to-device messages.
"""
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id): if origin != get_domain_from_id(sender_user_id):
@ -135,12 +142,16 @@ class DeviceMessageHandler:
message_type, sender_user_id, by_device message_type, sender_user_id, by_device
) )
stream_id = await self.store.add_messages_from_remote_to_device_inbox( # Add messages to the database.
# Retrieve the stream id of the last-processed to-device message.
last_stream_id = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages origin, message_id, local_messages
) )
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream id.
self.notifier.on_new_event( self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys() "to_device_key", last_stream_id, users=local_messages.keys()
) )
async def _check_for_unknown_devices( async def _check_for_unknown_devices(
@ -195,6 +206,14 @@ class DeviceMessageHandler:
message_type: str, message_type: str,
messages: Dict[str, Dict[str, JsonDict]], messages: Dict[str, Dict[str, JsonDict]],
) -> None: ) -> None:
"""
Handle a request from a user to send to-device message(s).
Args:
requester: The user that is sending the to-device messages.
message_type: The type of to-device messages that are being sent.
messages: A dictionary containing recipients mapped to messages intended for them.
"""
sender_user_id = requester.user.to_string() sender_user_id = requester.user.to_string()
message_id = random_string(16) message_id = random_string(16)
@ -257,12 +276,16 @@ class DeviceMessageHandler:
"org.matrix.opentracing_context": json_encoder.encode(context), "org.matrix.opentracing_context": json_encoder.encode(context),
} }
stream_id = await self.store.add_messages_to_device_inbox( # Add messages to the database.
# Retrieve the stream id of the last-processed to-device message.
last_stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents local_messages, remote_edu_contents
) )
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream id.
self.notifier.on_new_event( self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys() "to_device_key", last_stream_id, users=local_messages.keys()
) )
if self.federation_sender: if self.federation_sender:

View file

@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
) )
async def set_type_stream_id_for_appservice( async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: Optional[int] self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None: ) -> None:
if type not in ("read_receipt", "presence"): if stream_type not in ("read_receipt", "presence"):
raise ValueError( raise ValueError(
"Expected type to be a valid application stream id type, got %s" "Expected type to be a valid application stream id type, got %s"
% (type,) % (stream_type,)
) )
def set_type_stream_id_for_appservice_txn(txn): def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type stream_id_type = "%s_stream_id" % stream_type
txn.execute( txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?" "UPDATE application_services_state SET %s = ? WHERE as_id=?"
% stream_id_type, % stream_id_type,

View file

@ -134,7 +134,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
limit: The maximum number of messages to retrieve. limit: The maximum number of messages to retrieve.
Returns: Returns:
A list of messages for the device and where in the stream the messages got to. A tuple containing:
* A list of messages for the device.
* The max stream token of these messages. There may be more to retrieve
if the given limit was reached.
""" """
has_changed = self._device_inbox_stream_cache.has_entity_changed( has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id user_id, last_stream_id
@ -153,12 +156,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute( txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit) sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
) )
messages = [] messages = []
stream_pos = current_stream_id
for row in txn: for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(db_to_json(row[1])) messages.append(db_to_json(row[1]))
# If the limit was not reached we know that there's no more data for this
# user/device pair up to current_stream_id.
if len(messages) < limit: if len(messages) < limit:
stream_pos = current_stream_id stream_pos = current_stream_id
return messages, stream_pos return messages, stream_pos
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -260,13 +270,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" LIMIT ?" " LIMIT ?"
) )
txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
messages = [] messages = []
stream_pos = current_stream_id
for row in txn: for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(db_to_json(row[1])) messages.append(db_to_json(row[1]))
# If the limit was not reached we know that there's no more data for this
# user/device pair up to current_stream_id.
if len(messages) < limit: if len(messages) < limit:
log_kv({"message": "Set stream position to current position"}) log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id stream_pos = current_stream_id
return messages, stream_pos return messages, stream_pos
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -372,8 +389,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"""Used to send messages from this server. """Used to send messages from this server.
Args: Args:
local_messages_by_user_and_device: local_messages_by_user_then_device:
Dictionary of user_id to device_id to message. Dictionary of recipient user_id to recipient device_id to message.
remote_messages_by_destination: remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send. Dictionary of destination server_name to the EDU JSON to send.

View file

@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
make_awaitable(([event], None)) make_awaitable(([event], None))
) )
self.handler.notify_interested_services_ephemeral("receipt_key", 580) self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
interested_service, [event] interested_service, [event]
) )
@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
make_awaitable(([event], None)) make_awaitable(([event], None))
) )
self.handler.notify_interested_services_ephemeral("receipt_key", 579) self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
def _mkservice(self, is_interested, protocols=None): def _mkservice(self, is_interested, protocols=None):