From 55b0aa847a61774b6a3acdc4b177a20dc019f01a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Apr 2024 15:24:08 +0100 Subject: [PATCH] Fix GHSA-3h7q-rfh9-xm4v Weakness in auth chain indexing allows DoS from remote room members through disk fill and high CPU usage. A remote Matrix user with malicious intent, sharing a room with Synapse instances before 1.104.1, can dispatch specially crafted events to exploit a weakness in how the auth chain cover index is calculated. This can induce high CPU consumption and accumulate excessive data in the database of such instances, resulting in a denial of service. Servers in private federations, or those that do not federate, are not affected. --- changelog.d/17044.misc | 1 + synapse/storage/databases/main/events.py | 108 ++++++++--------------- synapse/storage/schema/__init__.py | 8 +- tests/storage/test_event_chain.py | 104 +++++++++++++++------- 4 files changed, 117 insertions(+), 104 deletions(-) create mode 100644 changelog.d/17044.misc diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc new file mode 100644 index 000000000..a1439752d --- /dev/null +++ b/changelog.d/17044.misc @@ -0,0 +1 @@ +Refactor auth chain fetching to reduce duplication. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a6fda3f43..1e731d56b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -19,6 +19,7 @@ # [This file includes modifications made by New Vector Limited] # # +import collections import itertools import logging from collections import OrderedDict @@ -53,6 +54,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.event_federation import EventFederationStore from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine @@ -768,40 +770,26 @@ class PersistEventsStore: # that have the same chain ID as the event. # 2. For each retained auth event we: # a. Add a link from the event's to the auth event's chain - # ID/sequence number; and - # b. Add a link from the event to every chain reachable by the - # auth event. + # ID/sequence number # Step 1, fetch all existing links from all the chains we've seen # referenced. chain_links = _LinkMap() - auth_chain_rows = cast( - List[Tuple[int, int, int, int]], - db_pool.simple_select_many_txn( - txn, - table="event_auth_chain_links", - column="origin_chain_id", - iterable={chain_id for chain_id, _ in chain_map.values()}, - keyvalues={}, - retcols=( - "origin_chain_id", - "origin_sequence_number", - "target_chain_id", - "target_sequence_number", - ), - ), - ) - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in auth_chain_rows: - chain_links.add_link( - (origin_chain_id, origin_sequence_number), - (target_chain_id, target_sequence_number), - new=False, - ) + + for links in EventFederationStore._get_chain_links( + txn, {chain_id for chain_id, _ in chain_map.values()} + ): + for origin_chain_id, inner_links in links.items(): + for ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in inner_links: + chain_links.add_link( + (origin_chain_id, origin_sequence_number), + (target_chain_id, target_sequence_number), + new=False, + ) # We do this in toplogical order to avoid adding redundant links. for event_id in sorted_topologically( @@ -836,18 +824,6 @@ class PersistEventsStore: (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) ) - # Step 2b, add a link to chains reachable from the auth - # event. - for target_id, target_seq in chain_links.get_links_from( - (auth_chain_id, auth_sequence_number) - ): - if target_id == chain_id: - continue - - chain_links.add_link( - (chain_id, sequence_number), (target_id, target_seq) - ) - db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", @@ -2451,31 +2427,6 @@ class _LinkMap: current_links[src_seq] = target_seq return True - def get_links_from( - self, src_tuple: Tuple[int, int] - ) -> Generator[Tuple[int, int], None, None]: - """Gets the chains reachable from the given chain/sequence number. - - Yields: - The chain ID and sequence number the link points to. - """ - src_chain, src_seq = src_tuple - for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): - for link_src_seq, target_seq in sequence_numbers.items(): - if link_src_seq <= src_seq: - yield target_id, target_seq - - def get_links_between( - self, source_chain: int, target_chain: int - ) -> Generator[Tuple[int, int], None, None]: - """Gets the links between two chains. - - Yields: - The source and target sequence numbers. - """ - - yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() - def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: """Gets any newly added links. @@ -2502,9 +2453,24 @@ class _LinkMap: if src_chain == target_chain: return target_seq <= src_seq - links = self.get_links_between(src_chain, target_chain) - for link_start_seq, link_end_seq in links: - if link_start_seq <= src_seq and target_seq <= link_end_seq: - return True + # We have to graph traverse the links to check for indirect paths. + visited_chains = collections.Counter() + search = [(src_chain, src_seq)] + while search: + chain, seq = search.pop() + visited_chains[chain] = max(seq, visited_chains[chain]) + for tc, links in self.maps.get(chain, {}).items(): + for ss, ts in links.items(): + # Don't revisit chains we've already seen, unless the target + # sequence number is higher than last time. + if ts <= visited_chains.get(tc, 0): + continue + + if ss <= seq: + if tc == target_chain: + if target_seq <= ts: + return True + else: + search.append((tc, ts)) return False diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index c0b925444..039aa91b9 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -132,12 +132,16 @@ Changes in SCHEMA_VERSION = 82 Changes in SCHEMA_VERSION = 83 - The event_txn_id is no longer used. + +Changes in SCHEMA_VERSION = 84 + - No longer assumes that `event_auth_chain_links` holds transitive links, and + so read operations must do graph traversal. """ SCHEMA_COMPAT_VERSION = ( - # The event_txn_id table and tables from MSC2716 no longer exist. - 83 + # Transitive links are no longer written to `event_auth_chain_links` + 84 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 9e4e73832..27d5b0125 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -21,6 +21,8 @@ from typing import Dict, List, Set, Tuple, cast +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest @@ -45,7 +47,8 @@ class EventChainStoreTestCase(HomeserverTestCase): self.store = hs.get_datastores().main self._next_stream_ordering = 1 - def test_simple(self) -> None: + @parameterized.expand([(False,), (True,)]) + def test_simple(self, batched: bool) -> None: """Test that the example in `docs/auth_chain_difference_algorithm.md` works. """ @@ -53,6 +56,7 @@ class EventChainStoreTestCase(HomeserverTestCase): event_factory = self.hs.get_event_builder_factory() bob = "@creator:test" alice = "@alice:test" + charlie = "@charlie:test" room_id = "!room:test" # Ensure that we have a rooms entry so that we generate the chain index. @@ -191,6 +195,26 @@ class EventChainStoreTestCase(HomeserverTestCase): ) ) + charlie_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": charlie, + "sender": alice, + "room_id": room_id, + "content": {"tag": "charlie_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[ + create.event_id, + alice_join2.event_id, + power_2.event_id, + ], + ) + ) + events = [ create, bob_join, @@ -200,33 +224,41 @@ class EventChainStoreTestCase(HomeserverTestCase): bob_join_2, power_2, alice_join2, + charlie_invite, ] expected_links = [ (bob_join, create), - (power, create), (power, bob_join), - (alice_invite, create), (alice_invite, power), - (alice_invite, bob_join), (bob_join_2, power), (alice_join2, power_2), + (charlie_invite, alice_join2), ] - self.persist(events) + # We either persist as a batch or one-by-one depending on test + # parameter. + if batched: + self.persist(events) + else: + for event in events: + self.persist([event]) + chain_map, link_map = self.fetch_chains(events) # Check that the expected links and only the expected links have been # added. - self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + event_map = {e.event_id: e for e in events} + reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()} - for start, end in expected_links: - start_id, start_seq = chain_map[start.event_id] - end_id, end_seq = chain_map[end.event_id] - - self.assertIn( - (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) - ) + self.maxDiff = None + self.assertCountEqual( + expected_links, + [ + (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)]) + for s1, s2, t1, t2 in link_map.get_additions() + ], + ) # Test that everything can reach the create event, but the create event # can't reach anything. @@ -368,24 +400,23 @@ class EventChainStoreTestCase(HomeserverTestCase): expected_links = [ (bob_join, create), - (power, create), (power, bob_join), - (alice_invite, create), (alice_invite, power), - (alice_invite, bob_join), ] # Check that the expected links and only the expected links have been # added. - self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + event_map = {e.event_id: e for e in events} + reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()} - for start, end in expected_links: - start_id, start_seq = chain_map[start.event_id] - end_id, end_seq = chain_map[end.event_id] - - self.assertIn( - (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) - ) + self.maxDiff = None + self.assertCountEqual( + expected_links, + [ + (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)]) + for s1, s2, t1, t2 in link_map.get_additions() + ], + ) def persist( self, @@ -489,8 +520,6 @@ class LinkMapTestCase(unittest.TestCase): link_map = _LinkMap() link_map.add_link((1, 1), (2, 1), new=False) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) - self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) self.assertCountEqual(link_map.get_additions(), []) self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) @@ -499,18 +528,31 @@ class LinkMapTestCase(unittest.TestCase): # Attempting to add a redundant link is ignored. self.assertFalse(link_map.add_link((1, 4), (2, 1))) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_additions(), []) # Adding new non-redundant links works self.assertTrue(link_map.add_link((1, 3), (2, 3))) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)]) self.assertTrue(link_map.add_link((2, 5), (1, 3))) - self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) - self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) + def test_exists_path_from(self) -> None: + "Check that `exists_path_from` can handle non-direct links" + link_map = _LinkMap() + + link_map.add_link((1, 1), (2, 1), new=False) + link_map.add_link((2, 1), (3, 1), new=False) + + self.assertTrue(link_map.exists_path_from((1, 4), (3, 1))) + self.assertFalse(link_map.exists_path_from((1, 4), (3, 2))) + + link_map.add_link((1, 5), (2, 3), new=False) + link_map.add_link((2, 2), (3, 3), new=False) + + self.assertTrue(link_map.exists_path_from((1, 6), (3, 2))) + self.assertFalse(link_map.exists_path_from((1, 4), (3, 2))) + class EventChainBackgroundUpdateTestCase(HomeserverTestCase): servlets = [