0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-13 05:21:43 +01:00

Some cleanups to device inbox store. (#9041)

This commit is contained in:
Erik Johnston 2021-01-07 17:20:44 +00:00 committed by GitHub
parent 9066c2fd7f
commit 63593134a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 57 deletions

1
changelog.d/9041.misc Normal file
View file

@ -0,0 +1 @@
Various cleanups to device inbox store.

View file

@ -18,7 +18,6 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -37,13 +36,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_inbox_id_gen.get_current_token(), self._device_inbox_id_gen.get_current_token(),
) )
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME: if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token) self._device_inbox_id_gen.advance(instance_name, token)

View file

@ -17,7 +17,7 @@ import logging
from typing import List, Tuple from typing import List, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
@ -310,20 +322,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
@trace @trace
async def add_messages_to_device_inbox( async def add_messages_to_device_inbox(
self, self,
@ -351,16 +349,19 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add the remote messages to the federation outbox. # Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a # We'll send them to a remote server when we next send a
# federation transaction to that destination. # federation transaction to that destination.
sql = ( self.db_pool.simple_insert_many_txn(
"INSERT INTO device_federation_outbox" txn,
" (destination, stream_id, queued_ts, messages_json)" table="device_federation_outbox",
" VALUES (?,?,?,?)" values=[
{
"destination": destination,
"stream_id": stream_id,
"queued_ts": now_ms,
"messages_json": json_encoder.encode(edu),
}
for destination, edu in remote_messages_by_destination.items()
],
) )
rows = []
for destination, edu in remote_messages_by_destination.items():
edu_json = json_encoder.encode(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
async with self._device_inbox_id_gen.get_next() as stream_id: async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
@ -433,32 +434,37 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
devices = list(messages_by_device.keys()) devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*": if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids. # Handle wildcard device_ids.
sql = "SELECT device_id FROM devices WHERE user_id = ?" devices = self.db_pool.simple_select_onecol_txn(
txn.execute(sql, (user_id,)) txn,
table="devices",
keyvalues={"user_id": user_id},
retcol="device_id",
)
message_json = json_encoder.encode(messages_by_device["*"]) message_json = json_encoder.encode(messages_by_device["*"])
for row in txn: for device_id in devices:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
device = row[0] messages_json_for_user[device_id] = message_json
messages_json_for_user[device] = message_json
else: else:
if not devices: if not devices:
continue continue
clause, args = make_in_list_sql_clause( rows = self.db_pool.simple_select_many_txn(
txn.database_engine, "device_id", devices txn,
table="devices",
keyvalues={"user_id": user_id},
column="device_id",
iterable=devices,
retcols=("device_id",),
) )
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are for row in rows:
# too many local devices for a given user.
txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device_id = row["device_id"]
message_json = json_encoder.encode(messages_by_device[device]) message_json = json_encoder.encode(messages_by_device[device_id])
messages_json_for_user[device] = message_json messages_json_for_user[device_id] = message_json
if messages_json_for_user: if messages_json_for_user:
local_by_user_then_device[user_id] = messages_json_for_user local_by_user_then_device[user_id] = messages_json_for_user
@ -466,14 +472,17 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
if not local_by_user_then_device: if not local_by_user_then_device:
return return
sql = ( self.db_pool.simple_insert_many_txn(
"INSERT INTO device_inbox" txn,
" (user_id, device_id, stream_id, message_json)" table="device_inbox",
" VALUES (?,?,?,?)" values=[
{
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"message_json": message_json,
}
for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items()
],
) )
rows = []
for user_id, messages_by_device in local_by_user_then_device.items():
for device_id, message_json in messages_by_device.items():
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)