mirror of
https://mau.dev/maunium/synapse.git
synced 2024-05-17 02:53:46 +02:00
Refactor chain fetching (#17044)
Since these queries are duplicated in two places.
This commit is contained in:
parent
fd48fc4585
commit
ec174d0470
1
changelog.d/17044.misc
Normal file
1
changelog.d/17044.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor auth chain fetching to reduce duplication.
|
|
@ -27,6 +27,7 @@ from typing import (
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# Now we look up all links for the chains we have, adding chains that
|
# Now we look up all links for the chains we have, adding chains that
|
||||||
# are reachable from any event.
|
# are reachable from any event.
|
||||||
#
|
|
||||||
# This query is structured to first get all chain IDs reachable, and
|
|
||||||
# then pull out all links from those chains. This does pull out more
|
|
||||||
# rows than is strictly necessary, however there isn't a way of
|
|
||||||
# structuring the recursive part of query to pull out the links without
|
|
||||||
# also returning large quantities of redundant data (which can make it a
|
|
||||||
# lot slower).
|
|
||||||
sql = """
|
|
||||||
WITH RECURSIVE links(chain_id) AS (
|
|
||||||
SELECT
|
|
||||||
DISTINCT origin_chain_id
|
|
||||||
FROM event_auth_chain_links WHERE %s
|
|
||||||
UNION
|
|
||||||
SELECT
|
|
||||||
target_chain_id
|
|
||||||
FROM event_auth_chain_links
|
|
||||||
INNER JOIN links ON (chain_id = origin_chain_id)
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
origin_chain_id, origin_sequence_number,
|
|
||||||
target_chain_id, target_sequence_number
|
|
||||||
FROM links
|
|
||||||
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# A map from chain ID to max sequence number *reachable* from any event ID.
|
# A map from chain ID to max sequence number *reachable* from any event ID.
|
||||||
chains: Dict[int, int] = {}
|
chains: Dict[int, int] = {}
|
||||||
|
for links in self._get_chain_links(txn, set(event_chains.keys())):
|
||||||
# Add all linked chains reachable from initial set of chains.
|
|
||||||
chains_to_fetch = set(event_chains.keys())
|
|
||||||
while chains_to_fetch:
|
|
||||||
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
|
|
||||||
chains_to_fetch.difference_update(batch2)
|
|
||||||
clause, args = make_in_list_sql_clause(
|
|
||||||
txn.database_engine, "origin_chain_id", batch2
|
|
||||||
)
|
|
||||||
txn.execute(sql % (clause,), args)
|
|
||||||
|
|
||||||
links: Dict[int, List[Tuple[int, int, int]]] = {}
|
|
||||||
|
|
||||||
for (
|
|
||||||
origin_chain_id,
|
|
||||||
origin_sequence_number,
|
|
||||||
target_chain_id,
|
|
||||||
target_sequence_number,
|
|
||||||
) in txn:
|
|
||||||
links.setdefault(origin_chain_id, []).append(
|
|
||||||
(origin_sequence_number, target_chain_id, target_sequence_number)
|
|
||||||
)
|
|
||||||
|
|
||||||
for chain_id in links:
|
for chain_id in links:
|
||||||
if chain_id not in event_chains:
|
if chain_id not in event_chains:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_materialize(chain_id, event_chains[chain_id], links, chains)
|
_materialize(chain_id, event_chains[chain_id], links, chains)
|
||||||
|
|
||||||
chains_to_fetch.difference_update(chains)
|
|
||||||
|
|
||||||
# Add the initial set of chains, excluding the sequence corresponding to
|
# Add the initial set of chains, excluding the sequence corresponding to
|
||||||
# initial event.
|
# initial event.
|
||||||
for chain_id, seq_no in event_chains.items():
|
for chain_id, seq_no in event_chains.items():
|
||||||
|
@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_chain_links(
|
||||||
|
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
|
||||||
|
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
|
||||||
|
"""Fetch all auth chain links from the given set of chains, and all
|
||||||
|
links from those chains, recursively.
|
||||||
|
|
||||||
|
Note: This may return links that are not reachable from the given
|
||||||
|
chains.
|
||||||
|
|
||||||
|
Returns a generator that produces dicts from origin chain ID to 3-tuple
|
||||||
|
of origin sequence number, target chain ID and target sequence number.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This query is structured to first get all chain IDs reachable, and
|
||||||
|
# then pull out all links from those chains. This does pull out more
|
||||||
|
# rows than is strictly necessary, however there isn't a way of
|
||||||
|
# structuring the recursive part of query to pull out the links without
|
||||||
|
# also returning large quantities of redundant data (which can make it a
|
||||||
|
# lot slower).
|
||||||
|
sql = """
|
||||||
|
WITH RECURSIVE links(chain_id) AS (
|
||||||
|
SELECT
|
||||||
|
DISTINCT origin_chain_id
|
||||||
|
FROM event_auth_chain_links WHERE %s
|
||||||
|
UNION
|
||||||
|
SELECT
|
||||||
|
target_chain_id
|
||||||
|
FROM event_auth_chain_links
|
||||||
|
INNER JOIN links ON (chain_id = origin_chain_id)
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
origin_chain_id, origin_sequence_number,
|
||||||
|
target_chain_id, target_sequence_number
|
||||||
|
FROM links
|
||||||
|
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
while chains_to_fetch:
|
||||||
|
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
|
||||||
|
chains_to_fetch.difference_update(batch2)
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
txn.database_engine, "origin_chain_id", batch2
|
||||||
|
)
|
||||||
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
links: Dict[int, List[Tuple[int, int, int]]] = {}
|
||||||
|
|
||||||
|
for (
|
||||||
|
origin_chain_id,
|
||||||
|
origin_sequence_number,
|
||||||
|
target_chain_id,
|
||||||
|
target_sequence_number,
|
||||||
|
) in txn:
|
||||||
|
links.setdefault(origin_chain_id, []).append(
|
||||||
|
(origin_sequence_number, target_chain_id, target_sequence_number)
|
||||||
|
)
|
||||||
|
|
||||||
|
chains_to_fetch.difference_update(links)
|
||||||
|
|
||||||
|
yield links
|
||||||
|
|
||||||
def _get_auth_chain_ids_txn(
|
def _get_auth_chain_ids_txn(
|
||||||
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
|
@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# Now we look up all links for the chains we have, adding chains that
|
# Now we look up all links for the chains we have, adding chains that
|
||||||
# are reachable from any event.
|
# are reachable from any event.
|
||||||
#
|
|
||||||
# This query is structured to first get all chain IDs reachable, and
|
|
||||||
# then pull out all links from those chains. This does pull out more
|
|
||||||
# rows than is strictly necessary, however there isn't a way of
|
|
||||||
# structuring the recursive part of query to pull out the links without
|
|
||||||
# also returning large quantities of redundant data (which can make it a
|
|
||||||
# lot slower).
|
|
||||||
sql = """
|
|
||||||
WITH RECURSIVE links(chain_id) AS (
|
|
||||||
SELECT
|
|
||||||
DISTINCT origin_chain_id
|
|
||||||
FROM event_auth_chain_links WHERE %s
|
|
||||||
UNION
|
|
||||||
SELECT
|
|
||||||
target_chain_id
|
|
||||||
FROM event_auth_chain_links
|
|
||||||
INNER JOIN links ON (chain_id = origin_chain_id)
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
origin_chain_id, origin_sequence_number,
|
|
||||||
target_chain_id, target_sequence_number
|
|
||||||
FROM links
|
|
||||||
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# (We need to take a copy of `seen_chains` as we want to mutate it in
|
|
||||||
# the loop)
|
|
||||||
chains_to_fetch = set(seen_chains)
|
|
||||||
while chains_to_fetch:
|
|
||||||
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
|
|
||||||
clause, args = make_in_list_sql_clause(
|
|
||||||
txn.database_engine, "origin_chain_id", batch2
|
|
||||||
)
|
|
||||||
txn.execute(sql % (clause,), args)
|
|
||||||
|
|
||||||
links: Dict[int, List[Tuple[int, int, int]]] = {}
|
|
||||||
|
|
||||||
for (
|
|
||||||
origin_chain_id,
|
|
||||||
origin_sequence_number,
|
|
||||||
target_chain_id,
|
|
||||||
target_sequence_number,
|
|
||||||
) in txn:
|
|
||||||
links.setdefault(origin_chain_id, []).append(
|
|
||||||
(origin_sequence_number, target_chain_id, target_sequence_number)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# (We need to take a copy of `seen_chains` as the function mutates it)
|
||||||
|
for links in self._get_chain_links(txn, set(seen_chains)):
|
||||||
for chains in set_to_chain:
|
for chains in set_to_chain:
|
||||||
for chain_id in links:
|
for chain_id in links:
|
||||||
if chain_id not in chains:
|
if chain_id not in chains:
|
||||||
|
@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
_materialize(chain_id, chains[chain_id], links, chains)
|
_materialize(chain_id, chains[chain_id], links, chains)
|
||||||
|
|
||||||
chains_to_fetch.difference_update(chains)
|
|
||||||
seen_chains.update(chains)
|
seen_chains.update(chains)
|
||||||
|
|
||||||
# Now for each chain we figure out the maximum sequence number reachable
|
# Now for each chain we figure out the maximum sequence number reachable
|
||||||
|
|
Loading…
Reference in a new issue