mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 08:23:51 +01:00
Merge pull request #6178 from matrix-org/babolivier/factor_out_bg_updates
Factor out backgroung updates
This commit is contained in:
commit
59d6290ed9
10 changed files with 671 additions and 605 deletions
1
changelog.d/6178.bugfix
Normal file
1
changelog.d/6178.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database.
|
|
@ -33,16 +33,9 @@ logger = logging.getLogger(__name__)
|
||||||
LAST_SEEN_GRANULARITY = 120 * 1000
|
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||||
|
|
||||||
|
|
||||||
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
|
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
self.client_ip_last_seen = Cache(
|
|
||||||
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
|
||||||
)
|
|
||||||
|
|
||||||
super(ClientIpStore, self).__init__(db_conn, hs)
|
|
||||||
|
|
||||||
self.user_ips_max_age = hs.config.user_ips_max_age
|
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
"user_ips_device_index",
|
"user_ips_device_index",
|
||||||
|
@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
"devices_last_seen", self._devices_last_seen_update
|
"devices_last_seen", self._devices_last_seen_update
|
||||||
)
|
)
|
||||||
|
|
||||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
|
||||||
self._batch_row_update = {}
|
|
||||||
|
|
||||||
self._client_ip_looper = self._clock.looping_call(
|
|
||||||
self._update_client_ips_batch, 5 * 1000
|
|
||||||
)
|
|
||||||
self.hs.get_reactor().addSystemEventTrigger(
|
|
||||||
"before", "shutdown", self._update_client_ips_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.user_ips_max_age:
|
|
||||||
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _remove_user_ip_nonunique(self, progress, batch_size):
|
def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||||
def f(conn):
|
def f(conn):
|
||||||
|
@ -303,6 +283,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _devices_last_seen_update(self, progress, batch_size):
|
||||||
|
"""Background update to insert last seen info into devices table
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_user_id = progress.get("last_user_id", "")
|
||||||
|
last_device_id = progress.get("last_device_id", "")
|
||||||
|
|
||||||
|
def _devices_last_seen_update_txn(txn):
|
||||||
|
# This consists of two queries:
|
||||||
|
#
|
||||||
|
# 1. The sub-query searches for the next N devices and joins
|
||||||
|
# against user_ips to find the max last_seen associated with
|
||||||
|
# that device.
|
||||||
|
# 2. The outer query then joins again against user_ips on
|
||||||
|
# user/device/last_seen. This *should* hopefully only
|
||||||
|
# return one row, but if it does return more than one then
|
||||||
|
# we'll just end up updating the same device row multiple
|
||||||
|
# times, which is fine.
|
||||||
|
|
||||||
|
if self.database_engine.supports_tuple_comparison:
|
||||||
|
where_clause = "(user_id, device_id) > (?, ?)"
|
||||||
|
where_args = [last_user_id, last_device_id]
|
||||||
|
else:
|
||||||
|
# We explicitly do a `user_id >= ? AND (...)` here to ensure
|
||||||
|
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
|
||||||
|
# makes it hard for query optimiser to tell that it can use the
|
||||||
|
# index on user_id
|
||||||
|
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
|
||||||
|
where_args = [last_user_id, last_user_id, last_device_id]
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
last_seen, ip, user_agent, user_id, device_id
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
user_id, device_id, MAX(u.last_seen) AS last_seen
|
||||||
|
FROM devices
|
||||||
|
INNER JOIN user_ips AS u USING (user_id, device_id)
|
||||||
|
WHERE %(where_clause)s
|
||||||
|
GROUP BY user_id, device_id
|
||||||
|
ORDER BY user_id ASC, device_id ASC
|
||||||
|
LIMIT ?
|
||||||
|
) c
|
||||||
|
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
|
||||||
|
""" % {
|
||||||
|
"where_clause": where_clause
|
||||||
|
}
|
||||||
|
txn.execute(sql, where_args + [batch_size])
|
||||||
|
|
||||||
|
rows = txn.fetchall()
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
UPDATE devices
|
||||||
|
SET last_seen = ?, ip = ?, user_agent = ?
|
||||||
|
WHERE user_id = ? AND device_id = ?
|
||||||
|
"""
|
||||||
|
txn.execute_batch(sql, rows)
|
||||||
|
|
||||||
|
_, _, _, user_id, device_id = rows[-1]
|
||||||
|
self._background_update_progress_txn(
|
||||||
|
txn,
|
||||||
|
"devices_last_seen",
|
||||||
|
{"last_user_id": user_id, "last_device_id": device_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
updated = yield self.runInteraction(
|
||||||
|
"_devices_last_seen_update", _devices_last_seen_update_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
yield self._end_background_update("devices_last_seen")
|
||||||
|
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
|
||||||
|
self.client_ip_last_seen = Cache(
|
||||||
|
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
||||||
|
)
|
||||||
|
|
||||||
|
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self.user_ips_max_age = hs.config.user_ips_max_age
|
||||||
|
|
||||||
|
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||||
|
self._batch_row_update = {}
|
||||||
|
|
||||||
|
self._client_ip_looper = self._clock.looping_call(
|
||||||
|
self._update_client_ips_batch, 5 * 1000
|
||||||
|
)
|
||||||
|
self.hs.get_reactor().addSystemEventTrigger(
|
||||||
|
"before", "shutdown", self._update_client_ips_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.user_ips_max_age:
|
||||||
|
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def insert_client_ip(
|
def insert_client_ip(
|
||||||
self, user_id, access_token, ip, user_agent, device_id, now=None
|
self, user_id, access_token, ip, user_agent, device_id, now=None
|
||||||
|
@ -454,85 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _devices_last_seen_update(self, progress, batch_size):
|
|
||||||
"""Background update to insert last seen info into devices table
|
|
||||||
"""
|
|
||||||
|
|
||||||
last_user_id = progress.get("last_user_id", "")
|
|
||||||
last_device_id = progress.get("last_device_id", "")
|
|
||||||
|
|
||||||
def _devices_last_seen_update_txn(txn):
|
|
||||||
# This consists of two queries:
|
|
||||||
#
|
|
||||||
# 1. The sub-query searches for the next N devices and joins
|
|
||||||
# against user_ips to find the max last_seen associated with
|
|
||||||
# that device.
|
|
||||||
# 2. The outer query then joins again against user_ips on
|
|
||||||
# user/device/last_seen. This *should* hopefully only
|
|
||||||
# return one row, but if it does return more than one then
|
|
||||||
# we'll just end up updating the same device row multiple
|
|
||||||
# times, which is fine.
|
|
||||||
|
|
||||||
if self.database_engine.supports_tuple_comparison:
|
|
||||||
where_clause = "(user_id, device_id) > (?, ?)"
|
|
||||||
where_args = [last_user_id, last_device_id]
|
|
||||||
else:
|
|
||||||
# We explicitly do a `user_id >= ? AND (...)` here to ensure
|
|
||||||
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
|
|
||||||
# makes it hard for query optimiser to tell that it can use the
|
|
||||||
# index on user_id
|
|
||||||
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
|
|
||||||
where_args = [last_user_id, last_user_id, last_device_id]
|
|
||||||
|
|
||||||
sql = """
|
|
||||||
SELECT
|
|
||||||
last_seen, ip, user_agent, user_id, device_id
|
|
||||||
FROM (
|
|
||||||
SELECT
|
|
||||||
user_id, device_id, MAX(u.last_seen) AS last_seen
|
|
||||||
FROM devices
|
|
||||||
INNER JOIN user_ips AS u USING (user_id, device_id)
|
|
||||||
WHERE %(where_clause)s
|
|
||||||
GROUP BY user_id, device_id
|
|
||||||
ORDER BY user_id ASC, device_id ASC
|
|
||||||
LIMIT ?
|
|
||||||
) c
|
|
||||||
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
|
|
||||||
""" % {
|
|
||||||
"where_clause": where_clause
|
|
||||||
}
|
|
||||||
txn.execute(sql, where_args + [batch_size])
|
|
||||||
|
|
||||||
rows = txn.fetchall()
|
|
||||||
if not rows:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
sql = """
|
|
||||||
UPDATE devices
|
|
||||||
SET last_seen = ?, ip = ?, user_agent = ?
|
|
||||||
WHERE user_id = ? AND device_id = ?
|
|
||||||
"""
|
|
||||||
txn.execute_batch(sql, rows)
|
|
||||||
|
|
||||||
_, _, _, user_id, device_id = rows[-1]
|
|
||||||
self._background_update_progress_txn(
|
|
||||||
txn,
|
|
||||||
"devices_last_seen",
|
|
||||||
{"last_user_id": user_id, "last_device_id": device_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(rows)
|
|
||||||
|
|
||||||
updated = yield self.runInteraction(
|
|
||||||
"_devices_last_seen_update", _devices_last_seen_update_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
if not updated:
|
|
||||||
yield self._end_background_update("devices_last_seen")
|
|
||||||
|
|
||||||
return updated
|
|
||||||
|
|
||||||
@wrap_as_background_process("prune_old_user_ips")
|
@wrap_as_background_process("prune_old_user_ips")
|
||||||
async def _prune_old_user_ips(self):
|
async def _prune_old_user_ips(self):
|
||||||
"""Removes entries in user IPs older than the configured period.
|
"""Removes entries in user IPs older than the configured period.
|
||||||
|
|
|
@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
|
||||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(DeviceInboxStore, self).__init__(db_conn, hs)
|
super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
"device_inbox_stream_index",
|
"device_inbox_stream_index",
|
||||||
|
@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||||
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||||
|
def reindex_txn(conn):
|
||||||
|
txn = conn.cursor()
|
||||||
|
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
||||||
|
txn.close()
|
||||||
|
|
||||||
|
yield self.runWithConnection(reindex_txn)
|
||||||
|
|
||||||
|
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
||||||
|
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(DeviceInboxStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
# Map of (user_id, device_id) to the last stream_id that has been
|
# Map of (user_id, device_id) to the last stream_id that has been
|
||||||
# deleted up to. This is so that we can no op deletions.
|
# deleted up to. This is so that we can no op deletions.
|
||||||
self._last_device_delete_cache = ExpiringCache(
|
self._last_device_delete_cache = ExpiringCache(
|
||||||
|
@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _background_drop_index_device_inbox(self, progress, batch_size):
|
|
||||||
def reindex_txn(conn):
|
|
||||||
txn = conn.cursor()
|
|
||||||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
|
||||||
txn.close()
|
|
||||||
|
|
||||||
yield self.runWithConnection(reindex_txn)
|
|
||||||
|
|
||||||
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
|
|
||||||
|
|
||||||
return 1
|
|
||||||
|
|
|
@ -512,17 +512,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(DeviceStore, self).__init__(db_conn, hs)
|
super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
|
||||||
# the device exists.
|
|
||||||
self.device_id_exists_cache = Cache(
|
|
||||||
name="device_id_exists", keylen=2, max_entries=10000
|
|
||||||
)
|
|
||||||
|
|
||||||
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
"device_lists_stream_idx",
|
"device_lists_stream_idx",
|
||||||
|
@ -555,6 +547,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||||
self._drop_device_list_streams_non_unique_indexes,
|
self._drop_device_list_streams_non_unique_indexes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||||
|
def f(conn):
|
||||||
|
txn = conn.cursor()
|
||||||
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||||
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||||
|
txn.close()
|
||||||
|
|
||||||
|
yield self.runWithConnection(f)
|
||||||
|
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(DeviceStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||||
|
# the device exists.
|
||||||
|
self.device_id_exists_cache = Cache(
|
||||||
|
name="device_id_exists", keylen=2, max_entries=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def store_device(self, user_id, device_id, initial_device_display_name):
|
def store_device(self, user_id, device_id, initial_device_display_name):
|
||||||
"""Ensure the given device is known; add it to the store if not
|
"""Ensure the given device is known; add it to the store if not
|
||||||
|
@ -910,15 +927,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||||
"_prune_old_outbound_device_pokes",
|
"_prune_old_outbound_device_pokes",
|
||||||
_prune_txn,
|
_prune_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
|
||||||
def f(conn):
|
|
||||||
txn = conn.cursor()
|
|
||||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
|
||||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
|
||||||
txn.close()
|
|
||||||
|
|
||||||
yield self.runWithConnection(f)
|
|
||||||
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
|
|
||||||
return 1
|
|
||||||
|
|
|
@ -15,11 +15,9 @@
|
||||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
|
|
||||||
|
|
||||||
class MediaRepositoryStore(BackgroundUpdateStore):
|
class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
|
||||||
"""Persistence for attachments and avatars"""
|
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(MediaRepositoryStore, self).__init__(db_conn, hs)
|
super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
update_name="local_media_repository_url_idx",
|
update_name="local_media_repository_url_idx",
|
||||||
|
@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
||||||
where_clause="url_cache IS NOT NULL",
|
where_clause="url_cache IS NOT NULL",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
|
"""Persistence for attachments and avatars"""
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(MediaRepositoryStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
def get_local_media(self, media_id):
|
def get_local_media(self, media_id):
|
||||||
"""Get the metadata for a local piece of media
|
"""Get the metadata for a local piece of media
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -787,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(
|
class RegistrationBackgroundUpdateStore(
|
||||||
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
|
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
|
||||||
):
|
):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.config = hs.config
|
||||||
|
|
||||||
self.register_background_index_update(
|
self.register_background_index_update(
|
||||||
"access_tokens_device_index",
|
"access_tokens_device_index",
|
||||||
|
@ -809,8 +810,6 @@ class RegistrationStore(
|
||||||
columns=["creation_ts"],
|
columns=["creation_ts"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self._account_validity = hs.config.account_validity
|
|
||||||
|
|
||||||
# we no longer use refresh tokens, but it's possible that some people
|
# we no longer use refresh tokens, but it's possible that some people
|
||||||
# might have a background update queued to build this index. Just
|
# might have a background update queued to build this index. Just
|
||||||
# clear the background update.
|
# clear the background update.
|
||||||
|
@ -824,17 +823,6 @@ class RegistrationStore(
|
||||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a background job for culling expired 3PID validity tokens
|
|
||||||
def start_cull():
|
|
||||||
# run as a background process to make sure that the database transactions
|
|
||||||
# have a logcontext to report to
|
|
||||||
return run_as_background_process(
|
|
||||||
"cull_expired_threepid_validation_tokens",
|
|
||||||
self.cull_expired_threepid_validation_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_update_set_deactivated_flag(self, progress, batch_size):
|
def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||||
|
@ -896,6 +884,54 @@ class RegistrationStore(
|
||||||
|
|
||||||
return nb_processed
|
return nb_processed
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _bg_user_threepids_grandfather(self, progress, batch_size):
|
||||||
|
"""We now track which identity servers a user binds their 3PID to, so
|
||||||
|
we need to handle the case of existing bindings where we didn't track
|
||||||
|
this.
|
||||||
|
|
||||||
|
We do this by grandfathering in existing user threepids assuming that
|
||||||
|
they used one of the server configured trusted identity servers.
|
||||||
|
"""
|
||||||
|
id_servers = set(self.config.trusted_third_party_id_servers)
|
||||||
|
|
||||||
|
def _bg_user_threepids_grandfather_txn(txn):
|
||||||
|
sql = """
|
||||||
|
INSERT INTO user_threepid_id_server
|
||||||
|
(user_id, medium, address, id_server)
|
||||||
|
SELECT user_id, medium, address, ?
|
||||||
|
FROM user_threepids
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
||||||
|
|
||||||
|
if id_servers:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._end_background_update("user_threepids_grandfather")
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self._account_validity = hs.config.account_validity
|
||||||
|
|
||||||
|
# Create a background job for culling expired 3PID validity tokens
|
||||||
|
def start_cull():
|
||||||
|
# run as a background process to make sure that the database transactions
|
||||||
|
# have a logcontext to report to
|
||||||
|
return run_as_background_process(
|
||||||
|
"cull_expired_threepid_validation_tokens",
|
||||||
|
self.cull_expired_threepid_validation_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
|
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
|
||||||
"""Adds an access token for the given user.
|
"""Adds an access token for the given user.
|
||||||
|
@ -1244,36 +1280,6 @@ class RegistrationStore(
|
||||||
desc="get_users_pending_deactivation",
|
desc="get_users_pending_deactivation",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _bg_user_threepids_grandfather(self, progress, batch_size):
|
|
||||||
"""We now track which identity servers a user binds their 3PID to, so
|
|
||||||
we need to handle the case of existing bindings where we didn't track
|
|
||||||
this.
|
|
||||||
|
|
||||||
We do this by grandfathering in existing user threepids assuming that
|
|
||||||
they used one of the server configured trusted identity servers.
|
|
||||||
"""
|
|
||||||
id_servers = set(self.config.trusted_third_party_id_servers)
|
|
||||||
|
|
||||||
def _bg_user_threepids_grandfather_txn(txn):
|
|
||||||
sql = """
|
|
||||||
INSERT INTO user_threepid_id_server
|
|
||||||
(user_id, medium, address, id_server)
|
|
||||||
SELECT user_id, medium, address, ?
|
|
||||||
FROM user_threepids
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
|
||||||
|
|
||||||
if id_servers:
|
|
||||||
yield self.runInteraction(
|
|
||||||
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self._end_background_update("user_threepids_grandfather")
|
|
||||||
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
|
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
|
||||||
"""Attempt to validate a threepid session using a token
|
"""Attempt to validate a threepid session using a token
|
||||||
|
|
||||||
|
@ -1465,17 +1471,6 @@ class RegistrationStore(
|
||||||
self.clock.time_msec(),
|
self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
|
||||||
self._simple_update_one_txn(
|
|
||||||
txn=txn,
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user_id},
|
|
||||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
|
||||||
)
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.get_user_deactivated_status, (user_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_user_deactivated_status(self, user_id, deactivated):
|
def set_user_deactivated_status(self, user_id, deactivated):
|
||||||
"""Set the `deactivated` property for the provided user to the provided value.
|
"""Set the `deactivated` property for the provided user to the provided value.
|
||||||
|
@ -1491,3 +1486,14 @@ class RegistrationStore(
|
||||||
user_id,
|
user_id,
|
||||||
deactivated,
|
deactivated,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
||||||
|
self._simple_update_one_txn(
|
||||||
|
txn=txn,
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_deactivated_status, (user_id,)
|
||||||
|
)
|
||||||
|
|
|
@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import LoggingTransaction
|
from synapse.storage._base import LoggingTransaction
|
||||||
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
from synapse.storage.engines import Sqlite3Engine
|
from synapse.storage.engines import Sqlite3Engine
|
||||||
from synapse.storage.events_worker import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
@ -820,9 +821,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
return set(room_ids)
|
return set(room_ids)
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberStore(RoomMemberWorkerStore):
|
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RoomMemberStore, self).__init__(db_conn, hs)
|
super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||||
)
|
)
|
||||||
|
@ -838,112 +839,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
|
||||||
where_clause="forgotten = 1",
|
where_clause="forgotten = 1",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _store_room_members_txn(self, txn, events, backfilled):
|
|
||||||
"""Store a room member in the database.
|
|
||||||
"""
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="room_memberships",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"user_id": event.state_key,
|
|
||||||
"sender": event.user_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"membership": event.membership,
|
|
||||||
"display_name": event.content.get("displayname", None),
|
|
||||||
"avatar_url": event.content.get("avatar_url", None),
|
|
||||||
}
|
|
||||||
for event in events
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
txn.call_after(
|
|
||||||
self._membership_stream_cache.entity_has_changed,
|
|
||||||
event.state_key,
|
|
||||||
event.internal_metadata.stream_ordering,
|
|
||||||
)
|
|
||||||
txn.call_after(
|
|
||||||
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We update the local_invites table only if the event is "current",
|
|
||||||
# i.e., its something that has just happened. If the event is an
|
|
||||||
# outlier it is only current if its an "out of band membership",
|
|
||||||
# like a remote invite or a rejection of a remote invite.
|
|
||||||
is_new_state = not backfilled and (
|
|
||||||
not event.internal_metadata.is_outlier()
|
|
||||||
or event.internal_metadata.is_out_of_band_membership()
|
|
||||||
)
|
|
||||||
is_mine = self.hs.is_mine_id(event.state_key)
|
|
||||||
if is_new_state and is_mine:
|
|
||||||
if event.membership == Membership.INVITE:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="local_invites",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"invitee": event.state_key,
|
|
||||||
"inviter": event.sender,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"stream_id": event.internal_metadata.stream_ordering,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sql = (
|
|
||||||
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
|
||||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
|
||||||
" AND replaced_by is NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(
|
|
||||||
event.internal_metadata.stream_ordering,
|
|
||||||
event.event_id,
|
|
||||||
event.room_id,
|
|
||||||
event.state_key,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def locally_reject_invite(self, user_id, room_id):
|
|
||||||
sql = (
|
|
||||||
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
|
||||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
|
||||||
" AND replaced_by is NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
def f(txn, stream_ordering):
|
|
||||||
txn.execute(sql, (stream_ordering, True, room_id, user_id))
|
|
||||||
|
|
||||||
with self._stream_id_gen.get_next() as stream_ordering:
|
|
||||||
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
|
||||||
|
|
||||||
def forget(self, user_id, room_id):
|
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
|
||||||
|
|
||||||
def f(txn):
|
|
||||||
sql = (
|
|
||||||
"UPDATE"
|
|
||||||
" room_memberships"
|
|
||||||
" SET"
|
|
||||||
" forgotten = 1"
|
|
||||||
" WHERE"
|
|
||||||
" user_id = ?"
|
|
||||||
" AND"
|
|
||||||
" room_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (user_id, room_id))
|
|
||||||
|
|
||||||
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.runInteraction("forget_membership", f)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_add_membership_profile(self, progress, batch_size):
|
def _background_add_membership_profile(self, progress, batch_size):
|
||||||
target_min_stream_id = progress.get(
|
target_min_stream_id = progress.get(
|
||||||
|
@ -1078,6 +973,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
|
||||||
return row_count
|
return row_count
|
||||||
|
|
||||||
|
|
||||||
|
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(RoomMemberStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
def _store_room_members_txn(self, txn, events, backfilled):
|
||||||
|
"""Store a room member in the database.
|
||||||
|
"""
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="room_memberships",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"user_id": event.state_key,
|
||||||
|
"sender": event.user_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"membership": event.membership,
|
||||||
|
"display_name": event.content.get("displayname", None),
|
||||||
|
"avatar_url": event.content.get("avatar_url", None),
|
||||||
|
}
|
||||||
|
for event in events
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
txn.call_after(
|
||||||
|
self._membership_stream_cache.entity_has_changed,
|
||||||
|
event.state_key,
|
||||||
|
event.internal_metadata.stream_ordering,
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We update the local_invites table only if the event is "current",
|
||||||
|
# i.e., its something that has just happened. If the event is an
|
||||||
|
# outlier it is only current if its an "out of band membership",
|
||||||
|
# like a remote invite or a rejection of a remote invite.
|
||||||
|
is_new_state = not backfilled and (
|
||||||
|
not event.internal_metadata.is_outlier()
|
||||||
|
or event.internal_metadata.is_out_of_band_membership()
|
||||||
|
)
|
||||||
|
is_mine = self.hs.is_mine_id(event.state_key)
|
||||||
|
if is_new_state and is_mine:
|
||||||
|
if event.membership == Membership.INVITE:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="local_invites",
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"invitee": event.state_key,
|
||||||
|
"inviter": event.sender,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"stream_id": event.internal_metadata.stream_ordering,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
||||||
|
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||||
|
" AND replaced_by is NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
event.internal_metadata.stream_ordering,
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
event.state_key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def locally_reject_invite(self, user_id, room_id):
|
||||||
|
sql = (
|
||||||
|
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
||||||
|
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||||
|
" AND replaced_by is NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
def f(txn, stream_ordering):
|
||||||
|
txn.execute(sql, (stream_ordering, True, room_id, user_id))
|
||||||
|
|
||||||
|
with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
|
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||||
|
|
||||||
|
def forget(self, user_id, room_id):
|
||||||
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
|
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"UPDATE"
|
||||||
|
" room_memberships"
|
||||||
|
" SET"
|
||||||
|
" forgotten = 1"
|
||||||
|
" WHERE"
|
||||||
|
" user_id = ?"
|
||||||
|
" AND"
|
||||||
|
" room_id = ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, room_id))
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.runInteraction("forget_membership", f)
|
||||||
|
|
||||||
|
|
||||||
class _JoinedHostsCache(object):
|
class _JoinedHostsCache(object):
|
||||||
"""Cache for joined hosts in a room that is optimised to handle updates
|
"""Cache for joined hosts in a room that is optimised to handle updates
|
||||||
via state deltas.
|
via state deltas.
|
||||||
|
|
|
@ -36,7 +36,7 @@ SearchEntry = namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchStore(BackgroundUpdateStore):
|
class SearchBackgroundUpdateStore(BackgroundUpdateStore):
|
||||||
|
|
||||||
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
||||||
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||||
|
@ -44,7 +44,7 @@ class SearchStore(BackgroundUpdateStore):
|
||||||
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
|
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SearchStore, self).__init__(db_conn, hs)
|
super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
if not hs.config.enable_search:
|
if not hs.config.enable_search:
|
||||||
return
|
return
|
||||||
|
@ -289,29 +289,6 @@ class SearchStore(BackgroundUpdateStore):
|
||||||
|
|
||||||
return num_rows
|
return num_rows
|
||||||
|
|
||||||
def store_event_search_txn(self, txn, event, key, value):
|
|
||||||
"""Add event to the search table
|
|
||||||
|
|
||||||
Args:
|
|
||||||
txn (cursor):
|
|
||||||
event (EventBase):
|
|
||||||
key (str):
|
|
||||||
value (str):
|
|
||||||
"""
|
|
||||||
self.store_search_entries_txn(
|
|
||||||
txn,
|
|
||||||
(
|
|
||||||
SearchEntry(
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
event_id=event.event_id,
|
|
||||||
room_id=event.room_id,
|
|
||||||
stream_ordering=event.internal_metadata.stream_ordering,
|
|
||||||
origin_server_ts=event.origin_server_ts,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def store_search_entries_txn(self, txn, entries):
|
def store_search_entries_txn(self, txn, entries):
|
||||||
"""Add entries to the search table
|
"""Add entries to the search table
|
||||||
|
|
||||||
|
@ -358,6 +335,34 @@ class SearchStore(BackgroundUpdateStore):
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchStore(SearchBackgroundUpdateStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SearchStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
def store_event_search_txn(self, txn, event, key, value):
|
||||||
|
"""Add event to the search table
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn (cursor):
|
||||||
|
event (EventBase):
|
||||||
|
key (str):
|
||||||
|
value (str):
|
||||||
|
"""
|
||||||
|
self.store_search_entries_txn(
|
||||||
|
txn,
|
||||||
|
(
|
||||||
|
SearchEntry(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
event_id=event.event_id,
|
||||||
|
room_id=event.room_id,
|
||||||
|
stream_ordering=event.internal_metadata.stream_ordering,
|
||||||
|
origin_server_ts=event.origin_server_ts,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def search_msgs(self, room_ids, search_term, keys):
|
def search_msgs(self, room_ids, search_term, keys):
|
||||||
"""Performs a full text search over events with given keys.
|
"""Performs a full text search over events with given keys.
|
||||||
|
|
|
@ -353,8 +353,158 @@ class StateFilter(object):
|
||||||
return member_filter, non_member_filter
|
return member_filter, non_member_filter
|
||||||
|
|
||||||
|
|
||||||
|
class StateGroupBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
"""Defines functions related to state groups needed to run the state backgroud
|
||||||
|
updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _count_state_group_hops_txn(self, txn, state_group):
|
||||||
|
"""Given a state group, count how many hops there are in the tree.
|
||||||
|
|
||||||
|
This is used to ensure the delta chains don't get too long.
|
||||||
|
"""
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
sql = """
|
||||||
|
WITH RECURSIVE state(state_group) AS (
|
||||||
|
VALUES(?::bigint)
|
||||||
|
UNION ALL
|
||||||
|
SELECT prev_state_group FROM state_group_edges e, state s
|
||||||
|
WHERE s.state_group = e.state_group
|
||||||
|
)
|
||||||
|
SELECT count(*) FROM state;
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (state_group,))
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row and row[0]:
|
||||||
|
return row[0]
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||||
|
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||||
|
next_group = state_group
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
while next_group:
|
||||||
|
next_group = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_group_edges",
|
||||||
|
keyvalues={"state_group": next_group},
|
||||||
|
retcol="prev_state_group",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if next_group:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _get_state_groups_from_groups_txn(
|
||||||
|
self, txn, groups, state_filter=StateFilter.all()
|
||||||
|
):
|
||||||
|
results = {group: {} for group in groups}
|
||||||
|
|
||||||
|
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||||
|
|
||||||
|
# Unless the filter clause is empty, we're going to append it after an
|
||||||
|
# existing where clause
|
||||||
|
if where_clause:
|
||||||
|
where_clause = " AND (%s)" % (where_clause,)
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
# Temporarily disable sequential scans in this transaction. This is
|
||||||
|
# a temporary hack until we can add the right indices in
|
||||||
|
txn.execute("SET LOCAL enable_seqscan=off")
|
||||||
|
|
||||||
|
# The below query walks the state_group tree so that the "state"
|
||||||
|
# table includes all state_groups in the tree. It then joins
|
||||||
|
# against `state_groups_state` to fetch the latest state.
|
||||||
|
# It assumes that previous state groups are always numerically
|
||||||
|
# lesser.
|
||||||
|
# The PARTITION is used to get the event_id in the greatest state
|
||||||
|
# group for the given type, state_key.
|
||||||
|
# This may return multiple rows per (type, state_key), but last_value
|
||||||
|
# should be the same.
|
||||||
|
sql = """
|
||||||
|
WITH RECURSIVE state(state_group) AS (
|
||||||
|
VALUES(?::bigint)
|
||||||
|
UNION ALL
|
||||||
|
SELECT prev_state_group FROM state_group_edges e, state s
|
||||||
|
WHERE s.state_group = e.state_group
|
||||||
|
)
|
||||||
|
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
|
||||||
|
PARTITION BY type, state_key ORDER BY state_group ASC
|
||||||
|
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
|
||||||
|
) AS event_id FROM state_groups_state
|
||||||
|
WHERE state_group IN (
|
||||||
|
SELECT state_group FROM state
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
for group in groups:
|
||||||
|
args = [group]
|
||||||
|
args.extend(where_args)
|
||||||
|
|
||||||
|
txn.execute(sql + where_clause, args)
|
||||||
|
for row in txn:
|
||||||
|
typ, state_key, event_id = row
|
||||||
|
key = (typ, state_key)
|
||||||
|
results[group][key] = event_id
|
||||||
|
else:
|
||||||
|
max_entries_returned = state_filter.max_entries_returned()
|
||||||
|
|
||||||
|
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||||
|
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||||
|
for group in groups:
|
||||||
|
next_group = group
|
||||||
|
|
||||||
|
while next_group:
|
||||||
|
# We did this before by getting the list of group ids, and
|
||||||
|
# then passing that list to sqlite to get latest event for
|
||||||
|
# each (type, state_key). However, that was terribly slow
|
||||||
|
# without the right indices (which we can't add until
|
||||||
|
# after we finish deduping state, which requires this func)
|
||||||
|
args = [next_group]
|
||||||
|
args.extend(where_args)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"SELECT type, state_key, event_id FROM state_groups_state"
|
||||||
|
" WHERE state_group = ? " + where_clause,
|
||||||
|
args,
|
||||||
|
)
|
||||||
|
results[group].update(
|
||||||
|
((typ, state_key), event_id)
|
||||||
|
for typ, state_key, event_id in txn
|
||||||
|
if (typ, state_key) not in results[group]
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the number of entries in the (type,state_key)->event_id dict
|
||||||
|
# matches the number of (type,state_keys) types we were searching
|
||||||
|
# for, then we must have found them all, so no need to go walk
|
||||||
|
# further down the tree... UNLESS our types filter contained
|
||||||
|
# wildcards (i.e. Nones) in which case we have to do an exhaustive
|
||||||
|
# search
|
||||||
|
if (
|
||||||
|
max_entries_returned is not None
|
||||||
|
and len(results[group]) == max_entries_returned
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
next_group = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_group_edges",
|
||||||
|
keyvalues={"state_group": next_group},
|
||||||
|
retcol="prev_state_group",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
# this inherits from EventsWorkerStore because it calls self.get_events
|
# this inherits from EventsWorkerStore because it calls self.get_events
|
||||||
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
class StateGroupWorkerStore(
|
||||||
|
EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
|
||||||
|
):
|
||||||
"""The parts of StateGroupStore that can be called from workers.
|
"""The parts of StateGroupStore that can be called from workers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _get_state_groups_from_groups_txn(
|
|
||||||
self, txn, groups, state_filter=StateFilter.all()
|
|
||||||
):
|
|
||||||
results = {group: {} for group in groups}
|
|
||||||
|
|
||||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
|
||||||
|
|
||||||
# Unless the filter clause is empty, we're going to append it after an
|
|
||||||
# existing where clause
|
|
||||||
if where_clause:
|
|
||||||
where_clause = " AND (%s)" % (where_clause,)
|
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
|
||||||
# Temporarily disable sequential scans in this transaction. This is
|
|
||||||
# a temporary hack until we can add the right indices in
|
|
||||||
txn.execute("SET LOCAL enable_seqscan=off")
|
|
||||||
|
|
||||||
# The below query walks the state_group tree so that the "state"
|
|
||||||
# table includes all state_groups in the tree. It then joins
|
|
||||||
# against `state_groups_state` to fetch the latest state.
|
|
||||||
# It assumes that previous state groups are always numerically
|
|
||||||
# lesser.
|
|
||||||
# The PARTITION is used to get the event_id in the greatest state
|
|
||||||
# group for the given type, state_key.
|
|
||||||
# This may return multiple rows per (type, state_key), but last_value
|
|
||||||
# should be the same.
|
|
||||||
sql = """
|
|
||||||
WITH RECURSIVE state(state_group) AS (
|
|
||||||
VALUES(?::bigint)
|
|
||||||
UNION ALL
|
|
||||||
SELECT prev_state_group FROM state_group_edges e, state s
|
|
||||||
WHERE s.state_group = e.state_group
|
|
||||||
)
|
|
||||||
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
|
|
||||||
PARTITION BY type, state_key ORDER BY state_group ASC
|
|
||||||
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
|
|
||||||
) AS event_id FROM state_groups_state
|
|
||||||
WHERE state_group IN (
|
|
||||||
SELECT state_group FROM state
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
for group in groups:
|
|
||||||
args = [group]
|
|
||||||
args.extend(where_args)
|
|
||||||
|
|
||||||
txn.execute(sql + where_clause, args)
|
|
||||||
for row in txn:
|
|
||||||
typ, state_key, event_id = row
|
|
||||||
key = (typ, state_key)
|
|
||||||
results[group][key] = event_id
|
|
||||||
else:
|
|
||||||
max_entries_returned = state_filter.max_entries_returned()
|
|
||||||
|
|
||||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
|
||||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
|
||||||
for group in groups:
|
|
||||||
next_group = group
|
|
||||||
|
|
||||||
while next_group:
|
|
||||||
# We did this before by getting the list of group ids, and
|
|
||||||
# then passing that list to sqlite to get latest event for
|
|
||||||
# each (type, state_key). However, that was terribly slow
|
|
||||||
# without the right indices (which we can't add until
|
|
||||||
# after we finish deduping state, which requires this func)
|
|
||||||
args = [next_group]
|
|
||||||
args.extend(where_args)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
"SELECT type, state_key, event_id FROM state_groups_state"
|
|
||||||
" WHERE state_group = ? " + where_clause,
|
|
||||||
args,
|
|
||||||
)
|
|
||||||
results[group].update(
|
|
||||||
((typ, state_key), event_id)
|
|
||||||
for typ, state_key, event_id in txn
|
|
||||||
if (typ, state_key) not in results[group]
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the number of entries in the (type,state_key)->event_id dict
|
|
||||||
# matches the number of (type,state_keys) types we were searching
|
|
||||||
# for, then we must have found them all, so no need to go walk
|
|
||||||
# further down the tree... UNLESS our types filter contained
|
|
||||||
# wildcards (i.e. Nones) in which case we have to do an exhaustive
|
|
||||||
# search
|
|
||||||
if (
|
|
||||||
max_entries_returned is not None
|
|
||||||
and len(results[group]) == max_entries_returned
|
|
||||||
):
|
|
||||||
break
|
|
||||||
|
|
||||||
next_group = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="state_group_edges",
|
|
||||||
keyvalues={"state_group": next_group},
|
|
||||||
retcol="prev_state_group",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||||
"""Given a list of event_ids and type tuples, return a list of state
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
|
@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return self.runInteraction("store_state_group", _store_state_group_txn)
|
return self.runInteraction("store_state_group", _store_state_group_txn)
|
||||||
|
|
||||||
def _count_state_group_hops_txn(self, txn, state_group):
|
|
||||||
"""Given a state group, count how many hops there are in the tree.
|
|
||||||
|
|
||||||
This is used to ensure the delta chains don't get too long.
|
class StateBackgroundUpdateStore(
|
||||||
"""
|
StateGroupBackgroundUpdateStore, BackgroundUpdateStore
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
):
|
||||||
sql = """
|
|
||||||
WITH RECURSIVE state(state_group) AS (
|
|
||||||
VALUES(?::bigint)
|
|
||||||
UNION ALL
|
|
||||||
SELECT prev_state_group FROM state_group_edges e, state s
|
|
||||||
WHERE s.state_group = e.state_group
|
|
||||||
)
|
|
||||||
SELECT count(*) FROM state;
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (state_group,))
|
|
||||||
row = txn.fetchone()
|
|
||||||
if row and row[0]:
|
|
||||||
return row[0]
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
|
||||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
|
||||||
next_group = state_group
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
while next_group:
|
|
||||||
next_group = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="state_group_edges",
|
|
||||||
keyvalues={"state_group": next_group},
|
|
||||||
retcol="prev_state_group",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if next_group:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
|
||||||
""" Keeps track of the state at a given event.
|
|
||||||
|
|
||||||
This is done by the concept of `state groups`. Every event is a assigned
|
|
||||||
a state group (identified by an arbitrary string), which references a
|
|
||||||
collection of state events. The current state of an event is then the
|
|
||||||
collection of state events referenced by the event's state group.
|
|
||||||
|
|
||||||
Hence, every change in the current state causes a new state group to be
|
|
||||||
generated. However, if no change happens (e.g., if we get a message event
|
|
||||||
with only one parent it inherits the state group from its parent.)
|
|
||||||
|
|
||||||
There are three tables:
|
|
||||||
* `state_groups`: Stores group name, first event with in the group and
|
|
||||||
room id.
|
|
||||||
* `event_to_state_groups`: Maps events to state groups.
|
|
||||||
* `state_groups_state`: Maps state group to state events.
|
|
||||||
"""
|
|
||||||
|
|
||||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
||||||
|
@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||||
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
|
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(StateStore, self).__init__(db_conn, hs)
|
super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
self.register_background_update_handler(
|
self.register_background_update_handler(
|
||||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||||
self._background_deduplicate_state,
|
self._background_deduplicate_state,
|
||||||
|
@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||||
columns=["state_group"],
|
columns=["state_group"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
|
|
||||||
state_groups = {}
|
|
||||||
for event, context in events_and_contexts:
|
|
||||||
if event.internal_metadata.is_outlier():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if the event was rejected, just give it the same state as its
|
|
||||||
# predecessor.
|
|
||||||
if context.rejected:
|
|
||||||
state_groups[event.event_id] = context.prev_group
|
|
||||||
continue
|
|
||||||
|
|
||||||
state_groups[event.event_id] = context.state_group
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="event_to_state_groups",
|
|
||||||
values=[
|
|
||||||
{"state_group": state_group_id, "event_id": event_id}
|
|
||||||
for event_id, state_group_id in iteritems(state_groups)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
for event_id, state_group_id in iteritems(state_groups):
|
|
||||||
txn.call_after(
|
|
||||||
self._get_state_group_for_event.prefill, (event_id,), state_group_id
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_deduplicate_state(self, progress, batch_size):
|
def _background_deduplicate_state(self, progress, batch_size):
|
||||||
"""This background update will slowly deduplicate state by reencoding
|
"""This background update will slowly deduplicate state by reencoding
|
||||||
|
@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||||
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
|
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
|
||||||
|
""" Keeps track of the state at a given event.
|
||||||
|
|
||||||
|
This is done by the concept of `state groups`. Every event is a assigned
|
||||||
|
a state group (identified by an arbitrary string), which references a
|
||||||
|
collection of state events. The current state of an event is then the
|
||||||
|
collection of state events referenced by the event's state group.
|
||||||
|
|
||||||
|
Hence, every change in the current state causes a new state group to be
|
||||||
|
generated. However, if no change happens (e.g., if we get a message event
|
||||||
|
with only one parent it inherits the state group from its parent.)
|
||||||
|
|
||||||
|
There are three tables:
|
||||||
|
* `state_groups`: Stores group name, first event with in the group and
|
||||||
|
room id.
|
||||||
|
* `event_to_state_groups`: Maps events to state groups.
|
||||||
|
* `state_groups_state`: Maps state group to state events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(StateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
|
||||||
|
state_groups = {}
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
if event.internal_metadata.is_outlier():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if the event was rejected, just give it the same state as its
|
||||||
|
# predecessor.
|
||||||
|
if context.rejected:
|
||||||
|
state_groups[event.event_id] = context.prev_group
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_groups[event.event_id] = context.state_group
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="event_to_state_groups",
|
||||||
|
values=[
|
||||||
|
{"state_group": state_group_id, "event_id": event_id}
|
||||||
|
for event_id, state_group_id in iteritems(state_groups)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for event_id, state_group_id in iteritems(state_groups):
|
||||||
|
txn.call_after(
|
||||||
|
self._get_state_group_for_event.prefill, (event_id,), state_group_id
|
||||||
|
)
|
||||||
|
|
|
@ -32,14 +32,14 @@ logger = logging.getLogger(__name__)
|
||||||
TEMP_TABLE = "_temp_populate_user_directory"
|
TEMP_TABLE = "_temp_populate_user_directory"
|
||||||
|
|
||||||
|
|
||||||
class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
|
|
||||||
# How many records do we calculate before sending it to
|
# How many records do we calculate before sending it to
|
||||||
# add_users_who_share_private_rooms?
|
# add_users_who_share_private_rooms?
|
||||||
SHARE_PRIVATE_WORKING_SET = 500
|
SHARE_PRIVATE_WORKING_SET = 500
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(UserDirectoryStore, self).__init__(db_conn, hs)
|
super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
|
@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_from_user_dir(self, user_id):
|
|
||||||
def _remove_from_user_dir_txn(txn):
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
|
||||||
)
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn, table="user_directory_search", keyvalues={"user_id": user_id}
|
|
||||||
)
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
|
|
||||||
)
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn,
|
|
||||||
table="users_who_share_private_rooms",
|
|
||||||
keyvalues={"user_id": user_id},
|
|
||||||
)
|
|
||||||
self._simple_delete_txn(
|
|
||||||
txn,
|
|
||||||
table="users_who_share_private_rooms",
|
|
||||||
keyvalues={"other_user_id": user_id},
|
|
||||||
)
|
|
||||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
|
||||||
|
|
||||||
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_users_in_dir_due_to_room(self, room_id):
|
|
||||||
"""Get all user_ids that are in the room directory because they're
|
|
||||||
in the given room_id
|
|
||||||
"""
|
|
||||||
user_ids_share_pub = yield self._simple_select_onecol(
|
|
||||||
table="users_in_public_rooms",
|
|
||||||
keyvalues={"room_id": room_id},
|
|
||||||
retcol="user_id",
|
|
||||||
desc="get_users_in_dir_due_to_room",
|
|
||||||
)
|
|
||||||
|
|
||||||
user_ids_share_priv = yield self._simple_select_onecol(
|
|
||||||
table="users_who_share_private_rooms",
|
|
||||||
keyvalues={"room_id": room_id},
|
|
||||||
retcol="other_user_id",
|
|
||||||
desc="get_users_in_dir_due_to_room",
|
|
||||||
)
|
|
||||||
|
|
||||||
user_ids = set(user_ids_share_pub)
|
|
||||||
user_ids.update(user_ids_share_priv)
|
|
||||||
|
|
||||||
return user_ids
|
|
||||||
|
|
||||||
def add_users_who_share_private_room(self, room_id, user_id_tuples):
|
def add_users_who_share_private_room(self, room_id, user_id_tuples):
|
||||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||||
user should be a local user.
|
user should be a local user.
|
||||||
|
@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_all_from_user_dir(self):
|
||||||
|
"""Delete the entire user directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _delete_all_from_user_dir_txn(txn):
|
||||||
|
txn.execute("DELETE FROM user_directory")
|
||||||
|
txn.execute("DELETE FROM user_directory_search")
|
||||||
|
txn.execute("DELETE FROM users_in_public_rooms")
|
||||||
|
txn.execute("DELETE FROM users_who_share_private_rooms")
|
||||||
|
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_user_in_directory(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="user_directory",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcols=("display_name", "avatar_url"),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_user_in_directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_user_directory_stream_pos(self, stream_id):
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="user_directory_stream_pos",
|
||||||
|
keyvalues={},
|
||||||
|
updatevalues={"stream_id": stream_id},
|
||||||
|
desc="update_user_directory_stream_pos",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
|
|
||||||
|
# How many records do we calculate before sending it to
|
||||||
|
# add_users_who_share_private_rooms?
|
||||||
|
SHARE_PRIVATE_WORKING_SET = 500
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(UserDirectoryStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
def remove_from_user_dir(self, user_id):
|
||||||
|
def _remove_from_user_dir_txn(txn):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn, table="user_directory_search", keyvalues={"user_id": user_id}
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="users_who_share_private_rooms",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
)
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="users_who_share_private_rooms",
|
||||||
|
keyvalues={"other_user_id": user_id},
|
||||||
|
)
|
||||||
|
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||||
|
|
||||||
|
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_users_in_dir_due_to_room(self, room_id):
|
||||||
|
"""Get all user_ids that are in the room directory because they're
|
||||||
|
in the given room_id
|
||||||
|
"""
|
||||||
|
user_ids_share_pub = yield self._simple_select_onecol(
|
||||||
|
table="users_in_public_rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="user_id",
|
||||||
|
desc="get_users_in_dir_due_to_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids_share_priv = yield self._simple_select_onecol(
|
||||||
|
table="users_who_share_private_rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="other_user_id",
|
||||||
|
desc="get_users_in_dir_due_to_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids = set(user_ids_share_pub)
|
||||||
|
user_ids.update(user_ids_share_priv)
|
||||||
|
|
||||||
|
return user_ids
|
||||||
|
|
||||||
def remove_user_who_share_room(self, user_id, room_id):
|
def remove_user_who_share_room(self, user_id, room_id):
|
||||||
"""
|
"""
|
||||||
Deletes entries in the users_who_share_*_rooms table. The first
|
Deletes entries in the users_who_share_*_rooms table. The first
|
||||||
|
@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
|
|
||||||
return [room_id for room_id, in rows]
|
return [room_id for room_id, in rows]
|
||||||
|
|
||||||
def delete_all_from_user_dir(self):
|
|
||||||
"""Delete the entire user directory
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _delete_all_from_user_dir_txn(txn):
|
|
||||||
txn.execute("DELETE FROM user_directory")
|
|
||||||
txn.execute("DELETE FROM user_directory_search")
|
|
||||||
txn.execute("DELETE FROM users_in_public_rooms")
|
|
||||||
txn.execute("DELETE FROM users_who_share_private_rooms")
|
|
||||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
|
||||||
|
|
||||||
return self.runInteraction(
|
|
||||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached()
|
|
||||||
def get_user_in_directory(self, user_id):
|
|
||||||
return self._simple_select_one(
|
|
||||||
table="user_directory",
|
|
||||||
keyvalues={"user_id": user_id},
|
|
||||||
retcols=("display_name", "avatar_url"),
|
|
||||||
allow_none=True,
|
|
||||||
desc="get_user_in_directory",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_user_directory_stream_pos(self):
|
def get_user_directory_stream_pos(self):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="user_directory_stream_pos",
|
table="user_directory_stream_pos",
|
||||||
|
@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
desc="get_user_directory_stream_pos",
|
desc="get_user_directory_stream_pos",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_user_directory_stream_pos(self, stream_id):
|
|
||||||
return self._simple_update_one(
|
|
||||||
table="user_directory_stream_pos",
|
|
||||||
keyvalues={},
|
|
||||||
updatevalues={"stream_id": stream_id},
|
|
||||||
desc="update_user_directory_stream_pos",
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def search_user_dir(self, user_id, search_term, limit):
|
def search_user_dir(self, user_id, search_term, limit):
|
||||||
"""Searches for users in directory
|
"""Searches for users in directory
|
||||||
|
|
Loading…
Reference in a new issue