mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 11:53:49 +01:00
Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444)
This commit is contained in:
parent
d6b7d49a61
commit
a4904dcb04
23 changed files with 641 additions and 443 deletions
1
changelog.d/16444.misc
Normal file
1
changelog.d/16444.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce memory allocations.
|
|
@ -1874,9 +1874,9 @@ class DatabasePool:
|
||||||
keyvalues: Optional[Dict[str, Any]] = None,
|
keyvalues: Optional[Dict[str, Any]] = None,
|
||||||
desc: str = "simple_select_many_batch",
|
desc: str = "simple_select_many_batch",
|
||||||
batch_size: int = 100,
|
batch_size: int = 100,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows.
|
||||||
|
|
||||||
Filters rows by whether the value of `column` is in `iterable`.
|
Filters rows by whether the value of `column` is in `iterable`.
|
||||||
|
|
||||||
|
@ -1888,10 +1888,13 @@ class DatabasePool:
|
||||||
keyvalues: dict of column names and values to select the rows with
|
keyvalues: dict of column names and values to select the rows with
|
||||||
desc: description of the transaction, for logging and metrics
|
desc: description of the transaction, for logging and metrics
|
||||||
batch_size: the number of rows for each select query
|
batch_size: the number of rows for each select query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The results as a list of tuples.
|
||||||
"""
|
"""
|
||||||
keyvalues = keyvalues or {}
|
keyvalues = keyvalues or {}
|
||||||
|
|
||||||
results: List[Dict[str, Any]] = []
|
results: List[Tuple[Any, ...]] = []
|
||||||
|
|
||||||
for chunk in batch_iter(iterable, batch_size):
|
for chunk in batch_iter(iterable, batch_size):
|
||||||
rows = await self.runInteraction(
|
rows = await self.runInteraction(
|
||||||
|
@ -1918,9 +1921,9 @@ class DatabasePool:
|
||||||
iterable: Collection[Any],
|
iterable: Collection[Any],
|
||||||
keyvalues: Dict[str, Any],
|
keyvalues: Dict[str, Any],
|
||||||
retcols: Iterable[str],
|
retcols: Iterable[str],
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows.
|
||||||
|
|
||||||
Filters rows by whether the value of `column` is in `iterable`.
|
Filters rows by whether the value of `column` is in `iterable`.
|
||||||
|
|
||||||
|
@ -1931,6 +1934,9 @@ class DatabasePool:
|
||||||
iterable: list
|
iterable: list
|
||||||
keyvalues: dict of column names and values to select the rows with
|
keyvalues: dict of column names and values to select the rows with
|
||||||
retcols: list of strings giving the names of the columns to return
|
retcols: list of strings giving the names of the columns to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The results as a list of tuples.
|
||||||
"""
|
"""
|
||||||
if not iterable:
|
if not iterable:
|
||||||
return []
|
return []
|
||||||
|
@ -1949,7 +1955,7 @@ class DatabasePool:
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return cls.cursor_to_dict(txn)
|
return txn.fetchall()
|
||||||
|
|
||||||
async def simple_update(
|
async def simple_update(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
# Note that this is more efficient than just dropping `device_id` from the query,
|
# Note that this is more efficient than just dropping `device_id` from the query,
|
||||||
# since device_inbox has an index on `(user_id, device_id, stream_id)`
|
# since device_inbox has an index on `(user_id, device_id, stream_id)`
|
||||||
if not device_ids_to_query:
|
if not device_ids_to_query:
|
||||||
user_device_dicts = self.db_pool.simple_select_many_txn(
|
user_device_dicts = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="devices",
|
table="devices",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids_to_query,
|
iterable=user_ids_to_query,
|
||||||
keyvalues={"hidden": False},
|
keyvalues={"hidden": False},
|
||||||
retcols=("device_id",),
|
retcols=("device_id",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
device_ids_to_query.update(
|
device_ids_to_query.update({row[0] for row in user_device_dicts})
|
||||||
{row["device_id"] for row in user_device_dicts}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not device_ids_to_query:
|
if not device_ids_to_query:
|
||||||
# We've ended up with no devices to query.
|
# We've ended up with no devices to query.
|
||||||
|
@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
# We exclude hidden devices (such as cross-signing keys) here as they are
|
# We exclude hidden devices (such as cross-signing keys) here as they are
|
||||||
# not expected to receive to-device messages.
|
# not expected to receive to-device messages.
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="devices",
|
table="devices",
|
||||||
keyvalues={"user_id": user_id, "hidden": False},
|
keyvalues={"user_id": user_id, "hidden": False},
|
||||||
column="device_id",
|
column="device_id",
|
||||||
iterable=devices,
|
iterable=devices,
|
||||||
retcols=("device_id",),
|
retcols=("device_id",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in rows:
|
for (device_id,) in rows:
|
||||||
# Only insert into the local inbox if the device exists on
|
# Only insert into the local inbox if the device exists on
|
||||||
# this server
|
# this server
|
||||||
device_id = row["device_id"]
|
|
||||||
|
|
||||||
with start_active_span("serialise_to_device_message"):
|
with start_active_span("serialise_to_device_message"):
|
||||||
msg = messages_by_device[device_id]
|
msg = messages_by_device[device_id]
|
||||||
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
|
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
|
||||||
|
|
|
@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
async def get_device_list_last_stream_id_for_remotes(
|
async def get_device_list_last_stream_id_for_remotes(
|
||||||
self, user_ids: Iterable[str]
|
self, user_ids: Iterable[str]
|
||||||
) -> Mapping[str, Optional[str]]:
|
) -> Mapping[str, Optional[str]]:
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
retcols=("user_id", "stream_id"),
|
retcols=("user_id", "stream_id"),
|
||||||
desc="get_device_list_last_stream_id_for_remotes",
|
desc="get_device_list_last_stream_id_for_remotes",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
||||||
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
results.update(rows)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -1077,19 +1080,27 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
The IDs of users whose device lists need resync.
|
The IDs of users whose device lists need resync.
|
||||||
"""
|
"""
|
||||||
if user_ids:
|
if user_ids:
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
row_tuples = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_resync",
|
table="device_lists_remote_resync",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
retcols=("user_id",),
|
retcols=("user_id",),
|
||||||
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return {row[0] for row in row_tuples}
|
||||||
else:
|
else:
|
||||||
rows = await self.db_pool.simple_select_list(
|
rows = cast(
|
||||||
|
List[Dict[str, str]],
|
||||||
|
await self.db_pool.simple_select_list(
|
||||||
table="device_lists_remote_resync",
|
table="device_lists_remote_resync",
|
||||||
keyvalues=None,
|
keyvalues=None,
|
||||||
retcols=("user_id",),
|
retcols=("user_id",),
|
||||||
desc="get_user_ids_requiring_device_list_resync",
|
desc="get_user_ids_requiring_device_list_resync",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row["user_id"] for row in rows}
|
return {row["user_id"] for row in rows}
|
||||||
|
|
|
@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
A map from (algorithm, key_id) to json string for key
|
A map from (algorithm, key_id) to json string for key
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str, str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="e2e_one_time_keys_json",
|
table="e2e_one_time_keys_json",
|
||||||
column="key_id",
|
column="key_id",
|
||||||
iterable=key_ids,
|
iterable=key_ids,
|
||||||
retcols=("algorithm", "key_id", "key_json"),
|
retcols=("algorithm", "key_id", "key_json"),
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
desc="add_e2e_one_time_keys_check",
|
desc="add_e2e_one_time_keys_check",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
|
result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
|
||||||
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
|
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -1049,7 +1049,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
Args:
|
Args:
|
||||||
event_ids: The event IDs to calculate the max depth of.
|
event_ids: The event IDs to calculate the max depth of.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, int]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="events",
|
table="events",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
|
@ -1058,6 +1060,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"depth",
|
"depth",
|
||||||
),
|
),
|
||||||
desc="get_max_depth_of",
|
desc="get_max_depth_of",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
|
@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
else:
|
else:
|
||||||
max_depth_event_id = ""
|
max_depth_event_id = ""
|
||||||
current_max_depth = 0
|
current_max_depth = 0
|
||||||
for row in rows:
|
for event_id, depth in rows:
|
||||||
if row["depth"] > current_max_depth:
|
if depth > current_max_depth:
|
||||||
max_depth_event_id = row["event_id"]
|
max_depth_event_id = event_id
|
||||||
current_max_depth = row["depth"]
|
current_max_depth = depth
|
||||||
|
|
||||||
return max_depth_event_id, current_max_depth
|
return max_depth_event_id, current_max_depth
|
||||||
|
|
||||||
|
@ -1078,7 +1081,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
Args:
|
Args:
|
||||||
event_ids: The event IDs to calculate the max depth of.
|
event_ids: The event IDs to calculate the max depth of.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, int]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="events",
|
table="events",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
|
@ -1087,6 +1092,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"depth",
|
"depth",
|
||||||
),
|
),
|
||||||
desc="get_min_depth_of",
|
desc="get_min_depth_of",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
|
@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
else:
|
else:
|
||||||
min_depth_event_id = ""
|
min_depth_event_id = ""
|
||||||
current_min_depth = MAX_DEPTH
|
current_min_depth = MAX_DEPTH
|
||||||
for row in rows:
|
for event_id, depth in rows:
|
||||||
if row["depth"] < current_min_depth:
|
if depth < current_min_depth:
|
||||||
min_depth_event_id = row["event_id"]
|
min_depth_event_id = event_id
|
||||||
current_min_depth = row["depth"]
|
current_min_depth = depth
|
||||||
|
|
||||||
return min_depth_event_id, current_min_depth
|
return min_depth_event_id, current_min_depth
|
||||||
|
|
||||||
|
@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
A filtered down list of `event_ids` that have previous failed pull attempts.
|
A filtered down list of `event_ids` that have previous failed pull attempts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="event_failed_pull_attempts",
|
table="event_failed_pull_attempts",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("event_id",),
|
retcols=("event_id",),
|
||||||
desc="get_event_ids_with_failed_pull_attempts",
|
desc="get_event_ids_with_failed_pull_attempts",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
event_ids_with_failed_pull_attempts: Set[str] = {
|
return {row[0] for row in rows}
|
||||||
row["event_id"] for row in rows
|
|
||||||
}
|
|
||||||
|
|
||||||
return event_ids_with_failed_pull_attempts
|
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_event_ids_to_not_pull_from_backoff(
|
async def get_event_ids_to_not_pull_from_backoff(
|
||||||
|
@ -1585,7 +1590,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
A dictionary of event_ids that should not be attempted to be pulled and the
|
A dictionary of event_ids that should not be attempted to be pulled and the
|
||||||
next timestamp at which we may try pulling them again.
|
next timestamp at which we may try pulling them again.
|
||||||
"""
|
"""
|
||||||
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
|
event_failed_pull_attempts = cast(
|
||||||
|
List[Tuple[str, int, int]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="event_failed_pull_attempts",
|
table="event_failed_pull_attempts",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
|
@ -1596,21 +1603,21 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"num_attempts",
|
"num_attempts",
|
||||||
),
|
),
|
||||||
desc="get_event_ids_to_not_pull_from_backoff",
|
desc="get_event_ids_to_not_pull_from_backoff",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
current_time = self._clock.time_msec()
|
current_time = self._clock.time_msec()
|
||||||
|
|
||||||
event_ids_with_backoff = {}
|
event_ids_with_backoff = {}
|
||||||
for event_failed_pull_attempt in event_failed_pull_attempts:
|
for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
|
||||||
event_id = event_failed_pull_attempt["event_id"]
|
|
||||||
# Exponential back-off (up to the upper bound) so we don't try to
|
# Exponential back-off (up to the upper bound) so we don't try to
|
||||||
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
||||||
backoff_end_time = (
|
backoff_end_time = (
|
||||||
event_failed_pull_attempt["last_attempt_ts"]
|
last_attempt_ts
|
||||||
+ (
|
+ (
|
||||||
2
|
2
|
||||||
** min(
|
** min(
|
||||||
event_failed_pull_attempt["num_attempts"],
|
num_attempts,
|
||||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -501,16 +502,19 @@ class PersistEventsStore:
|
||||||
|
|
||||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||||
# for.
|
# for.
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str, Optional[Union[int, bool]]]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="rooms",
|
table="rooms",
|
||||||
column="room_id",
|
column="room_id",
|
||||||
iterable={event.room_id for event in events if event.is_state()},
|
iterable={event.room_id for event in events if event.is_state()},
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("room_id", "has_auth_chain_index"),
|
retcols=("room_id", "has_auth_chain_index"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
rooms_using_chain_index = {
|
rooms_using_chain_index = {
|
||||||
row["room_id"] for row in rows if row["has_auth_chain_index"]
|
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
|
||||||
}
|
}
|
||||||
|
|
||||||
state_events = {
|
state_events = {
|
||||||
|
@ -571,19 +575,18 @@ class PersistEventsStore:
|
||||||
# We check if there are any events that need to be handled in the rooms
|
# We check if there are any events that need to be handled in the rooms
|
||||||
# we're looking at. These should just be out of band memberships, where
|
# we're looking at. These should just be out of band memberships, where
|
||||||
# we didn't have the auth chain when we first persisted.
|
# we didn't have the auth chain when we first persisted.
|
||||||
rows = db_pool.simple_select_many_txn(
|
auth_chain_to_calc_rows = cast(
|
||||||
|
List[Tuple[str, str, str]],
|
||||||
|
db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_auth_chain_to_calculate",
|
table="event_auth_chain_to_calculate",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
column="room_id",
|
column="room_id",
|
||||||
iterable=set(event_to_room_id.values()),
|
iterable=set(event_to_room_id.values()),
|
||||||
retcols=("event_id", "type", "state_key"),
|
retcols=("event_id", "type", "state_key"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for row in rows:
|
for event_id, event_type, state_key in auth_chain_to_calc_rows:
|
||||||
event_id = row["event_id"]
|
|
||||||
event_type = row["type"]
|
|
||||||
state_key = row["state_key"]
|
|
||||||
|
|
||||||
# (We could pull out the auth events for all rows at once using
|
# (We could pull out the auth events for all rows at once using
|
||||||
# simple_select_many, but this case happens rarely and almost always
|
# simple_select_many, but this case happens rarely and almost always
|
||||||
# with a single row.)
|
# with a single row.)
|
||||||
|
@ -753,7 +756,9 @@ class PersistEventsStore:
|
||||||
# Step 1, fetch all existing links from all the chains we've seen
|
# Step 1, fetch all existing links from all the chains we've seen
|
||||||
# referenced.
|
# referenced.
|
||||||
chain_links = _LinkMap()
|
chain_links = _LinkMap()
|
||||||
rows = db_pool.simple_select_many_txn(
|
auth_chain_rows = cast(
|
||||||
|
List[Tuple[int, int, int, int]],
|
||||||
|
db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_auth_chain_links",
|
table="event_auth_chain_links",
|
||||||
column="origin_chain_id",
|
column="origin_chain_id",
|
||||||
|
@ -765,11 +770,17 @@ class PersistEventsStore:
|
||||||
"target_chain_id",
|
"target_chain_id",
|
||||||
"target_sequence_number",
|
"target_sequence_number",
|
||||||
),
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for row in rows:
|
for (
|
||||||
|
origin_chain_id,
|
||||||
|
origin_sequence_number,
|
||||||
|
target_chain_id,
|
||||||
|
target_sequence_number,
|
||||||
|
) in auth_chain_rows:
|
||||||
chain_links.add_link(
|
chain_links.add_link(
|
||||||
(row["origin_chain_id"], row["origin_sequence_number"]),
|
(origin_chain_id, origin_sequence_number),
|
||||||
(row["target_chain_id"], row["target_sequence_number"]),
|
(target_chain_id, target_sequence_number),
|
||||||
new=False,
|
new=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
|
|
||||||
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
|
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
ev_rows = self.db_pool.simple_select_many_txn(
|
ev_rows = cast(
|
||||||
|
List[Tuple[str, str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_json",
|
table="event_json",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=chunk,
|
iterable=chunk,
|
||||||
retcols=["event_id", "json"],
|
retcols=["event_id", "json"],
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in ev_rows:
|
for event_id, json in ev_rows:
|
||||||
event_id = row["event_id"]
|
event_json = db_to_json(json)
|
||||||
event_json = db_to_json(row["json"])
|
|
||||||
try:
|
try:
|
||||||
origin_server_ts = event_json["origin_server_ts"]
|
origin_server_ts = event_json["origin_server_ts"]
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
|
@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
|
|
||||||
if deleted:
|
if deleted:
|
||||||
# We now need to invalidate the caches of these rooms
|
# We now need to invalidate the caches of these rooms
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="events",
|
table="events",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=to_delete,
|
iterable=to_delete,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("room_id",),
|
retcols=("room_id",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
room_ids = {row["room_id"] for row in rows}
|
room_ids = {row[0] for row in rows}
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
|
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
|
||||||
|
@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
count = len(rows)
|
count = len(rows)
|
||||||
|
|
||||||
# We also need to fetch the auth events for them.
|
# We also need to fetch the auth events for them.
|
||||||
auth_events = self.db_pool.simple_select_many_txn(
|
auth_events = cast(
|
||||||
|
List[Tuple[str, str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_auth",
|
table="event_auth",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_to_room_id,
|
iterable=event_to_room_id,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("event_id", "auth_id"),
|
retcols=("event_id", "auth_id"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
event_to_auth_chain: Dict[str, List[str]] = {}
|
event_to_auth_chain: Dict[str, List[str]] = {}
|
||||||
for row in auth_events:
|
for event_id, auth_id in auth_events:
|
||||||
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
|
event_to_auth_chain.setdefault(event_id, []).append(auth_id)
|
||||||
|
|
||||||
# Calculate and persist the chain cover index for this set of events.
|
# Calculate and persist the chain cover index for this set of events.
|
||||||
#
|
#
|
||||||
|
|
|
@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
"""Given a list of event ids, check if we have already processed and
|
"""Given a list of event ids, check if we have already processed and
|
||||||
stored them as non outliers.
|
stored them as non outliers.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="events",
|
table="events",
|
||||||
retcols=("event_id",),
|
retcols=("event_id",),
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=list(event_ids),
|
iterable=list(event_ids),
|
||||||
keyvalues={"outlier": False},
|
keyvalues={"outlier": False},
|
||||||
desc="have_events_in_timeline",
|
desc="have_events_in_timeline",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {r["event_id"] for r in rows}
|
return {r[0] for r in rows}
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@tag_args
|
@tag_args
|
||||||
|
@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
a dict mapping from event id to partial-stateness. We return True for
|
a dict mapping from event id to partial-stateness. We return True for
|
||||||
any of the events which are unknown (or are outliers).
|
any of the events which are unknown (or are outliers).
|
||||||
"""
|
"""
|
||||||
result = await self.db_pool.simple_select_many_batch(
|
result = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="partial_state_events",
|
table="partial_state_events",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
retcols=["event_id"],
|
retcols=["event_id"],
|
||||||
desc="get_partial_state_events",
|
desc="get_partial_state_events",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# convert the result to a dict, to make @cachedList work
|
# convert the result to a dict, to make @cachedList work
|
||||||
partial = {r["event_id"] for r in result}
|
partial = {r[0] for r in result}
|
||||||
return {e_id: e_id in partial for e_id in event_ids}
|
return {e_id: e_id in partial for e_id in event_ids}
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, Mapping, Optional, Tuple
|
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
@ -205,7 +205,9 @@ class KeyStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
If we have multiple entries for a given key ID, returns the most recent.
|
If we have multiple entries for a given key ID, returns the most recent.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="server_keys_json",
|
table="server_keys_json",
|
||||||
column="key_id",
|
column="key_id",
|
||||||
iterable=key_ids,
|
iterable=key_ids,
|
||||||
|
@ -218,22 +220,24 @@ class KeyStore(CacheInvalidationWorkerStore):
|
||||||
"key_json",
|
"key_json",
|
||||||
),
|
),
|
||||||
desc="get_server_keys_json_for_remote",
|
desc="get_server_keys_json_for_remote",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# We sort the rows so that the most recently added entry is picked up.
|
# We sort the rows by ts_added_ms so that the most recently added entry
|
||||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
# will stomp over older entries in the dictionary.
|
||||||
|
rows.sort(key=lambda r: r[2])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row["key_id"]: FetchKeyResultForRemote(
|
key_id: FetchKeyResultForRemote(
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
key_json=bytes(row["key_json"]),
|
key_json=bytes(key_json),
|
||||||
valid_until_ts=row["ts_valid_until_ms"],
|
valid_until_ts=ts_valid_until_ms,
|
||||||
added_ts=row["ts_added_ms"],
|
added_ts=ts_added_ms,
|
||||||
)
|
)
|
||||||
for row in rows
|
for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_all_server_keys_json_for_remote(
|
async def get_all_server_keys_json_for_remote(
|
||||||
|
@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore):
|
||||||
if not rows:
|
if not rows:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# We sort the rows by ts_added_ms so that the most recently added entry
|
||||||
|
# will stomp over older entries in the dictionary.
|
||||||
rows.sort(key=lambda r: r["ts_added_ms"])
|
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -261,7 +261,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
async def get_presence_for_users(
|
async def get_presence_for_users(
|
||||||
self, user_ids: Iterable[str]
|
self, user_ids: Iterable[str]
|
||||||
) -> Mapping[str, UserPresenceState]:
|
) -> Mapping[str, UserPresenceState]:
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
# TODO All these columns are nullable, but we don't expect that:
|
||||||
|
# https://github.com/matrix-org/synapse/issues/16467
|
||||||
|
rows = cast(
|
||||||
|
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="presence_stream",
|
table="presence_stream",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
|
@ -276,12 +280,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
"currently_active",
|
"currently_active",
|
||||||
),
|
),
|
||||||
desc="get_presence_for_users",
|
desc="get_presence_for_users",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in rows:
|
return {
|
||||||
row["currently_active"] = bool(row["currently_active"])
|
user_id: UserPresenceState(
|
||||||
|
user_id=user_id,
|
||||||
return {row["user_id"]: UserPresenceState(**row) for row in rows}
|
state=state,
|
||||||
|
last_active_ts=last_active_ts,
|
||||||
|
last_federation_update_ts=last_federation_update_ts,
|
||||||
|
last_user_sync_ts=last_user_sync_ts,
|
||||||
|
status_msg=status_msg,
|
||||||
|
currently_active=bool(currently_active),
|
||||||
|
)
|
||||||
|
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
|
||||||
|
}
|
||||||
|
|
||||||
async def should_user_receive_full_presence_with_token(
|
async def should_user_receive_full_presence_with_token(
|
||||||
self,
|
self,
|
||||||
|
@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
limit = 100
|
limit = 100
|
||||||
offset = 0
|
offset = 0
|
||||||
while True:
|
while True:
|
||||||
|
# TODO All these columns are nullable, but we don't expect that:
|
||||||
|
# https://github.com/matrix-org/synapse/issues/16467
|
||||||
rows = cast(
|
rows = cast(
|
||||||
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
|
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -62,20 +62,34 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_rules(
|
def _load_rules(
|
||||||
rawrules: List[JsonDict],
|
rawrules: List[Tuple[str, int, str, str]],
|
||||||
enabled_map: Dict[str, bool],
|
enabled_map: Dict[str, bool],
|
||||||
experimental_config: ExperimentalConfig,
|
experimental_config: ExperimentalConfig,
|
||||||
) -> FilteredPushRules:
|
) -> FilteredPushRules:
|
||||||
"""Take the DB rows returned from the DB and convert them into a full
|
"""Take the DB rows returned from the DB and convert them into a full
|
||||||
`FilteredPushRules` object.
|
`FilteredPushRules` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rawrules: List of tuples of:
|
||||||
|
* rule ID
|
||||||
|
* Priority lass
|
||||||
|
* Conditions (as serialized JSON)
|
||||||
|
* Actions (as serialized JSON)
|
||||||
|
enabled_map: A dictionary of rule ID to a boolean of whether the rule is
|
||||||
|
enabled. This might not include all rule IDs from rawrules.
|
||||||
|
experimental_config: The `experimental_features` section of the Synapse
|
||||||
|
config. (Used to check if various features are enabled.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new FilteredPushRules object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ruleslist = [
|
ruleslist = [
|
||||||
PushRule.from_db(
|
PushRule.from_db(
|
||||||
rule_id=rawrule["rule_id"],
|
rule_id=rawrule[0],
|
||||||
priority_class=rawrule["priority_class"],
|
priority_class=rawrule[1],
|
||||||
conditions=rawrule["conditions"],
|
conditions=rawrule[2],
|
||||||
actions=rawrule["actions"],
|
actions=rawrule[3],
|
||||||
)
|
)
|
||||||
for rawrule in rawrules
|
for rawrule in rawrules
|
||||||
]
|
]
|
||||||
|
@ -183,7 +197,19 @@ class PushRulesWorkerStore(
|
||||||
|
|
||||||
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
|
||||||
|
|
||||||
return _load_rules(rows, enabled_map, self.hs.config.experimental)
|
return _load_rules(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
row["rule_id"],
|
||||||
|
row["priority_class"],
|
||||||
|
row["conditions"],
|
||||||
|
row["actions"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
],
|
||||||
|
enabled_map,
|
||||||
|
self.hs.config.experimental,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||||
results = await self.db_pool.simple_select_list(
|
results = await self.db_pool.simple_select_list(
|
||||||
|
@ -221,21 +247,36 @@ class PushRulesWorkerStore(
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
|
raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
|
||||||
|
user_id: [] for user_id in user_ids
|
||||||
|
}
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str, int, int, str, str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="push_rules",
|
table="push_rules",
|
||||||
column="user_name",
|
column="user_name",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
retcols=("*",),
|
retcols=(
|
||||||
|
"user_name",
|
||||||
|
"rule_id",
|
||||||
|
"priority_class",
|
||||||
|
"priority",
|
||||||
|
"conditions",
|
||||||
|
"actions",
|
||||||
|
),
|
||||||
desc="bulk_get_push_rules",
|
desc="bulk_get_push_rules",
|
||||||
batch_size=1000,
|
batch_size=1000,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
# Sort by highest priority_class, then highest priority.
|
||||||
|
rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
|
||||||
|
|
||||||
for row in rows:
|
for user_name, rule_id, priority_class, _, conditions, actions in rows:
|
||||||
raw_rules.setdefault(row["user_name"], []).append(row)
|
raw_rules.setdefault(user_name, []).append(
|
||||||
|
(rule_id, priority_class, conditions, actions)
|
||||||
|
)
|
||||||
|
|
||||||
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
||||||
|
|
||||||
|
@ -256,17 +297,19 @@ class PushRulesWorkerStore(
|
||||||
|
|
||||||
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
|
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str, Optional[int]]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="push_rules_enable",
|
table="push_rules_enable",
|
||||||
column="user_name",
|
column="user_name",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
retcols=("user_name", "rule_id", "enabled"),
|
retcols=("user_name", "rule_id", "enabled"),
|
||||||
desc="bulk_get_push_rules_enabled",
|
desc="bulk_get_push_rules_enabled",
|
||||||
batch_size=1000,
|
batch_size=1000,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for row in rows:
|
for user_name, rule_id, enabled in rows:
|
||||||
enabled = bool(row["enabled"])
|
results.setdefault(user_name, {})[rule_id] = bool(enabled)
|
||||||
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_all_push_rule_updates(
|
async def get_all_push_rule_updates(
|
||||||
|
|
|
@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
def get_all_relation_ids_for_event_with_types_txn(
|
def get_all_relation_ids_for_event_with_types_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="event_relations",
|
table="event_relations",
|
||||||
column="relation_type",
|
column="relation_type",
|
||||||
iterable=relation_types,
|
iterable=relation_types,
|
||||||
keyvalues={"relates_to_id": event_id},
|
keyvalues={"relates_to_id": event_id},
|
||||||
retcols=["event_id"],
|
retcols=["event_id"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [row["event_id"] for row in rows]
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
desc="get_all_relation_ids_for_event_with_types",
|
desc="get_all_relation_ids_for_event_with_types",
|
||||||
|
|
|
@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
complete.
|
complete.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="partial_state_rooms",
|
table="partial_state_rooms",
|
||||||
column="room_id",
|
column="room_id",
|
||||||
iterable=room_ids,
|
iterable=room_ids,
|
||||||
retcols=("room_id",),
|
retcols=("room_id",),
|
||||||
desc="is_partial_state_room_batched",
|
desc="is_partial_state_room_batched",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
|
partial_state_rooms = {row[0] for row in rows}
|
||||||
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
|
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
|
||||||
|
|
||||||
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
||||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -683,7 +684,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
Map from user_id to set of rooms that is currently in.
|
Map from user_id to set of rooms that is currently in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="current_state_events",
|
table="current_state_events",
|
||||||
column="state_key",
|
column="state_key",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
|
@ -696,12 +699,13 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
"membership": Membership.JOIN,
|
"membership": Membership.JOIN,
|
||||||
},
|
},
|
||||||
desc="get_rooms_for_users",
|
desc="get_rooms_for_users",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
|
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
|
||||||
|
|
||||||
for row in rows:
|
for state_key, room_id in rows:
|
||||||
user_rooms[row["state_key"]].add(row["room_id"])
|
user_rooms[state_key].add(room_id)
|
||||||
|
|
||||||
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
|
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
|
||||||
|
|
||||||
|
@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
Map from event ID to `user_id`, or None if event is not a join.
|
Map from event ID to `user_id`, or None if event is not a join.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
retcols=("user_id", "event_id"),
|
retcols=("event_id", "user_id"),
|
||||||
keyvalues={"membership": Membership.JOIN},
|
keyvalues={"membership": Membership.JOIN},
|
||||||
batch_size=1000,
|
batch_size=1000,
|
||||||
desc="_get_user_ids_from_membership_event_ids",
|
desc="_get_user_ids_from_membership_event_ids",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row["event_id"]: row["user_id"] for row in rows}
|
return dict(rows)
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
async def is_host_joined(self, room_id: str, host: str) -> bool:
|
async def is_host_joined(self, room_id: str, host: str) -> bool:
|
||||||
|
@ -1202,7 +1209,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
membership event, otherwise the value is None.
|
membership event, otherwise the value is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, str, str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=member_event_ids,
|
iterable=member_event_ids,
|
||||||
|
@ -1210,13 +1219,12 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
batch_size=500,
|
batch_size=500,
|
||||||
desc="get_membership_from_event_ids",
|
desc="get_membership_from_event_ids",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row["event_id"]: EventIdMembership(
|
event_id: EventIdMembership(membership=membership, user_id=user_id)
|
||||||
membership=row["membership"], user_id=row["user_id"]
|
for user_id, membership, event_id in rows
|
||||||
)
|
|
||||||
for row in rows
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def is_local_host_in_room_ignoring_users(
|
async def is_local_host_in_room_ignoring_users(
|
||||||
|
|
|
@ -20,10 +20,12 @@ from typing import (
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError if the state is unknown at any of the given events
|
RuntimeError if the state is unknown at any of the given events
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, int]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("event_id", "state_group"),
|
retcols=("event_id", "state_group"),
|
||||||
desc="_get_state_group_for_events",
|
desc="_get_state_group_for_events",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
res = {row["event_id"]: row["state_group"] for row in rows}
|
res = dict(rows)
|
||||||
for e in event_ids:
|
for e in event_ids:
|
||||||
if e not in res:
|
if e not in res:
|
||||||
raise RuntimeError("No state group for unknown or outlier event %s" % e)
|
raise RuntimeError("No state group for unknown or outlier event %s" % e)
|
||||||
|
@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
The subset of state groups that are referenced.
|
The subset of state groups that are referenced.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[int]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
column="state_group",
|
column="state_group",
|
||||||
iterable=state_groups,
|
iterable=state_groups,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("DISTINCT state_group",),
|
retcols=("DISTINCT state_group",),
|
||||||
desc="get_referenced_state_groups",
|
desc="get_referenced_state_groups",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row["state_group"] for row in rows}
|
return {row[0] for row in rows}
|
||||||
|
|
||||||
async def update_state_for_partial_state_event(
|
async def update_state_for_partial_state_event(
|
||||||
self,
|
self,
|
||||||
|
@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||||
# potentially stale, since there may have been a period where the
|
# potentially stale, since there may have been a period where the
|
||||||
# server didn't share a room with the remote user and therefore may
|
# server didn't share a room with the remote user and therefore may
|
||||||
# have missed any device updates.
|
# have missed any device updates.
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="current_state_events",
|
table="current_state_events",
|
||||||
column="room_id",
|
column="room_id",
|
||||||
iterable=to_delete,
|
iterable=to_delete,
|
||||||
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
|
keyvalues={
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
},
|
||||||
retcols=("state_key",),
|
retcols=("state_key",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
potentially_left_users = {row["state_key"] for row in rows}
|
potentially_left_users = {row[0] for row in rows}
|
||||||
|
|
||||||
# Now lets actually delete the rooms from the DB.
|
# Now lets actually delete the rooms from the DB.
|
||||||
self.db_pool.simple_delete_many_txn(
|
self.db_pool.simple_delete_many_txn(
|
||||||
|
|
|
@ -506,7 +506,9 @@ class StatsStore(StateDeltasStore):
|
||||||
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
|
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
|
||||||
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
|
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
|
||||||
|
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="current_state_events",
|
table="current_state_events",
|
||||||
column="type",
|
column="type",
|
||||||
|
@ -522,9 +524,10 @@ class StatsStore(StateDeltasStore):
|
||||||
],
|
],
|
||||||
keyvalues={"room_id": room_id, "state_key": ""},
|
keyvalues={"room_id": room_id, "state_key": ""},
|
||||||
retcols=["event_id"],
|
retcols=["event_id"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
event_ids = cast(List[str], [row["event_id"] for row in rows])
|
event_ids = [row[0] for row in rows]
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||||
async def get_destination_retry_timings_batch(
|
async def get_destination_retry_timings_batch(
|
||||||
self, destinations: StrCollection
|
self, destinations: StrCollection
|
||||||
) -> Mapping[str, Optional[DestinationRetryTimings]]:
|
) -> Mapping[str, Optional[DestinationRetryTimings]]:
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="destinations",
|
table="destinations",
|
||||||
iterable=destinations,
|
iterable=destinations,
|
||||||
column="destination",
|
column="destination",
|
||||||
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
|
retcols=(
|
||||||
|
"destination",
|
||||||
|
"failure_ts",
|
||||||
|
"retry_last_ts",
|
||||||
|
"retry_interval",
|
||||||
|
),
|
||||||
desc="get_destination_retry_timings_batch",
|
desc="get_destination_retry_timings_batch",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row.pop("destination"): DestinationRetryTimings(**row)
|
destination: DestinationRetryTimings(
|
||||||
for row in rows
|
failure_ts, retry_last_ts, retry_interval
|
||||||
if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
|
)
|
||||||
|
for destination, failure_ts, retry_last_ts, retry_interval in rows
|
||||||
|
if retry_last_ts and failure_ts and retry_interval
|
||||||
}
|
}
|
||||||
|
|
||||||
async def set_destination_retry_timings(
|
async def set_destination_retry_timings(
|
||||||
|
|
|
@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
# If a registration token was used, decrement the pending counter
|
# If a registration token was used, decrement the pending counter
|
||||||
# before deleting the session.
|
# before deleting the session.
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="ui_auth_sessions_credentials",
|
table="ui_auth_sessions_credentials",
|
||||||
column="session_id",
|
column="session_id",
|
||||||
iterable=session_ids,
|
iterable=session_ids,
|
||||||
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
||||||
retcols=["result"],
|
retcols=["result"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the tokens used and how much pending needs to be decremented by.
|
# Get the tokens used and how much pending needs to be decremented by.
|
||||||
|
@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
# registration token stage for that session will be True.
|
# registration token stage for that session will be True.
|
||||||
# If a token was used to authenticate, but registration was
|
# If a token was used to authenticate, but registration was
|
||||||
# never completed, the result will be the token used.
|
# never completed, the result will be the token used.
|
||||||
token = db_to_json(r["result"])
|
token = db_to_json(r[0])
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
token_counts[token] = token_counts.get(token, 0) + 1
|
token_counts[token] = token_counts.get(token, 0) + 1
|
||||||
|
|
||||||
# Update the `pending` counters.
|
# Update the `pending` counters.
|
||||||
if len(token_counts) > 0:
|
if len(token_counts) > 0:
|
||||||
token_rows = self.db_pool.simple_select_many_txn(
|
token_rows = cast(
|
||||||
|
List[Tuple[str, int]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="registration_tokens",
|
table="registration_tokens",
|
||||||
column="token",
|
column="token",
|
||||||
iterable=list(token_counts.keys()),
|
iterable=list(token_counts.keys()),
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=["token", "pending"],
|
retcols=["token", "pending"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for token_row in token_rows:
|
for token, pending in token_rows:
|
||||||
token = token_row["token"]
|
new_pending = pending - token_counts[token]
|
||||||
new_pending = token_row["pending"] - token_counts[token]
|
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="registration_tokens",
|
table="registration_tokens",
|
||||||
|
|
|
@ -410,7 +410,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Next fetch their profiles. Note that not all users have profiles.
|
# Next fetch their profiles. Note that not all users have profiles.
|
||||||
profile_rows = self.db_pool.simple_select_many_txn(
|
profile_rows = cast(
|
||||||
|
List[Tuple[str, Optional[str], Optional[str]]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="profiles",
|
table="profiles",
|
||||||
column="full_user_id",
|
column="full_user_id",
|
||||||
|
@ -421,14 +423,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
"avatar_url",
|
"avatar_url",
|
||||||
),
|
),
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
profiles = {
|
profiles = {
|
||||||
row["full_user_id"]: _UserDirProfile(
|
full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
|
||||||
row["full_user_id"],
|
for full_user_id, displayname, avatar_url in profile_rows
|
||||||
row["displayname"],
|
|
||||||
row["avatar_url"],
|
|
||||||
)
|
|
||||||
for row in profile_rows
|
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles_to_insert = [
|
profiles_to_insert = [
|
||||||
|
@ -517,7 +516,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
|
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[str, Optional[str]]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="users",
|
table="users",
|
||||||
column="name",
|
column="name",
|
||||||
|
@ -526,9 +527,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
"deactivated": 0,
|
"deactivated": 0,
|
||||||
},
|
},
|
||||||
retcols=("name", "user_type"),
|
retcols=("name", "user_type"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
|
return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
|
||||||
|
|
||||||
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
|
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
|
||||||
"""Check if the room is either world_readable or publically joinable"""
|
"""Check if the room is either world_readable or publically joinable"""
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Iterable, Mapping
|
from typing import Iterable, List, Mapping, Tuple, cast
|
||||||
|
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||||
|
@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
|
||||||
Returns:
|
Returns:
|
||||||
for each user, whether the user has requested erasure.
|
for each user, whether the user has requested erasure.
|
||||||
"""
|
"""
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[str]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="erased_users",
|
table="erased_users",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
retcols=("user_id",),
|
retcols=("user_id",),
|
||||||
desc="are_users_erased",
|
desc="are_users_erased",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
erased_users = {row["user_id"] for row in rows}
|
erased_users = {row[0] for row in rows}
|
||||||
|
|
||||||
return {u: u in erased_users for u in user_ids}
|
return {u: u in erased_users for u in user_ids}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
"[purge] found %i state groups to delete", len(state_groups_to_delete)
|
"[purge] found %i state groups to delete", len(state_groups_to_delete)
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = self.db_pool.simple_select_many_txn(
|
rows = cast(
|
||||||
|
List[Tuple[int]],
|
||||||
|
self.db_pool.simple_select_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="state_group_edges",
|
table="state_group_edges",
|
||||||
column="prev_state_group",
|
column="prev_state_group",
|
||||||
iterable=state_groups_to_delete,
|
iterable=state_groups_to_delete,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("state_group",),
|
retcols=("state_group",),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
remaining_state_groups = {
|
remaining_state_groups = {
|
||||||
row["state_group"]
|
state_group
|
||||||
for row in rows
|
for state_group, in rows
|
||||||
if row["state_group"] not in state_groups_to_delete
|
if state_group not in state_groups_to_delete
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
A mapping from state group to previous state group.
|
A mapping from state group to previous state group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = cast(
|
||||||
|
List[Tuple[int, int]],
|
||||||
|
await self.db_pool.simple_select_many_batch(
|
||||||
table="state_group_edges",
|
table="state_group_edges",
|
||||||
column="prev_state_group",
|
column="prev_state_group",
|
||||||
iterable=state_groups,
|
iterable=state_groups,
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("prev_state_group", "state_group"),
|
retcols=("state_group", "prev_state_group"),
|
||||||
desc="get_previous_state_groups",
|
desc="get_previous_state_groups",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {row["state_group"]: row["prev_state_group"] for row in rows}
|
return dict(rows)
|
||||||
|
|
||||||
async def purge_room_state(
|
async def purge_room_state(
|
||||||
self, room_id: str, state_groups_to_delete: Collection[int]
|
self, room_id: str, state_groups_to_delete: Collection[int]
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple, cast
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
@ -421,7 +421,9 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
self, events: List[EventBase]
|
self, events: List[EventBase]
|
||||||
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
|
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
|
||||||
# Fetch the map from event ID -> (chain ID, sequence number)
|
# Fetch the map from event ID -> (chain ID, sequence number)
|
||||||
rows = self.get_success(
|
rows = cast(
|
||||||
|
List[Tuple[str, int, int]],
|
||||||
|
self.get_success(
|
||||||
self.store.db_pool.simple_select_many_batch(
|
self.store.db_pool.simple_select_many_batch(
|
||||||
table="event_auth_chains",
|
table="event_auth_chains",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
|
@ -429,14 +431,18 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
retcols=("event_id", "chain_id", "sequence_number"),
|
retcols=("event_id", "chain_id", "sequence_number"),
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
)
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_map = {
|
chain_map = {
|
||||||
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
|
event_id: (chain_id, sequence_number)
|
||||||
|
for event_id, chain_id, sequence_number in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
# Fetch all the links and pass them to the _LinkMap.
|
# Fetch all the links and pass them to the _LinkMap.
|
||||||
rows = self.get_success(
|
auth_chain_rows = cast(
|
||||||
|
List[Tuple[int, int, int, int]],
|
||||||
|
self.get_success(
|
||||||
self.store.db_pool.simple_select_many_batch(
|
self.store.db_pool.simple_select_many_batch(
|
||||||
table="event_auth_chain_links",
|
table="event_auth_chain_links",
|
||||||
column="origin_chain_id",
|
column="origin_chain_id",
|
||||||
|
@ -449,13 +455,19 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
),
|
),
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
)
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
link_map = _LinkMap()
|
link_map = _LinkMap()
|
||||||
for row in rows:
|
for (
|
||||||
|
origin_chain_id,
|
||||||
|
origin_sequence_number,
|
||||||
|
target_chain_id,
|
||||||
|
target_sequence_number,
|
||||||
|
) in auth_chain_rows:
|
||||||
added = link_map.add_link(
|
added = link_map.add_link(
|
||||||
(row["origin_chain_id"], row["origin_sequence_number"]),
|
(origin_chain_id, origin_sequence_number),
|
||||||
(row["target_chain_id"], row["target_sequence_number"]),
|
(target_chain_id, target_sequence_number),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We shouldn't have persisted any redundant links
|
# We shouldn't have persisted any redundant links
|
||||||
|
|
Loading…
Reference in a new issue