0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 21:33:20 +01:00

Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444)

This commit is contained in:
Patrick Cloke 2023-10-11 13:24:56 -04:00 committed by GitHub
parent d6b7d49a61
commit a4904dcb04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 641 additions and 443 deletions

1
changelog.d/16444.misc Normal file
View file

@ -0,0 +1 @@
Reduce memory allocations.

View file

@ -1874,9 +1874,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows.
Filters rows by whether the value of `column` is in `iterable`.
@ -1888,10 +1888,13 @@ class DatabasePool:
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query
Returns:
The results as a list of tuples.
"""
keyvalues = keyvalues or {}
results: List[Dict[str, Any]] = []
results: List[Tuple[Any, ...]] = []
for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
@ -1918,9 +1921,9 @@ class DatabasePool:
iterable: Collection[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows.
Filters rows by whether the value of `column` is in `iterable`.
@ -1931,6 +1934,9 @@ class DatabasePool:
iterable: list
keyvalues: dict of column names and values to select the rows with
retcols: list of strings giving the names of the columns to return
Returns:
The results as a list of tuples.
"""
if not iterable:
return []
@ -1949,7 +1955,7 @@ class DatabasePool:
)
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
return txn.fetchall()
async def simple_update(
self,

View file

@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
user_device_dicts = self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
),
)
device_ids_to_query.update(
{row["device_id"] for row in user_device_dicts}
)
device_ids_to_query.update({row[0] for row in user_device_dicts})
if not device_ids_to_query:
# We've ended up with no devices to query.
@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We exclude hidden devices (such as cross-signing keys) here as they are
# not expected to receive to-device messages.
rows = self.db_pool.simple_select_many_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
),
)
for row in rows:
for (device_id,) in rows:
# Only insert into the local inbox if the device exists on
# this server
device_id = row["device_id"]
with start_active_span("serialise_to_device_message"):
msg = messages_by_device[device_id]
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])

View file

