mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 05:53:51 +01:00
Merge pull request #6178 from matrix-org/babolivier/factor_out_bg_updates
Factor out backgroung updates
This commit is contained in:
commit
59d6290ed9
10 changed files with 671 additions and 605 deletions
1
changelog.d/6178.bugfix
Normal file
1
changelog.d/6178.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database.
|
|
@ -33,16 +33,9 @@ logger = logging.getLogger(__name__)
|
|||
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||
|
||||
|
||||
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
||||
)
|
||||
|
||||
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.user_ips_max_age = hs.config.user_ips_max_age
|
||||
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
"user_ips_device_index",
|
||||
|
@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
"devices_last_seen", self._devices_last_seen_update
|
||||
)
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
self._batch_row_update = {}
|
||||
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
)
|
||||
self.hs.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
if self.user_ips_max_age:
|
||||
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
def f(conn):
|
||||
|
@ -303,6 +283,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
return batch_size
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _devices_last_seen_update(self, progress, batch_size):
|
||||
"""Background update to insert last seen info into devices table
|
||||
"""
|
||||
|
||||
last_user_id = progress.get("last_user_id", "")
|
||||
last_device_id = progress.get("last_device_id", "")
|
||||
|
||||
def _devices_last_seen_update_txn(txn):
|
||||
# This consists of two queries:
|
||||
#
|
||||
# 1. The sub-query searches for the next N devices and joins
|
||||
# against user_ips to find the max last_seen associated with
|
||||
# that device.
|
||||
# 2. The outer query then joins again against user_ips on
|
||||
# user/device/last_seen. This *should* hopefully only
|
||||
# return one row, but if it does return more than one then
|
||||
# we'll just end up updating the same device row multiple
|
||||
# times, which is fine.
|
||||
|
||||
if self.database_engine.supports_tuple_comparison:
|
||||
where_clause = "(user_id, device_id) > (?, ?)"
|
||||
where_args = [last_user_id, last_device_id]
|
||||
else:
|
||||
# We explicitly do a `user_id >= ? AND (...)` here to ensure
|
||||
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
|
||||
# makes it hard for query optimiser to tell that it can use the
|
||||
# index on user_id
|
||||
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
|
||||
where_args = [last_user_id, last_user_id, last_device_id]
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
last_seen, ip, user_agent, user_id, device_id
|
||||
FROM (
|
||||
SELECT
|
||||
user_id, device_id, MAX(u.last_seen) AS last_seen
|
||||
FROM devices
|
||||
INNER JOIN user_ips AS u USING (user_id, device_id)
|
||||
WHERE %(where_clause)s
|
||||
GROUP BY user_id, device_id
|
||||
ORDER BY user_id ASC, device_id ASC
|
||||
LIMIT ?
|
||||
) c
|
||||
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
|
||||
""" % {
|
||||
"where_clause": where_clause
|
||||
}
|
||||
txn.execute(sql, where_args + [batch_size])
|
||||
|
||||
rows = txn.fetchall()
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
sql = """
|
||||
UPDATE devices
|
||||
SET last_seen = ?, ip = ?, user_agent = ?
|
||||
WHERE user_id = ? AND device_id = ?
|
||||
"""
|
||||
txn.execute_batch(sql, rows)
|
||||
|
||||
_, _, _, user_id, device_id = rows[-1]
|
||||
self._background_update_progress_txn(
|
||||
txn,
|
||||
"devices_last_seen",
|
||||
{"last_user_id": user_id, "last_device_id": device_id},
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
updated = yield self.runInteraction(
|
||||
"_devices_last_seen_update", _devices_last_seen_update_txn
|
||||
)
|
||||
|
||||
if not updated:
|
||||
yield self._end_background_update("devices_last_seen")
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
||||
)
|
||||
|
||||
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.user_ips_max_age = hs.config.user_ips_max_age
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
self._batch_row_update = {}
|
||||
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
)
|
||||
self.hs.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
if self.user_ips_max_age:
|
||||
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_client_ip(
|
||||
self, user_id, access_token, ip, user_agent, device_id, now=None
|
||||
|
@ -454,85 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _devices_last_seen_update(self, progress, batch_size):
|
||||
"""Background update to insert last seen info into devices table
|
||||
"""
|
||||
|
||||
last_user_id = progress.get("last_user_id", "")
|
||||
last_device_id = progress.get("last_device_id", "")
|
||||
|
||||
def _devices_last_seen_update_txn(txn):
|
||||
# This consists of two queries:
|
||||
#
|
||||
# 1. The sub-query searches for the next N devices and joins
|
||||
# against user_ips to find the max last_seen associated with
|
||||
# that device.
|
||||
# 2. The outer query then joins again against user_ips on
|
||||
# user/device/last_seen. This *should* hopefully only
|
||||
# return one row, but if it does return more than one then
|
||||
# we'll just end up updating the same device row multiple
|
||||
# times, which is fine.
|
||||
|
||||
if self.database_engine.supports_tuple_comparison:
|
||||
where_clause = "(user_id, device_id) > (?, ?)"
|
||||
where_args = [last_user_id, last_device_id]
|
||||
else:
|
||||
# We explicitly do a `user_id >= ? AND (...)` here to ensure
|
||||
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
|
||||
# makes it hard for query optimiser to tell that it can use the
|
||||
# index on user_id
|
||||
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
|
||||
where_args = [last_user_id, last_user_id, last_device_id]
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
last_seen, ip, user_agent, user_id, device_id
|
||||
FROM (
|
||||
SELECT
|
||||
user_id, device_id, MAX(u.last_seen) AS last_seen
|
||||
FROM devices
|
||||
INNER JOIN user_ips AS u USING (user_id, device_id)
|
||||
WHERE %(where_clause)s
|
||||
GROUP BY user_id, device_id
|
||||
ORDER BY user_id ASC, device_id ASC
|
||||
LIMIT ?
|
||||
) c
|
||||
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
|
||||
""" % {
|
||||
"where_clause": where_clause
|
||||
}
|
||||
txn.execute(sql, where_args + [batch_size])
|
||||
|
||||
rows = txn.fetchall()
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
sql = """
|
||||
UPDATE devices
|
||||
SET last_seen = ?, ip = ?, user_agent = ?
|
||||
WHERE user_id = ? AND device_id = ?
|
||||
"""
|
||||
txn.execute_batch(sql, rows)
|
||||
|
||||
_, _, _, user_id, device_id = rows[-1]
|
||||
self._background_update_progress_txn(
|
||||
txn,
|
||||
"devices_last_seen",
|
||||
{"last_user_id": user_id, "last_device_id": device_id},
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
updated = yield self.runInteraction(
|
||||
"_devices_last_seen_update", _devices_last_seen_update_txn
|
||||
)
|
||||
|
||||
if not updated:
|
||||
yield self._end_background_update("devices_last_seen")
|
||||
|
||||
return updated
|
||||
|
||||
@wrap_as_background_process("prune_old_user_ips")
|
||||
async def _prune_old_user_ips(self):
|
||||
"""Removes entries in user IPs older than the configured period.
|
||||
|
|
|
@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
|
||||
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceInboxStore, self).__init__(db_conn, hs)
|
||||
super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
"device_inbox_stream_index",
|
||||
|
@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
|||
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||
def reindex_txn(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(reindex_txn)
|
||||
|
||||
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceInboxStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Map of (user_id, device_id) to the last stream_id that has been
|
||||
# deleted up to. This is so that we can no op deletions.
|
||||
self._last_device_delete_cache = ExpiringCache(
|
||||
|
@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
|||
return self.runInteraction(
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||
def reindex_txn(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(reindex_txn)
|
||||
|
||||
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
|
||||
|
||||
return 1
|
||||
|
|
|
@ -512,17 +512,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return results
|
||||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
# the device exists.
|
||||
self.device_id_exists_cache = Cache(
|
||||
name="device_id_exists", keylen=2, max_entries=10000
|
||||
)
|
||||
|
||||
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
||||
super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
"device_lists_stream_idx",
|
||||
|
@ -555,6 +547,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
|||
self._drop_device_list_streams_non_unique_indexes,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(f)
|
||||
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
|
||||
return 1
|
||||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||
# the device exists.
|
||||
self.device_id_exists_cache = Cache(
|
||||
name="device_id_exists", keylen=2, max_entries=10000
|
||||
)
|
||||
|
||||
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_device(self, user_id, device_id, initial_device_display_name):
|
||||
"""Ensure the given device is known; add it to the store if not
|
||||
|
@ -910,15 +927,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
|||
"_prune_old_outbound_device_pokes",
|
||||
_prune_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(f)
|
||||
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
|
||||
return 1
|
||||
|
|
|
@ -15,11 +15,9 @@
|
|||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
|
||||
|
||||
class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
"""Persistence for attachments and avatars"""
|
||||
|
||||
class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(MediaRepositoryStore, self).__init__(db_conn, hs)
|
||||
super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
update_name="local_media_repository_url_idx",
|
||||
|
@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||
where_clause="url_cache IS NOT NULL",
|
||||
)
|
||||
|
||||
|
||||
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
"""Persistence for attachments and avatars"""
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(MediaRepositoryStore, self).__init__(db_conn, hs)
|
||||
|
||||
def get_local_media(self, media_id):
|
||||
"""Get the metadata for a local piece of media
|
||||
Returns:
|
||||
|
|
|
@ -787,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
|
||||
class RegistrationStore(
|
||||
class RegistrationBackgroundUpdateStore(
|
||||
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
|
||||
):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||
super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.config = hs.config
|
||||
|
||||
self.register_background_index_update(
|
||||
"access_tokens_device_index",
|
||||
|
@ -809,8 +810,6 @@ class RegistrationStore(
|
|||
columns=["creation_ts"],
|
||||
)
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
|
||||
# we no longer use refresh tokens, but it's possible that some people
|
||||
# might have a background update queued to build this index. Just
|
||||
# clear the background update.
|
||||
|
@ -824,17 +823,6 @@ class RegistrationStore(
|
|||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
)
|
||||
|
||||
# Create a background job for culling expired 3PID validity tokens
|
||||
def start_cull():
|
||||
# run as a background process to make sure that the database transactions
|
||||
# have a logcontext to report to
|
||||
return run_as_background_process(
|
||||
"cull_expired_threepid_validation_tokens",
|
||||
self.cull_expired_threepid_validation_tokens,
|
||||
)
|
||||
|
||||
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||
|
@ -896,6 +884,54 @@ class RegistrationStore(
|
|||
|
||||
return nb_processed
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _bg_user_threepids_grandfather(self, progress, batch_size):
|
||||
"""We now track which identity servers a user binds their 3PID to, so
|
||||
we need to handle the case of existing bindings where we didn't track
|
||||
this.
|
||||
|
||||
We do this by grandfathering in existing user threepids assuming that
|
||||
they used one of the server configured trusted identity servers.
|
||||
"""
|
||||
id_servers = set(self.config.trusted_third_party_id_servers)
|
||||
|
||||
def _bg_user_threepids_grandfather_txn(txn):
|
||||
sql = """
|
||||
INSERT INTO user_threepid_id_server
|
||||
(user_id, medium, address, id_server)
|
||||
SELECT user_id, medium, address, ?
|
||||
FROM user_threepids
|
||||
"""
|
||||
|
||||
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
||||
|
||||
if id_servers:
|
||||
yield self.runInteraction(
|
||||
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
|
||||
)
|
||||
|
||||
yield self._end_background_update("user_threepids_grandfather")
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
|
||||
# Create a background job for culling expired 3PID validity tokens
|
||||
def start_cull():
|
||||
# run as a background process to make sure that the database transactions
|
||||
# have a logcontext to report to
|
||||
return run_as_background_process(
|
||||
"cull_expired_threepid_validation_tokens",
|
||||
self.cull_expired_threepid_validation_tokens,
|
||||
)
|
||||
|
||||
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
|
||||
"""Adds an access token for the given user.
|
||||
|
@ -1244,36 +1280,6 @@ class RegistrationStore(
|
|||
desc="get_users_pending_deactivation",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _bg_user_threepids_grandfather(self, progress, batch_size):
|
||||
"""We now track which identity servers a user binds their 3PID to, so
|
||||
we need to handle the case of existing bindings where we didn't track
|
||||
this.
|
||||
|
||||
We do this by grandfathering in existing user threepids assuming that
|
||||
they used one of the server configured trusted identity servers.
|
||||
"""
|
||||
id_servers = set(self.config.trusted_third_party_id_servers)
|
||||
|
||||
def _bg_user_threepids_grandfather_txn(txn):
|
||||
sql = """
|
||||
INSERT INTO user_threepid_id_server
|
||||
(user_id, medium, address, id_server)
|
||||
SELECT user_id, medium, address, ?
|
||||
FROM user_threepids
|
||||
"""
|
||||
|
||||
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
||||
|
||||
if id_servers:
|
||||
yield self.runInteraction(
|
||||
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
|
||||
)
|
||||
|
||||
yield self._end_background_update("user_threepids_grandfather")
|
||||
|
||||
return 1
|
||||
|
||||
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
|
||||
"""Attempt to validate a threepid session using a token
|
||||
|
||||
|
@ -1465,17 +1471,6 @@ class RegistrationStore(
|
|||
self.clock.time_msec(),
|
||||
)
|
||||
|
||||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
||||
self._simple_update_one_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_deactivated_status, (user_id,)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_user_deactivated_status(self, user_id, deactivated):
|
||||
"""Set the `deactivated` property for the provided user to the provided value.
|
||||
|
@ -1491,3 +1486,14 @@ class RegistrationStore(
|
|||
user_id,
|
||||
deactivated,
|
||||
)
|
||||
|
||||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
||||
self._simple_update_one_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_deactivated_status, (user_id,)
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership
|
|||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import LoggingTransaction
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.storage.engines import Sqlite3Engine
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import get_domain_from_id
|
||||
|
@ -820,9 +821,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return set(room_ids)
|
||||
|
||||
|
||||
class RoomMemberStore(RoomMemberWorkerStore):
|
||||
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RoomMemberStore, self).__init__(db_conn, hs)
|
||||
super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
|
||||
)
|
||||
|
@ -838,112 +839,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
|
|||
where_clause="forgotten = 1",
|
||||
)
|
||||
|
||||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database.
|
||||
"""
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="room_memberships",
|
||||
values=[
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"user_id": event.state_key,
|
||||
"sender": event.user_id,
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"display_name": event.content.get("displayname", None),
|
||||
"avatar_url": event.content.get("avatar_url", None),
|
||||
}
|
||||
for event in events
|
||||
],
|
||||
)
|
||||
|
||||
for event in events:
|
||||
txn.call_after(
|
||||
self._membership_stream_cache.entity_has_changed,
|
||||
event.state_key,
|
||||
event.internal_metadata.stream_ordering,
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
||||
)
|
||||
|
||||
# We update the local_invites table only if the event is "current",
|
||||
# i.e., its something that has just happened. If the event is an
|
||||
# outlier it is only current if its an "out of band membership",
|
||||
# like a remote invite or a rejection of a remote invite.
|
||||
is_new_state = not backfilled and (
|
||||
not event.internal_metadata.is_outlier()
|
||||
or event.internal_metadata.is_out_of_band_membership()
|
||||
)
|
||||
is_mine = self.hs.is_mine_id(event.state_key)
|
||||
if is_new_state and is_mine:
|
||||
if event.membership == Membership.INVITE:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="local_invites",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"invitee": event.state_key,
|
||||
"inviter": event.sender,
|
||||
"room_id": event.room_id,
|
||||
"stream_id": event.internal_metadata.stream_ordering,
|
||||
},
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.state_key,
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def locally_reject_invite(self, user_id, room_id):
|
||||
sql = (
|
||||
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
)
|
||||
|
||||
def f(txn, stream_ordering):
|
||||
txn.execute(sql, (stream_ordering, True, room_id, user_id))
|
||||
|
||||
with self._stream_id_gen.get_next() as stream_ordering:
|
||||
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||
|
||||
def forget(self, user_id, room_id):
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"UPDATE"
|
||||
" room_memberships"
|
||||
" SET"
|
||||
" forgotten = 1"
|
||||
" WHERE"
|
||||
" user_id = ?"
|
||||
" AND"
|
||||
" room_id = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
||||
)
|
||||
|
||||
return self.runInteraction("forget_membership", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_add_membership_profile(self, progress, batch_size):
|
||||
target_min_stream_id = progress.get(
|
||||
|
@ -1078,6 +973,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
|
|||
return row_count
|
||||
|
||||
|
||||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RoomMemberStore, self).__init__(db_conn, hs)
|
||||
|
||||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database.
|
||||
"""
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="room_memberships",
|
||||
values=[
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"user_id": event.state_key,
|
||||
"sender": event.user_id,
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
"display_name": event.content.get("displayname", None),
|
||||
"avatar_url": event.content.get("avatar_url", None),
|
||||
}
|
||||
for event in events
|
||||
],
|
||||
)
|
||||
|
||||
for event in events:
|
||||
txn.call_after(
|
||||
self._membership_stream_cache.entity_has_changed,
|
||||
event.state_key,
|
||||
event.internal_metadata.stream_ordering,
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
|
||||
)
|
||||
|
||||
# We update the local_invites table only if the event is "current",
|
||||
# i.e., its something that has just happened. If the event is an
|
||||
# outlier it is only current if its an "out of band membership",
|
||||
# like a remote invite or a rejection of a remote invite.
|
||||
is_new_state = not backfilled and (
|
||||
not event.internal_metadata.is_outlier()
|
||||
or event.internal_metadata.is_out_of_band_membership()
|
||||
)
|
||||
is_mine = self.hs.is_mine_id(event.state_key)
|
||||
if is_new_state and is_mine:
|
||||
if event.membership == Membership.INVITE:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="local_invites",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"invitee": event.state_key,
|
||||
"inviter": event.sender,
|
||||
"room_id": event.room_id,
|
||||
"stream_id": event.internal_metadata.stream_ordering,
|
||||
},
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.state_key,
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def locally_reject_invite(self, user_id, room_id):
|
||||
sql = (
|
||||
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
)
|
||||
|
||||
def f(txn, stream_ordering):
|
||||
txn.execute(sql, (stream_ordering, True, room_id, user_id))
|
||||
|
||||
with self._stream_id_gen.get_next() as stream_ordering:
|
||||
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||
|
||||
def forget(self, user_id, room_id):
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"UPDATE"
|
||||
" room_memberships"
|
||||
" SET"
|
||||
" forgotten = 1"
|
||||
" WHERE"
|
||||
" user_id = ?"
|
||||
" AND"
|
||||
" room_id = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
||||
)
|
||||
|
||||
return self.runInteraction("forget_membership", f)
|
||||
|
||||
|
||||
class _JoinedHostsCache(object):
|
||||
"""Cache for joined hosts in a room that is optimised to handle updates
|
||||
via state deltas.
|
||||
|
|
|
@ -36,7 +36,7 @@ SearchEntry = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
class SearchStore(BackgroundUpdateStore):
|
||||
class SearchBackgroundUpdateStore(BackgroundUpdateStore):
|
||||
|
||||
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
||||
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||
|
@ -44,7 +44,7 @@ class SearchStore(BackgroundUpdateStore):
|
|||
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SearchStore, self).__init__(db_conn, hs)
|
||||
super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
if not hs.config.enable_search:
|
||||
return
|
||||
|
@ -289,29 +289,6 @@ class SearchStore(BackgroundUpdateStore):
|
|||
|
||||
return num_rows
|
||||
|
||||
def store_event_search_txn(self, txn, event, key, value):
|
||||
"""Add event to the search table
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
event (EventBase):
|
||||
key (str):
|
||||
value (str):
|
||||
"""
|
||||
self.store_search_entries_txn(
|
||||
txn,
|
||||
(
|
||||
SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event.event_id,
|
||||
room_id=event.room_id,
|
||||
stream_ordering=event.internal_metadata.stream_ordering,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def store_search_entries_txn(self, txn, entries):
|
||||
"""Add entries to the search table
|
||||
|
||||
|
@ -358,6 +335,34 @@ class SearchStore(BackgroundUpdateStore):
|
|||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
|
||||
class SearchStore(SearchBackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SearchStore, self).__init__(db_conn, hs)
|
||||
|
||||
def store_event_search_txn(self, txn, event, key, value):
|
||||
"""Add event to the search table
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
event (EventBase):
|
||||
key (str):
|
||||
value (str):
|
||||
"""
|
||||
self.store_search_entries_txn(
|
||||
txn,
|
||||
(
|
||||
SearchEntry(
|
||||
key=key,
|
||||
value=value,
|
||||
event_id=event.event_id,
|
||||
room_id=event.room_id,
|
||||
stream_ordering=event.internal_metadata.stream_ordering,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_msgs(self, room_ids, search_term, keys):
|
||||
"""Performs a full text search over events with given keys.
|
||||
|
|
|
@ -353,8 +353,158 @@ class StateFilter(object):
|
|||
return member_filter, non_member_filter
|
||||
|
||||
|
||||
class StateGroupBackgroundUpdateStore(SQLBaseStore):
|
||||
"""Defines functions related to state groups needed to run the state backgroud
|
||||
updates.
|
||||
"""
|
||||
|
||||
def _count_state_group_hops_txn(self, txn, state_group):
|
||||
"""Given a state group, count how many hops there are in the tree.
|
||||
|
||||
This is used to ensure the delta chains don't get too long.
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = """
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
SELECT prev_state_group FROM state_group_edges e, state s
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT count(*) FROM state;
|
||||
"""
|
||||
|
||||
txn.execute(sql, (state_group,))
|
||||
row = txn.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0]
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
next_group = state_group
|
||||
count = 0
|
||||
|
||||
while next_group:
|
||||
next_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": next_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if next_group:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _get_state_groups_from_groups_txn(
|
||||
self, txn, groups, state_filter=StateFilter.all()
|
||||
):
|
||||
results = {group: {} for group in groups}
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
|
||||
# Unless the filter clause is empty, we're going to append it after an
|
||||
# existing where clause
|
||||
if where_clause:
|
||||
where_clause = " AND (%s)" % (where_clause,)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# Temporarily disable sequential scans in this transaction. This is
|
||||
# a temporary hack until we can add the right indices in
|
||||
txn.execute("SET LOCAL enable_seqscan=off")
|
||||
|
||||
# The below query walks the state_group tree so that the "state"
|
||||
# table includes all state_groups in the tree. It then joins
|
||||
# against `state_groups_state` to fetch the latest state.
|
||||
# It assumes that previous state groups are always numerically
|
||||
# lesser.
|
||||
# The PARTITION is used to get the event_id in the greatest state
|
||||
# group for the given type, state_key.
|
||||
# This may return multiple rows per (type, state_key), but last_value
|
||||
# should be the same.
|
||||
sql = """
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
SELECT prev_state_group FROM state_group_edges e, state s
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
|
||||
PARTITION BY type, state_key ORDER BY state_group ASC
|
||||
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
|
||||
) AS event_id FROM state_groups_state
|
||||
WHERE state_group IN (
|
||||
SELECT state_group FROM state
|
||||
)
|
||||
"""
|
||||
|
||||
for group in groups:
|
||||
args = [group]
|
||||
args.extend(where_args)
|
||||
|
||||
txn.execute(sql + where_clause, args)
|
||||
for row in txn:
|
||||
typ, state_key, event_id = row
|
||||
key = (typ, state_key)
|
||||
results[group][key] = event_id
|
||||
else:
|
||||
max_entries_returned = state_filter.max_entries_returned()
|
||||
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
for group in groups:
|
||||
next_group = group
|
||||
|
||||
while next_group:
|
||||
# We did this before by getting the list of group ids, and
|
||||
# then passing that list to sqlite to get latest event for
|
||||
# each (type, state_key). However, that was terribly slow
|
||||
# without the right indices (which we can't add until
|
||||
# after we finish deduping state, which requires this func)
|
||||
args = [next_group]
|
||||
args.extend(where_args)
|
||||
|
||||
txn.execute(
|
||||
"SELECT type, state_key, event_id FROM state_groups_state"
|
||||
" WHERE state_group = ? " + where_clause,
|
||||
args,
|
||||
)
|
||||
results[group].update(
|
||||
((typ, state_key), event_id)
|
||||
for typ, state_key, event_id in txn
|
||||
if (typ, state_key) not in results[group]
|
||||
)
|
||||
|
||||
# If the number of entries in the (type,state_key)->event_id dict
|
||||
# matches the number of (type,state_keys) types we were searching
|
||||
# for, then we must have found them all, so no need to go walk
|
||||
# further down the tree... UNLESS our types filter contained
|
||||
# wildcards (i.e. Nones) in which case we have to do an exhaustive
|
||||
# search
|
||||
if (
|
||||
max_entries_returned is not None
|
||||
and len(results[group]) == max_entries_returned
|
||||
):
|
||||
break
|
||||
|
||||
next_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": next_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# this inherits from EventsWorkerStore because it calls self.get_events
|
||||
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
class StateGroupWorkerStore(
|
||||
EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
|
||||
):
|
||||
"""The parts of StateGroupStore that can be called from workers.
|
||||
"""
|
||||
|
||||
|
@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
def _get_state_groups_from_groups_txn(
|
||||
self, txn, groups, state_filter=StateFilter.all()
|
||||
):
|
||||
results = {group: {} for group in groups}
|
||||
|
||||
where_clause, where_args = state_filter.make_sql_filter_clause()
|
||||
|
||||
# Unless the filter clause is empty, we're going to append it after an
|
||||
# existing where clause
|
||||
if where_clause:
|
||||
where_clause = " AND (%s)" % (where_clause,)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# Temporarily disable sequential scans in this transaction. This is
|
||||
# a temporary hack until we can add the right indices in
|
||||
txn.execute("SET LOCAL enable_seqscan=off")
|
||||
|
||||
# The below query walks the state_group tree so that the "state"
|
||||
# table includes all state_groups in the tree. It then joins
|
||||
# against `state_groups_state` to fetch the latest state.
|
||||
# It assumes that previous state groups are always numerically
|
||||
# lesser.
|
||||
# The PARTITION is used to get the event_id in the greatest state
|
||||
# group for the given type, state_key.
|
||||
# This may return multiple rows per (type, state_key), but last_value
|
||||
# should be the same.
|
||||
sql = """
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
SELECT prev_state_group FROM state_group_edges e, state s
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
|
||||
PARTITION BY type, state_key ORDER BY state_group ASC
|
||||
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
|
||||
) AS event_id FROM state_groups_state
|
||||
WHERE state_group IN (
|
||||
SELECT state_group FROM state
|
||||
)
|
||||
"""
|
||||
|
||||
for group in groups:
|
||||
args = [group]
|
||||
args.extend(where_args)
|
||||
|
||||
txn.execute(sql + where_clause, args)
|
||||
for row in txn:
|
||||
typ, state_key, event_id = row
|
||||
key = (typ, state_key)
|
||||
results[group][key] = event_id
|
||||
else:
|
||||
max_entries_returned = state_filter.max_entries_returned()
|
||||
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
for group in groups:
|
||||
next_group = group
|
||||
|
||||
while next_group:
|
||||
# We did this before by getting the list of group ids, and
|
||||
# then passing that list to sqlite to get latest event for
|
||||
# each (type, state_key). However, that was terribly slow
|
||||
# without the right indices (which we can't add until
|
||||
# after we finish deduping state, which requires this func)
|
||||
args = [next_group]
|
||||
args.extend(where_args)
|
||||
|
||||
txn.execute(
|
||||
"SELECT type, state_key, event_id FROM state_groups_state"
|
||||
" WHERE state_group = ? " + where_clause,
|
||||
args,
|
||||
)
|
||||
results[group].update(
|
||||
((typ, state_key), event_id)
|
||||
for typ, state_key, event_id in txn
|
||||
if (typ, state_key) not in results[group]
|
||||
)
|
||||
|
||||
# If the number of entries in the (type,state_key)->event_id dict
|
||||
# matches the number of (type,state_keys) types we were searching
|
||||
# for, then we must have found them all, so no need to go walk
|
||||
# further down the tree... UNLESS our types filter contained
|
||||
# wildcards (i.e. Nones) in which case we have to do an exhaustive
|
||||
# search
|
||||
if (
|
||||
max_entries_returned is not None
|
||||
and len(results[group]) == max_entries_returned
|
||||
):
|
||||
break
|
||||
|
||||
next_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": next_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
|
@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return self.runInteraction("store_state_group", _store_state_group_txn)
|
||||
|
||||
def _count_state_group_hops_txn(self, txn, state_group):
|
||||
"""Given a state group, count how many hops there are in the tree.
|
||||
|
||||
This is used to ensure the delta chains don't get too long.
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = """
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
SELECT prev_state_group FROM state_group_edges e, state s
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT count(*) FROM state;
|
||||
"""
|
||||
|
||||
txn.execute(sql, (state_group,))
|
||||
row = txn.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0]
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
next_group = state_group
|
||||
count = 0
|
||||
|
||||
while next_group:
|
||||
next_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": next_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if next_group:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||
""" Keeps track of the state at a given event.
|
||||
|
||||
This is done by the concept of `state groups`. Every event is a assigned
|
||||
a state group (identified by an arbitrary string), which references a
|
||||
collection of state events. The current state of an event is then the
|
||||
collection of state events referenced by the event's state group.
|
||||
|
||||
Hence, every change in the current state causes a new state group to be
|
||||
generated. However, if no change happens (e.g., if we get a message event
|
||||
with only one parent it inherits the state group from its parent.)
|
||||
|
||||
There are three tables:
|
||||
* `state_groups`: Stores group name, first event with in the group and
|
||||
room id.
|
||||
* `event_to_state_groups`: Maps events to state groups.
|
||||
* `state_groups_state`: Maps state group to state events.
|
||||
"""
|
||||
class StateBackgroundUpdateStore(
|
||||
StateGroupBackgroundUpdateStore, BackgroundUpdateStore
|
||||
):
|
||||
|
||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
||||
|
@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
|||
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateStore, self).__init__(db_conn, hs)
|
||||
super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||
self._background_deduplicate_state,
|
||||
|
@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
|||
columns=["state_group"],
|
||||
)
|
||||
|
||||
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
|
||||
# if the event was rejected, just give it the same state as its
|
||||
# predecessor.
|
||||
if context.rejected:
|
||||
state_groups[event.event_id] = context.prev_group
|
||||
continue
|
||||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values=[
|
||||
{"state_group": state_group_id, "event_id": event_id}
|
||||
for event_id, state_group_id in iteritems(state_groups)
|
||||
],
|
||||
)
|
||||
|
||||
for event_id, state_group_id in iteritems(state_groups):
|
||||
txn.call_after(
|
||||
self._get_state_group_for_event.prefill, (event_id,), state_group_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_deduplicate_state(self, progress, batch_size):
|
||||
"""This background update will slowly deduplicate state by reencoding
|
||||
|
@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
|||
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
|
||||
""" Keeps track of the state at a given event.
|
||||
|
||||
This is done by the concept of `state groups`. Every event is a assigned
|
||||
a state group (identified by an arbitrary string), which references a
|
||||
collection of state events. The current state of an event is then the
|
||||
collection of state events referenced by the event's state group.
|
||||
|
||||
Hence, every change in the current state causes a new state group to be
|
||||
generated. However, if no change happens (e.g., if we get a message event
|
||||
with only one parent it inherits the state group from its parent.)
|
||||
|
||||
There are three tables:
|
||||
* `state_groups`: Stores group name, first event with in the group and
|
||||
room id.
|
||||
* `event_to_state_groups`: Maps events to state groups.
|
||||
* `state_groups_state`: Maps state group to state events.
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateStore, self).__init__(db_conn, hs)
|
||||
|
||||
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
|
||||
# if the event was rejected, just give it the same state as its
|
||||
# predecessor.
|
||||
if context.rejected:
|
||||
state_groups[event.event_id] = context.prev_group
|
||||
continue
|
||||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values=[
|
||||
{"state_group": state_group_id, "event_id": event_id}
|
||||
for event_id, state_group_id in iteritems(state_groups)
|
||||
],
|
||||
)
|
||||
|
||||
for event_id, state_group_id in iteritems(state_groups):
|
||||
txn.call_after(
|
||||
self._get_state_group_for_event.prefill, (event_id,), state_group_id
|
||||
)
|
||||
|
|
|
@ -32,14 +32,14 @@ logger = logging.getLogger(__name__)
|
|||
TEMP_TABLE = "_temp_populate_user_directory"
|
||||
|
||||
|
||||
class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||
class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
|
||||
|
||||
# How many records do we calculate before sending it to
|
||||
# add_users_who_share_private_rooms?
|
||||
SHARE_PRIVATE_WORKING_SET = 500
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(UserDirectoryStore, self).__init__(db_conn, hs)
|
||||
super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.server_name = hs.hostname
|
||||
|
||||
|
@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
|||
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
||||
)
|
||||
|
||||
def remove_from_user_dir(self, user_id):
|
||||
def _remove_from_user_dir_txn(txn):
|
||||
self._simple_delete_txn(
|
||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn, table="user_directory_search", keyvalues={"user_id": user_id}
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"other_user_id": user_id},
|
||||
)
|
||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||
|
||||
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_dir_due_to_room(self, room_id):
|
||||
"""Get all user_ids that are in the room directory because they're
|
||||
in the given room_id
|
||||
"""
|
||||
user_ids_share_pub = yield self._simple_select_onecol(
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="user_id",
|
||||
desc="get_users_in_dir_due_to_room",
|
||||
)
|
||||
|
||||
user_ids_share_priv = yield self._simple_select_onecol(
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="other_user_id",
|
||||
desc="get_users_in_dir_due_to_room",
|
||||
)
|
||||
|
||||
user_ids = set(user_ids_share_pub)
|
||||
user_ids.update(user_ids_share_priv)
|
||||
|
||||
return user_ids
|
||||
|
||||
def add_users_who_share_private_room(self, room_id, user_id_tuples):
|
||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||
user should be a local user.
|
||||
|
@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
|||
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
||||
)
|
||||
|
||||
def delete_all_from_user_dir(self):
|
||||
"""Delete the entire user directory
|
||||
"""
|
||||
|
||||
def _delete_all_from_user_dir_txn(txn):
|
||||
txn.execute("DELETE FROM user_directory")
|
||||
txn.execute("DELETE FROM user_directory_search")
|
||||
txn.execute("DELETE FROM users_in_public_rooms")
|
||||
txn.execute("DELETE FROM users_who_share_private_rooms")
|
||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_in_directory(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_user_in_directory",
|
||||
)
|
||||
|
||||
def update_user_directory_stream_pos(self, stream_id):
|
||||
return self._simple_update_one(
|
||||
table="user_directory_stream_pos",
|
||||
keyvalues={},
|
||||
updatevalues={"stream_id": stream_id},
|
||||
desc="update_user_directory_stream_pos",
|
||||
)
|
||||
|
||||
|
||||
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
|
||||
# How many records do we calculate before sending it to
|
||||
# add_users_who_share_private_rooms?
|
||||
SHARE_PRIVATE_WORKING_SET = 500
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(UserDirectoryStore, self).__init__(db_conn, hs)
|
||||
|
||||
def remove_from_user_dir(self, user_id):
|
||||
def _remove_from_user_dir_txn(txn):
|
||||
self._simple_delete_txn(
|
||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn, table="user_directory_search", keyvalues={"user_id": user_id}
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"other_user_id": user_id},
|
||||
)
|
||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||
|
||||
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_dir_due_to_room(self, room_id):
|
||||
"""Get all user_ids that are in the room directory because they're
|
||||
in the given room_id
|
||||
"""
|
||||
user_ids_share_pub = yield self._simple_select_onecol(
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="user_id",
|
||||
desc="get_users_in_dir_due_to_room",
|
||||
)
|
||||
|
||||
user_ids_share_priv = yield self._simple_select_onecol(
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="other_user_id",
|
||||
desc="get_users_in_dir_due_to_room",
|
||||
)
|
||||
|
||||
user_ids = set(user_ids_share_pub)
|
||||
user_ids.update(user_ids_share_priv)
|
||||
|
||||
return user_ids
|
||||
|
||||
def remove_user_who_share_room(self, user_id, room_id):
|
||||
"""
|
||||
Deletes entries in the users_who_share_*_rooms table. The first
|
||||
|
@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
|||
|
||||
return [room_id for room_id, in rows]
|
||||
|
||||
def delete_all_from_user_dir(self):
|
||||
"""Delete the entire user directory
|
||||
"""
|
||||
|
||||
def _delete_all_from_user_dir_txn(txn):
|
||||
txn.execute("DELETE FROM user_directory")
|
||||
txn.execute("DELETE FROM user_directory_search")
|
||||
txn.execute("DELETE FROM users_in_public_rooms")
|
||||
txn.execute("DELETE FROM users_who_share_private_rooms")
|
||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_in_directory(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_user_in_directory",
|
||||
)
|
||||
|
||||
def get_user_directory_stream_pos(self):
|
||||
return self._simple_select_one_onecol(
|
||||
table="user_directory_stream_pos",
|
||||
|
@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
|||
desc="get_user_directory_stream_pos",
|
||||
)
|
||||
|
||||
def update_user_directory_stream_pos(self, stream_id):
|
||||
return self._simple_update_one(
|
||||
table="user_directory_stream_pos",
|
||||
keyvalues={},
|
||||
updatevalues={"stream_id": stream_id},
|
||||
desc="update_user_directory_stream_pos",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_user_dir(self, user_id, search_term, limit):
|
||||
"""Searches for users in directory
|
||||
|
|
Loading…
Reference in a new issue