Add database method to fetch to-device messages by user_ids from db
This method is quite similar to the one below, except that it doesn't support device ids, and supports querying with more than one user id, both of which are relevant to application services. The results are also formatted in a different data structure, so I'm not sure how much we could really share here between the two methods.
This commit is contained in:
parent
59b5eeefe1
commit
43bdfbbc85
|
@ -13,13 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.logging import issue9533_logger
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.replication.tcp.streams import ToDeviceStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
|
@ -116,6 +120,97 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
def get_to_device_stream_token(self):
|
||||
return self._device_inbox_id_gen.get_current_token()
|
||||
|
||||
async def get_new_messages(
|
||||
self,
|
||||
user_ids: Collection[str],
|
||||
from_stream_id: int,
|
||||
to_stream_id: int,
|
||||
limit: int = 100,
|
||||
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
|
||||
"""
|
||||
Retrieve to-device messages for a given set of user IDs.
|
||||
|
||||
Only to-device messages with stream ids between the given boundaries
|
||||
(from < X <= to) are returned.
|
||||
|
||||
Note that multiple messages can have the same stream id. Stream ids are
|
||||
unique to *messages*, but there might be multiple recipients of a message, and
|
||||
thus multiple entries in the device_inbox table with the same stream id.
|
||||
|
||||
Args:
|
||||
user_ids: The users to retrieve to-device messages for.
|
||||
from_stream_id: The lower boundary of stream id to filter with (exclusive).
|
||||
to_stream_id: The upper boundary of stream id to filter with (inclusive).
|
||||
limit: The maximum number of to-device messages to return.
|
||||
|
||||
Returns:
|
||||
A tuple containing the following:
|
||||
* A list of to-device messages
|
||||
* The stream id that to-device messages were processed up to. The calling
|
||||
function can check this value against `to_stream_id` to see if we hit
|
||||
the given limit when fetching messages.
|
||||
"""
|
||||
# Bail out if none of these users have any messages
|
||||
for user_id in user_ids:
|
||||
if self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, from_stream_id
|
||||
):
|
||||
break
|
||||
else:
|
||||
return {}, to_stream_id
|
||||
|
||||
def get_new_messages_txn(txn: LoggingTransaction):
|
||||
# Build a query to select messages from any of the given users that are between
|
||||
# the given stream id bounds
|
||||
|
||||
# Scope to only the given users. We need to use this method as doing so is
|
||||
# different across database engines.
|
||||
many_clause_sql, many_clause_args = make_in_list_sql_clause(
|
||||
self.database_engine, "user_id", user_ids
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
|
||||
WHERE {many_clause_sql}
|
||||
AND ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id, limit))
|
||||
|
||||
# Create a dictionary of (user ID, device ID) -> list of messages that
|
||||
# that device is meant to receive.
|
||||
recipient_user_id_device_id_to_messages = {}
|
||||
|
||||
stream_pos = to_stream_id
|
||||
total_messages_processed = 0
|
||||
for row in txn:
|
||||
# Record the last-processed stream position, to return later.
|
||||
# Note that we process messages here in order of ascending stream id.
|
||||
stream_pos = row[0]
|
||||
recipient_user_id = row[1]
|
||||
recipient_device_id = row[2]
|
||||
message_dict = db_to_json(row[3])
|
||||
|
||||
recipient_user_id_device_id_to_messages.setdefault(
|
||||
(recipient_user_id, recipient_device_id), []
|
||||
).append(message_dict)
|
||||
total_messages_processed += 1
|
||||
|
||||
# If we don't end up hitting the limit, we still want to return the equivalent
|
||||
# value of `to_stream_id` to the calling function. This is needed as we'll
|
||||
# be processing up to `to_stream_id`, without necessarily fetching a message
|
||||
# with a stream id of `to_stream_id`.
|
||||
if total_messages_processed < limit:
|
||||
stream_pos = to_stream_id
|
||||
|
||||
return recipient_user_id_device_id_to_messages, stream_pos
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_new_messages", get_new_messages_txn
|
||||
)
|
||||
|
||||
async def get_new_messages_for_device(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
Loading…
Reference in a new issue