@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes",
),
)
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
results.update(rows)
return results
@ -1077,22 +1080,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync_with_iterable",
)
else:
rows = await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync",
row_tuples = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)
return {row["user_id"] for row in rows}
return {row[0] for row in row_tuples}
else:
rows = cast(
List[Dict[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync",
),
)
return {row["user_id"] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection

View file

@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A map from (algorithm, key_id) to json string for key
"""
rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
),
)
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result

View file

@ -1049,15 +1049,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
),
desc="get_max_depth_of",
),
desc="get_max_depth_of",
)
if not rows:
@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
max_depth_event_id = ""
current_max_depth = 0
for row in rows:
if row["depth"] > current_max_depth:
max_depth_event_id = row["event_id"]
current_max_depth = row["depth"]
for event_id, depth in rows:
if depth > current_max_depth:
max_depth_event_id = event_id
current_max_depth = depth
return max_depth_event_id, current_max_depth
@ -1078,15 +1081,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
),
desc="get_min_depth_of",
),
desc="get_min_depth_of",
)
if not rows:
@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
min_depth_event_id = ""
current_min_depth = MAX_DEPTH
for row in rows:
if row["depth"] < current_min_depth:
min_depth_event_id = row["event_id"]
current_min_depth = row["depth"]
for event_id, depth in rows:
if depth < current_min_depth:
min_depth_event_id = event_id
current_min_depth = depth
return min_depth_event_id, current_min_depth
@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A filtered down list of `event_ids` that have previous failed pull attempts.
"""
rows = await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id",),
desc="get_event_ids_with_failed_pull_attempts",
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id",),
desc="get_event_ids_with_failed_pull_attempts",
),
)
event_ids_with_failed_pull_attempts: Set[str] = {
row["event_id"] for row in rows
}
return event_ids_with_failed_pull_attempts
return {row[0] for row in rows}
@trace
async def get_event_ids_to_not_pull_from_backoff(
@ -1585,32 +1590,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again.
"""
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=(
"event_id",
"last_attempt_ts",
"num_attempts",
event_failed_pull_attempts = cast(
List[Tuple[str, int, int]],
await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=(
"event_id",
"last_attempt_ts",
"num_attempts",
),
desc="get_event_ids_to_not_pull_from_backoff",
),
desc="get_event_ids_to_not_pull_from_backoff",
)
current_time = self._clock.time_msec()
event_ids_with_backoff = {}
for event_failed_pull_attempt in event_failed_pull_attempts:
event_id = event_failed_pull_attempt["event_id"]
for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
backoff_end_time = (
event_failed_pull_attempt["last_attempt_ts"]
last_attempt_ts
+ (
2
** min(
event_failed_pull_attempt["num_attempts"],
num_attempts,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
)
)

View file

@ -27,6 +27,7 @@ from typing import (
Optional,
Set,
Tuple,
Union,
cast,
)
@ -501,16 +502,19 @@ class PersistEventsStore:
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
rows = self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
rows = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
),
)
rooms_using_chain_index = {
row["room_id"] for row in rows if row["has_auth_chain_index"]
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
@ -571,19 +575,18 @@ class PersistEventsStore:
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
rows = db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"),
auth_chain_to_calc_rows = cast(
List[Tuple[str, str, str]],
db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"),
),
)
for row in rows:
event_id = row["event_id"]
event_type = row["type"]
state_key = row["state_key"]
for event_id, event_type, state_key in auth_chain_to_calc_rows:
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
@ -753,23 +756,31 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
rows = db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
iterable={chain_id for chain_id, _ in chain_map.values()},
keyvalues={},
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
auth_chain_rows = cast(
List[Tuple[int, int, int, int]],
db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
iterable={chain_id for chain_id, _ in chain_map.values()},
keyvalues={},
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
),
)
for row in rows:
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
chain_links.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]),
(row["target_chain_id"], row["target_sequence_number"]),
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
new=False,
)

View file

@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
ev_rows = self.db_pool.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
iterable=chunk,
retcols=["event_id", "json"],
keyvalues={},
ev_rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
iterable=chunk,
retcols=["event_id", "json"],
keyvalues={},
),
)
for row in ev_rows:
event_id = row["event_id"]
event_json = db_to_json(row["json"])
for event_id, json in ev_rows:
event_json = db_to_json(json)
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted:
# We now need to invalidate the caches of these rooms
rows = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=to_delete,
keyvalues={},
retcols=("room_id",),
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=to_delete,
keyvalues={},
retcols=("room_id",),
),
)
room_ids = {row["room_id"] for row in rows}
room_ids = {row[0] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
count = len(rows)
# We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
auth_events = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
),
)
event_to_auth_chain: Dict[str, List[str]] = {}
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
for event_id, auth_id in auth_events:
event_to_auth_chain.setdefault(event_id, []).append(auth_id)
# Calculate and persist the chain cover index for this set of events.
#

View file

@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
),
)
return {r["event_id"] for r in rows}
return {r[0] for r in rows}
@trace
@tag_args
@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore):
a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers).
"""
result = await self.db_pool.simple_select_many_batch(
table="partial_state_events",
column="event_id",
iterable=event_ids,
retcols=["event_id"],
desc="get_partial_state_events",
result = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="partial_state_events",
column="event_id",
iterable=event_ids,
retcols=["event_id"],
desc="get_partial_state_events",
),
)
# convert the result to a dict, to make @cachedList work
partial = {r["event_id"] for r in result}
partial = {r[0] for r in result}
return {e_id: e_id in partial for e_id in event_ids}
@cached()

View file

