0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-05-22 21:43:44 +02:00

Fix GHSA-3h7q-rfh9-xm4v

Weakness in auth chain indexing allows DoS from remote room members
through disk fill and high CPU usage.

A remote Matrix user with malicious intent, sharing a room with Synapse
instances before 1.104.1, can dispatch specially crafted events to
exploit a weakness in how the auth chain cover index is calculated. This
can induce high CPU consumption and accumulate excessive data in the
database of such instances, resulting in a denial of service.

Servers in private federations, or those that do not federate, are not
affected.
This commit is contained in:
Erik Johnston 2024-04-23 15:24:08 +01:00
parent fbb2573525
commit 55b0aa847a
4 changed files with 117 additions and 104 deletions

1
changelog.d/17044.misc Normal file
View file

@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.

View file

@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
import collections
import itertools import itertools
import logging import logging
from collections import OrderedDict from collections import OrderedDict
@ -53,6 +54,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.event_federation import EventFederationStore
from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -768,40 +770,26 @@ class PersistEventsStore:
# that have the same chain ID as the event. # that have the same chain ID as the event.
# 2. For each retained auth event we: # 2. For each retained auth event we:
# a. Add a link from the event's to the auth event's chain # a. Add a link from the event's to the auth event's chain
# ID/sequence number; and # ID/sequence number
# b. Add a link from the event to every chain reachable by the
# auth event.
# Step 1, fetch all existing links from all the chains we've seen # Step 1, fetch all existing links from all the chains we've seen
# referenced. # referenced.
chain_links = _LinkMap() chain_links = _LinkMap()
auth_chain_rows = cast(
List[Tuple[int, int, int, int]], for links in EventFederationStore._get_chain_links(
db_pool.simple_select_many_txn( txn, {chain_id for chain_id, _ in chain_map.values()}
txn, ):
table="event_auth_chain_links", for origin_chain_id, inner_links in links.items():
column="origin_chain_id", for (
iterable={chain_id for chain_id, _ in chain_map.values()}, origin_sequence_number,
keyvalues={}, target_chain_id,
retcols=( target_sequence_number,
"origin_chain_id", ) in inner_links:
"origin_sequence_number", chain_links.add_link(
"target_chain_id", (origin_chain_id, origin_sequence_number),
"target_sequence_number", (target_chain_id, target_sequence_number),
), new=False,
), )
)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
chain_links.add_link(
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
new=False,
)
# We do this in toplogical order to avoid adding redundant links. # We do this in toplogical order to avoid adding redundant links.
for event_id in sorted_topologically( for event_id in sorted_topologically(
@ -836,18 +824,6 @@ class PersistEventsStore:
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number) (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
) )
# Step 2b, add a link to chains reachable from the auth
# event.
for target_id, target_seq in chain_links.get_links_from(
(auth_chain_id, auth_sequence_number)
):
if target_id == chain_id:
continue
chain_links.add_link(
(chain_id, sequence_number), (target_id, target_seq)
)
db_pool.simple_insert_many_txn( db_pool.simple_insert_many_txn(
txn, txn,
table="event_auth_chain_links", table="event_auth_chain_links",
@ -2451,31 +2427,6 @@ class _LinkMap:
current_links[src_seq] = target_seq current_links[src_seq] = target_seq
return True return True
def get_links_from(
self, src_tuple: Tuple[int, int]
) -> Generator[Tuple[int, int], None, None]:
"""Gets the chains reachable from the given chain/sequence number.
Yields:
The chain ID and sequence number the link points to.
"""
src_chain, src_seq = src_tuple
for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
for link_src_seq, target_seq in sequence_numbers.items():
if link_src_seq <= src_seq:
yield target_id, target_seq
def get_links_between(
self, source_chain: int, target_chain: int
) -> Generator[Tuple[int, int], None, None]:
"""Gets the links between two chains.
Yields:
The source and target sequence numbers.
"""
yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
"""Gets any newly added links. """Gets any newly added links.
@ -2502,9 +2453,24 @@ class _LinkMap:
if src_chain == target_chain: if src_chain == target_chain:
return target_seq <= src_seq return target_seq <= src_seq
links = self.get_links_between(src_chain, target_chain) # We have to graph traverse the links to check for indirect paths.
for link_start_seq, link_end_seq in links: visited_chains = collections.Counter()
if link_start_seq <= src_seq and target_seq <= link_end_seq: search = [(src_chain, src_seq)]
return True while search:
chain, seq = search.pop()
visited_chains[chain] = max(seq, visited_chains[chain])
for tc, links in self.maps.get(chain, {}).items():
for ss, ts in links.items():
# Don't revisit chains we've already seen, unless the target
# sequence number is higher than last time.
if ts <= visited_chains.get(tc, 0):
continue
if ss <= seq:
if tc == target_chain:
if target_seq <= ts:
return True
else:
search.append((tc, ts))
return False return False

View file

@ -132,12 +132,16 @@ Changes in SCHEMA_VERSION = 82
Changes in SCHEMA_VERSION = 83 Changes in SCHEMA_VERSION = 83
- The event_txn_id is no longer used. - The event_txn_id is no longer used.
Changes in SCHEMA_VERSION = 84
- No longer assumes that `event_auth_chain_links` holds transitive links, and
so read operations must do graph traversal.
""" """
SCHEMA_COMPAT_VERSION = ( SCHEMA_COMPAT_VERSION = (
# The event_txn_id table and tables from MSC2716 no longer exist. # Transitive links are no longer written to `event_auth_chain_links`
83 84
) )
"""Limit on how far the synapse codebase can be rolled back without breaking db compat """Limit on how far the synapse codebase can be rolled back without breaking db compat

View file

@ -21,6 +21,8 @@
from typing import Dict, List, Set, Tuple, cast from typing import Dict, List, Set, Tuple, cast
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest from twisted.trial import unittest
@ -45,7 +47,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._next_stream_ordering = 1 self._next_stream_ordering = 1
def test_simple(self) -> None: @parameterized.expand([(False,), (True,)])
def test_simple(self, batched: bool) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md` """Test that the example in `docs/auth_chain_difference_algorithm.md`
works. works.
""" """
@ -53,6 +56,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
event_factory = self.hs.get_event_builder_factory() event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test" bob = "@creator:test"
alice = "@alice:test" alice = "@alice:test"
charlie = "@charlie:test"
room_id = "!room:test" room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index. # Ensure that we have a rooms entry so that we generate the chain index.
@ -191,6 +195,26 @@ class EventChainStoreTestCase(HomeserverTestCase):
) )
) )
charlie_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": charlie,
"sender": alice,
"room_id": room_id,
"content": {"tag": "charlie_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[
create.event_id,
alice_join2.event_id,
power_2.event_id,
],
)
)
events = [ events = [
create, create,
bob_join, bob_join,
@ -200,33 +224,41 @@ class EventChainStoreTestCase(HomeserverTestCase):
bob_join_2, bob_join_2,
power_2, power_2,
alice_join2, alice_join2,
charlie_invite,
] ]
expected_links = [ expected_links = [
(bob_join, create), (bob_join, create),
(power, create),
(power, bob_join), (power, bob_join),
(alice_invite, create),
(alice_invite, power), (alice_invite, power),
(alice_invite, bob_join),
(bob_join_2, power), (bob_join_2, power),
(alice_join2, power_2), (alice_join2, power_2),
(charlie_invite, alice_join2),
] ]
self.persist(events) # We either persist as a batch or one-by-one depending on test
# parameter.
if batched:
self.persist(events)
else:
for event in events:
self.persist([event])
chain_map, link_map = self.fetch_chains(events) chain_map, link_map = self.fetch_chains(events)
# Check that the expected links and only the expected links have been # Check that the expected links and only the expected links have been
# added. # added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) event_map = {e.event_id: e for e in events}
reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
for start, end in expected_links: self.maxDiff = None
start_id, start_seq = chain_map[start.event_id] self.assertCountEqual(
end_id, end_seq = chain_map[end.event_id] expected_links,
[
self.assertIn( (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) for s1, s2, t1, t2 in link_map.get_additions()
) ],
)
# Test that everything can reach the create event, but the create event # Test that everything can reach the create event, but the create event
# can't reach anything. # can't reach anything.
@ -368,24 +400,23 @@ class EventChainStoreTestCase(HomeserverTestCase):
expected_links = [ expected_links = [
(bob_join, create), (bob_join, create),
(power, create),
(power, bob_join), (power, bob_join),
(alice_invite, create),
(alice_invite, power), (alice_invite, power),
(alice_invite, bob_join),
] ]
# Check that the expected links and only the expected links have been # Check that the expected links and only the expected links have been
# added. # added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) event_map = {e.event_id: e for e in events}
reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
for start, end in expected_links: self.maxDiff = None
start_id, start_seq = chain_map[start.event_id] self.assertCountEqual(
end_id, end_seq = chain_map[end.event_id] expected_links,
[
self.assertIn( (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) for s1, s2, t1, t2 in link_map.get_additions()
) ],
)
def persist( def persist(
self, self,
@ -489,8 +520,6 @@ class LinkMapTestCase(unittest.TestCase):
link_map = _LinkMap() link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False) link_map.add_link((1, 1), (2, 1), new=False)
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), []) self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
@ -499,18 +528,31 @@ class LinkMapTestCase(unittest.TestCase):
# Attempting to add a redundant link is ignored. # Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1))) self.assertFalse(link_map.add_link((1, 4), (2, 1)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) self.assertCountEqual(link_map.get_additions(), [])
# Adding new non-redundant links works # Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3))) self.assertTrue(link_map.add_link((1, 3), (2, 3)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)])
self.assertTrue(link_map.add_link((2, 5), (1, 3))) self.assertTrue(link_map.add_link((2, 5), (1, 3)))
self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
def test_exists_path_from(self) -> None:
"Check that `exists_path_from` can handle non-direct links"
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
link_map.add_link((2, 1), (3, 1), new=False)
self.assertTrue(link_map.exists_path_from((1, 4), (3, 1)))
self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
link_map.add_link((1, 5), (2, 3), new=False)
link_map.add_link((2, 2), (3, 3), new=False)
self.assertTrue(link_map.exists_path_from((1, 6), (3, 2)))
self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
class EventChainBackgroundUpdateTestCase(HomeserverTestCase): class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [ servlets = [