0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-10 20:11:32 +01:00

Some cleanups to device inbox store. (#9041)

This commit is contained in:
Erik Johnston 2021-01-07 17:20:44 +00:00 committed by GitHub
parent 9066c2fd7f
commit 63593134a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 57 deletions

1
changelog.d/9041.misc Normal file
View file

@ -0,0 +1 @@
Various cleanups to device inbox store.

View file

@ -18,7 +18,6 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -37,13 +36,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_inbox_id_gen.get_current_token(),
)
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)

View file

@ -17,7 +17,7 @@ import logging
from typing import List, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
@ -310,20 +322,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
@trace
async def add_messages_to_device_inbox(
self,
@ -351,16 +349,19 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a
# federation transaction to that destination.
sql = (
"INSERT INTO device_federation_outbox"
" (destination, stream_id, queued_ts, messages_json)"
" VALUES (?,?,?,?)"
self.db_pool.simple_insert_many_txn(
txn,
table="device_federation_outbox",
values=[
{
"destination": destination,
"stream_id": stream_id,
"queued_ts": now_ms,
"messages_json": json_encoder.encode(edu),
}
for destination, edu in remote_messages_by_destination.items()
],
)
rows = []
for destination, edu in remote_messages_by_destination.items():
edu_json = json_encoder.encode(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
@ -433,32 +434,37 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
devices = self.db_pool.simple_select_onecol_txn(
txn,
table="devices",
keyvalues={"user_id": user_id},
retcol="device_id",
)
message_json = json_encoder.encode(messages_by_device["*"])
for row in txn:
for device_id in devices:
# Add the message for all devices for this user on this
# server.
device = row[0]
messages_json_for_user[device] = message_json
messages_json_for_user[device_id] = message_json
else:
if not devices:
continue
clause, args = make_in_list_sql_clause(
txn.database_engine, "device_id", devices
rows = self.db_pool.simple_select_many_txn(
txn,
table="devices",
keyvalues={"user_id": user_id},
column="device_id",
iterable=devices,
retcols=("device_id",),
)
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + list(args))
for row in txn:
for row in rows:
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
message_json = json_encoder.encode(messages_by_device[device])
messages_json_for_user[device] = message_json
device_id = row["device_id"]
message_json = json_encoder.encode(messages_by_device[device_id])
messages_json_for_user[device_id] = message_json
if messages_json_for_user:
local_by_user_then_device[user_id] = messages_json_for_user
@ -466,14 +472,17 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
if not local_by_user_then_device:
return
sql = (
"INSERT INTO device_inbox"
" (user_id, device_id, stream_id, message_json)"
" VALUES (?,?,?,?)"
self.db_pool.simple_insert_many_txn(
txn,
table="device_inbox",
values=[
{
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"message_json": message_json,
}
for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items()
],
)
rows = []
for user_id, messages_by_device in local_by_user_then_device.items():
for device_id, message_json in messages_by_device.items():
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)