@ -16,7 +16,7 @@
import itertools
import json
import logging
from typing import Dict, Iterable, Mapping, Optional, Tuple
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_many_batch(
table="server_keys_json",
column="key_id",
iterable=key_ids,
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
rows = cast(
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
await self.db_pool.simple_select_many_batch(
table="server_keys_json",
column="key_id",
iterable=key_ids,
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
# We sort the rows so that the most recently added entry is picked up.
rows.sort(key=lambda r: r["ts_added_ms"])
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r[2])
return {
row["key_id"]: FetchKeyResultForRemote(
key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
key_json=bytes(key_json),
valid_until_ts=ts_valid_until_ms,
added_ts=ts_added_ms,
)
for row in rows
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
async def get_all_server_keys_json_for_remote(
@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore):
if not rows:
return {}
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r["ts_added_ms"])
return {

View file

@ -261,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Mapping[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
keyvalues={},
retcols=(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
# TODO All these columns are nullable, but we don't expect that:
# https://github.com/matrix-org/synapse/issues/16467
rows = cast(
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
keyvalues={},
retcols=(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
desc="get_presence_for_users",
),
desc="get_presence_for_users",
)
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return {row["user_id"]: UserPresenceState(**row) for row in rows}
return {
user_id: UserPresenceState(
user_id=user_id,
state=state,
last_active_ts=last_active_ts,
last_federation_update_ts=last_federation_update_ts,
last_user_sync_ts=last_user_sync_ts,
status_msg=status_msg,
currently_active=bool(currently_active),
)
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
}
async def should_user_receive_full_presence_with_token(
self,
@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
limit = 100
offset = 0
while True:
# TODO All these columns are nullable, but we don't expect that:
# https://github.com/matrix-org/synapse/issues/16467
rows = cast(
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
await self.db_pool.runInteraction(

View file

@ -62,20 +62,34 @@ logger = logging.getLogger(__name__)
def _load_rules(
rawrules: List[JsonDict],
rawrules: List[Tuple[str, int, str, str]],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
Args:
rawrules: List of tuples of:
* rule ID
* Priority lass
* Conditions (as serialized JSON)
* Actions (as serialized JSON)
enabled_map: A dictionary of rule ID to a boolean of whether the rule is
enabled. This might not include all rule IDs from rawrules.
experimental_config: The `experimental_features` section of the Synapse
config. (Used to check if various features are enabled.)
Returns:
A new FilteredPushRules object.
"""
ruleslist = [
PushRule.from_db(
rule_id=rawrule["rule_id"],
priority_class=rawrule["priority_class"],
conditions=rawrule["conditions"],
actions=rawrule["actions"],
rule_id=rawrule[0],
priority_class=rawrule[1],
conditions=rawrule[2],
actions=rawrule[3],
)
for rawrule in rawrules
]
@ -183,7 +197,19 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(rows, enabled_map, self.hs.config.experimental)
return _load_rules(
[
(
row["rule_id"],
row["priority_class"],
row["conditions"],
row["actions"],
)
for row in rows
],
enabled_map,
self.hs.config.experimental,
)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@ -221,21 +247,36 @@ class PushRulesWorkerStore(
if not user_ids:
return {}
raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
user_id: [] for user_id in user_ids
}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
retcols=("*",),
desc="bulk_get_push_rules",
batch_size=1000,
rows = cast(
List[Tuple[str, str, int, int, str, str]],
await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
retcols=(
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
),
desc="bulk_get_push_rules",
batch_size=1000,
),
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
# Sort by highest priority_class, then highest priority.
rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
for row in rows:
raw_rules.setdefault(row["user_name"], []).append(row)
for user_name, rule_id, priority_class, _, conditions, actions in rows:
raw_rules.setdefault(user_name, []).append(
(rule_id, priority_class, conditions, actions)
)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
@ -256,17 +297,19 @@ class PushRulesWorkerStore(
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
batch_size=1000,
rows = cast(
List[Tuple[str, str, Optional[int]]],
await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
batch_size=1000,
),
)
for row in rows:
enabled = bool(row["enabled"])
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
for user_name, rule_id, enabled in rows:
results.setdefault(user_name, {})[rule_id] = bool(enabled)
return results
async def get_all_push_rule_updates(

View file

@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
rows = self.db_pool.simple_select_many_txn(
txn=txn,
table="event_relations",
column="relation_type",
iterable=relation_types,
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn=txn,
table="event_relations",
column="relation_type",
iterable=relation_types,
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
),
)
return [row["event_id"] for row in rows]
return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",

View file

@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
complete.
"""
rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
table="partial_state_rooms",
column="room_id",
iterable=room_ids,
retcols=("room_id",),
desc="is_partial_state_room_batched",
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="partial_state_rooms",
column="room_id",
iterable=room_ids,
retcols=("room_id",),
desc="is_partial_state_room_batched",
),
)
partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(

View file

@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import attr
@ -683,25 +684,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from user_id to set of rooms that is currently in.
"""
rows = await self.db_pool.simple_select_many_batch(
table="current_state_events",
column="state_key",
iterable=user_ids,
retcols=(
"state_key",
"room_id",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="current_state_events",
column="state_key",
iterable=user_ids,
retcols=(
"state_key",
"room_id",
),
keyvalues={
"type": EventTypes.Member,
"membership": Membership.JOIN,
},
desc="get_rooms_for_users",
),
keyvalues={
"type": EventTypes.Member,
"membership": Membership.JOIN,
},
desc="get_rooms_for_users",
)
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
for row in rows:
user_rooms[row["state_key"]].add(row["room_id"])
for state_key, room_id in rows:
user_rooms[state_key].add(room_id)
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from event ID to `user_id`, or None if event is not a join.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
desc="_get_user_ids_from_membership_event_ids",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("event_id", "user_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
desc="_get_user_ids_from_membership_event_ids",
),
)
return {row["event_id"]: row["user_id"] for row in rows}
return dict(rows)
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@ -1202,21 +1209,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
membership event, otherwise the value is None.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=("user_id", "membership", "event_id"),
keyvalues={},
batch_size=500,
desc="get_membership_from_event_ids",
rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=("user_id", "membership", "event_id"),
keyvalues={},
batch_size=500,
desc="get_membership_from_event_ids",
),
)
return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
event_id: EventIdMembership(membership=membership, user_id=user_id)
for user_id, membership, event_id in rows
}
async def is_local_host_in_room_ignoring_users(

View file

@ -20,10 +20,12 @@ from typing import (
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
cast,
)
import attr
@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
RuntimeError if the state is unknown at any of the given events
"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id", "state_group"),
desc="_get_state_group_for_events",
rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id", "state_group"),
desc="_get_state_group_for_events",
),
)
res = {row["event_id"]: row["state_group"] for row in rows}
res = dict(rows)
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The subset of state groups that are referenced.
"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("DISTINCT state_group",),
desc="get_referenced_state_groups",
rows = cast(
List[Tuple[int]],
await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("DISTINCT state_group",),
desc="get_referenced_state_groups",
),
)
return {row["state_group"] for row in rows}
return {row[0] for row in rows}
async def update_state_for_partial_state_event(
self,
@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
# potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may
# have missed any device updates.
rows = self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="room_id",
iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
retcols=("state_key",),
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="room_id",
iterable=to_delete,
keyvalues={
"type": EventTypes.Member,
"membership": Membership.JOIN,
},
retcols=("state_key",),
),
)
potentially_left_users = {row["state_key"] for row in rows}
potentially_left_users = {row[0] for row in rows}
# Now lets actually delete the rooms from the DB.
self.db_pool.simple_delete_many_txn(

View file

@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore):
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
rows = self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
iterable=[
EventTypes.Create,
EventTypes.JoinRules,
EventTypes.RoomHistoryVisibility,
EventTypes.RoomEncryption,
EventTypes.Name,
EventTypes.Topic,
EventTypes.RoomAvatar,
EventTypes.CanonicalAlias,
],
keyvalues={"room_id": room_id, "state_key": ""},
retcols=["event_id"],
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
iterable=[
EventTypes.Create,
EventTypes.JoinRules,
EventTypes.RoomHistoryVisibility,
EventTypes.RoomEncryption,
EventTypes.Name,
EventTypes.Topic,
EventTypes.RoomAvatar,
EventTypes.CanonicalAlias,
],
keyvalues={"room_id": room_id, "state_key": ""},
retcols=["event_id"],
),
)
event_ids = cast(List[str], [row["event_id"] for row in rows])
event_ids = [row[0] for row in rows]
txn.execute(
"""

View file

@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
async def get_destination_retry_timings_batch(
self, destinations: StrCollection
) -> Mapping[str, Optional[DestinationRetryTimings]]:
rows = await self.db_pool.simple_select_many_batch(
table="destinations",
iterable=destinations,
column="destination",
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
desc="get_destination_retry_timings_batch",
rows = cast(
List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
await self.db_pool.simple_select_many_batch(
table="destinations",
iterable=destinations,
column="destination",
retcols=(
"destination",
"failure_ts",
"retry_last_ts",
"retry_interval",
),
desc="get_destination_retry_timings_batch",
),
)
return {
row.pop("destination"): DestinationRetryTimings(**row)
for row in rows
if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
destination: DestinationRetryTimings(
failure_ts, retry_last_ts, retry_interval
)
for destination, failure_ts, retry_last_ts, retry_interval in rows
if retry_last_ts and failure_ts and retry_interval
}
async def set_destination_retry_timings(

View file

@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore):
# If a registration token was used, decrement the pending counter
# before deleting the session.
rows = self.db_pool.simple_select_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
iterable=session_ids,
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
retcols=["result"],
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
iterable=session_ids,
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
retcols=["result"],
),
)
# Get the tokens used and how much pending needs to be decremented by.
@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore):
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
token = db_to_json(r["result"])
token = db_to_json(r[0])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
token_rows = self.db_pool.simple_select_many_txn(
txn,
table="registration_tokens",
column="token",
iterable=list(token_counts.keys()),
keyvalues={},
retcols=["token", "pending"],
token_rows = cast(
List[Tuple[str, int]],
self.db_pool.simple_select_many_txn(
txn,
table="registration_tokens",
column="token",
iterable=list(token_counts.keys()),
keyvalues={},
retcols=["token", "pending"],
),
)
for token_row in token_rows:
token = token_row["token"]
new_pending = token_row["pending"] - token_counts[token]
for token, pending in token_rows:
new_pending = pending - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",

View file

@ -410,25 +410,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
# Next fetch their profiles. Note that not all users have profiles.
profile_rows = self.db_pool.simple_select_many_txn(
txn,
table="profiles",
column="full_user_id",
iterable=list(users_to_insert),
retcols=(
"full_user_id",
"displayname",
"avatar_url",
profile_rows = cast(
List[Tuple[str, Optional[str], Optional[str]]],
self.db_pool.simple_select_many_txn(
txn,
table="profiles",
column="full_user_id",
iterable=list(users_to_insert),
retcols=(
"full_user_id",
"displayname",
"avatar_url",
),
keyvalues={},
),
keyvalues={},
)
profiles = {
row["full_user_id"]: _UserDirProfile(
row["full_user_id"],
row["displayname"],
row["avatar_url"],
)
for row in profile_rows
full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
for full_user_id, displayname, avatar_url in profile_rows
}
profiles_to_insert = [
@ -517,18 +516,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
]
rows = self.db_pool.simple_select_many_txn(
txn,
table="users",
column="name",
iterable=users,
keyvalues={
"deactivated": 0,
},
retcols=("name", "user_type"),
rows = cast(
List[Tuple[str, Optional[str]]],
self.db_pool.simple_select_many_txn(
txn,
table="users",
column="name",
iterable=users,
keyvalues={
"deactivated": 0,
},
retcols=("name", "user_type"),
),
)
return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Mapping
from typing import Iterable, List, Mapping, Tuple, cast
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
Returns:
for each user, whether the user has requested erasure.
"""
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="are_users_erased",
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="are_users_erased",
),
)
erased_users = {row["user_id"] for row in rows}
erased_users = {row[0] for row in rows}
return {u: u in erased_users for u in user_ids}

View file

@ -13,7 +13,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
import attr
@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
rows = self.db_pool.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
iterable=state_groups_to_delete,
keyvalues={},
retcols=("state_group",),
rows = cast(
List[Tuple[int]],
self.db_pool.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
iterable=state_groups_to_delete,
keyvalues={},
retcols=("state_group",),
),
)
remaining_state_groups = {
row["state_group"]
for row in rows
if row["state_group"] not in state_groups_to_delete
state_group
for state_group, in rows
if state_group not in state_groups_to_delete
}
logger.info(
@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A mapping from state group to previous state group.
"""
rows = await self.db_pool.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
keyvalues={},
retcols=("prev_state_group", "state_group"),
desc="get_previous_state_groups",
rows = cast(
List[Tuple[int, int]],
await self.db_pool.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
keyvalues={},
retcols=("state_group", "prev_state_group"),
desc="get_previous_state_groups",
),
)
return {row["state_group"]: row["prev_state_group"] for row in rows}
return dict(rows)
async def purge_room_state(
self, room_id: str, state_groups_to_delete: Collection[int]

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Set, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase):
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
# Fetch the map from event ID -> (chain ID, sequence number)
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chains",
column="event_id",
iterable=[e.event_id for e in events],
retcols=("event_id", "chain_id", "sequence_number"),
keyvalues={},
)
rows = cast(
List[Tuple[str, int, int]],
self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chains",
column="event_id",
iterable=[e.event_id for e in events],
retcols=("event_id", "chain_id", "sequence_number"),
keyvalues={},
)
),
)
chain_map = {
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
event_id: (chain_id, sequence_number)
for event_id, chain_id, sequence_number in rows
}
# Fetch all the links and pass them to the _LinkMap.
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chain_links",
column="origin_chain_id",
iterable=[chain_id for chain_id, _ in chain_map.values()],
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
keyvalues={},
)
auth_chain_rows = cast(
List[Tuple[int, int, int, int]],
self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chain_links",
column="origin_chain_id",
iterable=[chain_id for chain_id, _ in chain_map.values()],
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
keyvalues={},
)
),
)
link_map = _LinkMap()
for row in rows:
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
added = link_map.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]),
(row["target_chain_id"], row["target_sequence_number"]),
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
)
# We shouldn't have persisted any redundant links