0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 08:43:55 +01:00

Merge pull request #6178 from matrix-org/babolivier/factor_out_bg_updates

Factor out backgroung updates
This commit is contained in:
Brendan Abolivier 2019-10-09 12:29:01 +01:00 committed by GitHub
commit 59d6290ed9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 671 additions and 605 deletions

1
changelog.d/6178.bugfix Normal file
View file

@ -0,0 +1 @@
Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database.

View file

@ -33,16 +33,9 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000 LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore): class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
self.register_background_index_update( self.register_background_index_update(
"user_ips_device_index", "user_ips_device_index",
@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"devices_last_seen", self._devices_last_seen_update "devices_last_seen", self._devices_last_seen_update
) )
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size): def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn): def f(conn):
@ -303,6 +283,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return batch_size return batch_size
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip( def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None self, user_id, access_token, ip, user_agent, device_id, now=None
@ -454,85 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for (access_token, ip), (user_agent, last_seen) in iteritems(results) for (access_token, ip), (user_agent, last_seen) in iteritems(results)
) )
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
@wrap_as_background_process("prune_old_user_ips") @wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self): async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period. """Removes entries in user IPs older than the configured period.

View file

@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs) super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
"device_inbox_stream_index", "device_inbox_stream_index",
@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
) )
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been # Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions. # deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache( self._last_device_delete_cache = ExpiringCache(
@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return self.runInteraction( return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn
) )
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1

View file

@ -512,17 +512,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results return results
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs) super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
self.register_background_index_update( self.register_background_index_update(
"device_lists_stream_idx", "device_lists_stream_idx",
@ -555,6 +547,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._drop_device_list_streams_non_unique_indexes, self._drop_device_list_streams_non_unique_indexes,
) )
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name): def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
@ -910,15 +927,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"_prune_old_outbound_device_pokes", "_prune_old_outbound_device_pokes",
_prune_txn, _prune_txn,
) )
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1

View file

@ -15,11 +15,9 @@
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
class MediaRepositoryStore(BackgroundUpdateStore): class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs) super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
update_name="local_media_repository_url_idx", update_name="local_media_repository_url_idx",
@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause="url_cache IS NOT NULL", where_clause="url_cache IS NOT NULL",
) )
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
def get_local_media(self, media_id): def get_local_media(self, media_id):
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media
Returns: Returns:

View file

@ -787,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
class RegistrationStore( class RegistrationBackgroundUpdateStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore RegistrationWorkerStore, background_updates.BackgroundUpdateStore
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.config = hs.config
self.register_background_index_update( self.register_background_index_update(
"access_tokens_device_index", "access_tokens_device_index",
@ -809,8 +810,6 @@ class RegistrationStore(
columns=["creation_ts"], columns=["creation_ts"],
) )
self._account_validity = hs.config.account_validity
# we no longer use refresh tokens, but it's possible that some people # we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just # might have a background update queued to build this index. Just
# clear the background update. # clear the background update.
@ -824,17 +823,6 @@ class RegistrationStore(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag "users_set_deactivated_flag", self._background_update_set_deactivated_flag
) )
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size): def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@ -896,6 +884,54 @@ class RegistrationStore(
return nb_processed return nb_processed
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
self._account_validity = hs.config.account_validity
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user. """Adds an access token for the given user.
@ -1244,36 +1280,6 @@ class RegistrationStore(
desc="get_users_pending_deactivation", desc="get_users_pending_deactivation",
) )
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
def validate_threepid_session(self, session_id, client_secret, token, current_ts): def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token """Attempt to validate a threepid session using a token
@ -1465,17 +1471,6 @@ class RegistrationStore(
self.clock.time_msec(), self.clock.time_msec(),
) )
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated): def set_user_deactivated_status(self, user_id, deactivated):
"""Set the `deactivated` property for the provided user to the provided value. """Set the `deactivated` property for the provided user to the provided value.
@ -1491,3 +1486,14 @@ class RegistrationStore(
user_id, user_id,
deactivated, deactivated,
) )
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)

View file

@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction from synapse.storage._base import LoggingTransaction
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -820,9 +821,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids) return set(room_ids)
class RoomMemberStore(RoomMemberWorkerStore): class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs) super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
) )
@ -838,112 +839,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
where_clause="forgotten = 1", where_clause="forgotten = 1",
) )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size): def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get( target_min_stream_id = progress.get(
@ -1078,6 +973,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
return row_count return row_count
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
class _JoinedHostsCache(object): class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates """Cache for joined hosts in a room that is optimised to handle updates
via state deltas. via state deltas.

View file

@ -36,7 +36,7 @@ SearchEntry = namedtuple(
) )
class SearchStore(BackgroundUpdateStore): class SearchBackgroundUpdateStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@ -44,7 +44,7 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs) super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
if not hs.config.enable_search: if not hs.config.enable_search:
return return
@ -289,29 +289,6 @@ class SearchStore(BackgroundUpdateStore):
return num_rows return num_rows
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
def store_search_entries_txn(self, txn, entries): def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table """Add entries to the search table
@ -358,6 +335,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys): def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.

View file

@ -353,8 +353,158 @@ class StateFilter(object):
return member_filter, non_member_filter return member_filter, non_member_filter
class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""Defines functions related to state groups needed to run the state backgroud
updates.
"""
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
# this inherits from EventsWorkerStore because it calls self.get_events # this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class StateGroupWorkerStore(
EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
):
"""The parts of StateGroupStore that can be called from workers. """The parts of StateGroupStore that can be called from workers.
""" """
@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return self.runInteraction("store_state_group", _store_state_group_txn) return self.runInteraction("store_state_group", _store_state_group_txn)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long. class StateBackgroundUpdateStore(
""" StateGroupBackgroundUpdateStore, BackgroundUpdateStore
if isinstance(self.database_engine, PostgresEngine): ):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs) super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state, self._background_deduplicate_state,
@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_group"], columns=["state_group"],
) )
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size): def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding """This background update will slowly deduplicate state by reencoding
@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1 return 1
class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)

View file

@ -32,14 +32,14 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory" TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
# How many records do we calculate before sending it to # How many records do we calculate before sending it to
# add_users_who_share_private_rooms? # add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500 SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(UserDirectoryStore, self).__init__(db_conn, hs) super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"update_profile_in_user_dir", _update_profile_in_user_dir_txn "update_profile_in_user_dir", _update_profile_in_user_dir_txn
) )
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self._simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
return user_ids
def add_users_who_share_private_room(self, room_id, user_id_tuples): def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first """Insert entries into the users_who_share_private_rooms table. The first
user should be a local user. user should be a local user.
@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"add_users_in_public_rooms", _add_users_in_public_rooms_txn "add_users_in_public_rooms", _add_users_in_public_rooms_txn
) )
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, db_conn, hs):
super(UserDirectoryStore, self).__init__(db_conn, hs)
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self._simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self._simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self._simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
return user_ids
def remove_user_who_share_room(self, user_id, room_id): def remove_user_who_share_room(self, user_id, room_id):
""" """
Deletes entries in the users_who_share_*_rooms table. The first Deletes entries in the users_who_share_*_rooms table. The first
@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return [room_id for room_id, in rows] return [room_id for room_id, in rows]
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
def get_user_directory_stream_pos(self): def get_user_directory_stream_pos(self):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="user_directory_stream_pos", table="user_directory_stream_pos",
@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
desc="get_user_directory_stream_pos", desc="get_user_directory_stream_pos",
) )
def update_user_directory_stream_pos(self, stream_id):
return self._simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_user_directory_stream_pos",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit): def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory """Searches for users in directory