Add database method to fetch to-device messages by user_ids from db

This method is quite similar to the one below, except that it doesn't
support device ids, and supports querying with more than one user id,
both of which are relevant to application services.

The results are also formatted in a different data structure, so I'm
not sure how much we could really share here between the two methods.
This commit is contained in:
Andrew Morgan 2021-11-05 15:59:08 +00:00
parent 59b5eeefe1
commit 43bdfbbc85

View file

@ -13,13 +13,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@ -116,6 +120,97 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
async def get_new_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
limit: int = 100,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve to-device messages for a given set of user IDs.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Note that multiple messages can have the same stream id. Stream ids are
unique to *messages*, but there might be multiple recipients of a message, and
thus multiple entries in the device_inbox table with the same stream id.
Args:
user_ids: The users to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
limit: The maximum number of to-device messages to return.
Returns:
A tuple containing the following:
* A list of to-device messages
* The stream id that to-device messages were processed up to. The calling
function can check this value against `to_stream_id` to see if we hit
the given limit when fetching messages.
"""
# Bail out if none of these users have any messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
break
else:
return {}, to_stream_id
def get_new_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given users that are between
# the given stream id bounds
# Scope to only the given users. We need to use this method as doing so is
# different across database engines.
many_clause_sql, many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
sql = f"""
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
WHERE {many_clause_sql}
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id, limit))
# Create a dictionary of (user ID, device ID) -> list of messages that
# that device is meant to receive.
recipient_user_id_device_id_to_messages = {}
stream_pos = to_stream_id
total_messages_processed = 0
for row in txn:
# Record the last-processed stream position, to return later.
# Note that we process messages here in order of ascending stream id.
stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
recipient_user_id_device_id_to_messages.setdefault(
(recipient_user_id, recipient_device_id), []
).append(message_dict)
total_messages_processed += 1
# If we don't end up hitting the limit, we still want to return the equivalent
# value of `to_stream_id` to the calling function. This is needed as we'll
# be processing up to `to_stream_id`, without necessarily fetching a message
# with a stream id of `to_stream_id`.
if total_messages_processed < limit:
stream_pos = to_stream_id
return recipient_user_id_device_id_to_messages, stream_pos
return await self.db_pool.runInteraction(
"get_new_messages", get_new_messages_txn
)
async def get_new_messages_for_device(
self,
user_id: str,