Fix GHSA-3h7q-rfh9-xm4v

Weakness in auth chain indexing allows DoS from remote room members
through disk fill and high CPU usage.

A remote Matrix user with malicious intent, sharing a room with Synapse
instances before 1.104.1, can dispatch specially crafted events to
exploit a weakness in how the auth chain cover index is calculated. This
can induce high CPU consumption and accumulate excessive data in the
database of such instances, resulting in a denial of service.

Servers in private federations, or those that do not federate, are not
affected.
This commit is contained in:
Erik Johnston 2024-04-23 15:24:08 +01:00
parent fbb2573525
commit 55b0aa847a
4 changed files with 117 additions and 104 deletions

1
changelog.d/17044.misc Normal file
View File

@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.

View File

@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import collections
import itertools
import logging
from collections import OrderedDict
@ -53,6 +54,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_federation import EventFederationStore
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine
@ -768,40 +770,26 @@ class PersistEventsStore:
# that have the same chain ID as the event.
# 2. For each retained auth event we:
# a. Add a link from the event's to the auth event's chain
# ID/sequence number; and
# b. Add a link from the event to every chain reachable by the
# auth event.
# ID/sequence number
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
auth_chain_rows = cast(
List[Tuple[int, int, int, int]],
db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
iterable={chain_id for chain_id, _ in chain_map.values()},
keyvalues={},
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
),
)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
chain_links.add_link(
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
new=False,
)
for links in EventFederationStore._get_chain_links(
txn, {chain_id for chain_id, _ in chain_map.values()}
):
for origin_chain_id, inner_links in links.items():
for (
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in inner_links:
chain_links.add_link(
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
new=False,
)
# We do this in toplogical order to avoid adding redundant links.
for event_id in sorted_topologically(
@ -836,18 +824,6 @@ class PersistEventsStore:
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
# Step 2b, add a link to chains reachable from the auth
# event.
for target_id, target_seq in chain_links.get_links_from(
(auth_chain_id, auth_sequence_number)
):
if target_id == chain_id:
continue
chain_links.add_link(
(chain_id, sequence_number), (target_id, target_seq)
)
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
@ -2451,31 +2427,6 @@ class _LinkMap:
current_links[src_seq] = target_seq
return True
def get_links_from(
self, src_tuple: Tuple[int, int]
) -> Generator[Tuple[int, int], None, None]:
"""Gets the chains reachable from the given chain/sequence number.
Yields:
The chain ID and sequence number the link points to.
"""
src_chain, src_seq = src_tuple
for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
for link_src_seq, target_seq in sequence_numbers.items():
if link_src_seq <= src_seq:
yield target_id, target_seq
def get_links_between(
self, source_chain: int, target_chain: int
) -> Generator[Tuple[int, int], None, None]:
"""Gets the links between two chains.
Yields:
The source and target sequence numbers.
"""
yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
"""Gets any newly added links.
@ -2502,9 +2453,24 @@ class _LinkMap:
if src_chain == target_chain:
return target_seq <= src_seq
links = self.get_links_between(src_chain, target_chain)
for link_start_seq, link_end_seq in links:
if link_start_seq <= src_seq and target_seq <= link_end_seq:
return True
# We have to graph traverse the links to check for indirect paths.
visited_chains = collections.Counter()
search = [(src_chain, src_seq)]
while search:
chain, seq = search.pop()
visited_chains[chain] = max(seq, visited_chains[chain])
for tc, links in self.maps.get(chain, {}).items():
for ss, ts in links.items():
# Don't revisit chains we've already seen, unless the target
# sequence number is higher than last time.
if ts <= visited_chains.get(tc, 0):
continue
if ss <= seq:
if tc == target_chain:
if target_seq <= ts:
return True
else:
search.append((tc, ts))
return False

View File

@ -132,12 +132,16 @@ Changes in SCHEMA_VERSION = 82
Changes in SCHEMA_VERSION = 83
- The event_txn_id is no longer used.
Changes in SCHEMA_VERSION = 84
- No longer assumes that `event_auth_chain_links` holds transitive links, and
so read operations must do graph traversal.
"""
SCHEMA_COMPAT_VERSION = (
# The event_txn_id table and tables from MSC2716 no longer exist.
83
# Transitive links are no longer written to `event_auth_chain_links`
84
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@ -21,6 +21,8 @@
from typing import Dict, List, Set, Tuple, cast
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
@ -45,7 +47,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
self.store = hs.get_datastores().main
self._next_stream_ordering = 1
def test_simple(self) -> None:
@parameterized.expand([(False,), (True,)])
def test_simple(self, batched: bool) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@ -53,6 +56,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
charlie = "@charlie:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
@ -191,6 +195,26 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
)
charlie_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": charlie,
"sender": alice,
"room_id": room_id,
"content": {"tag": "charlie_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[
create.event_id,
alice_join2.event_id,
power_2.event_id,
],
)
)
events = [
create,
bob_join,
@ -200,33 +224,41 @@ class EventChainStoreTestCase(HomeserverTestCase):
bob_join_2,
power_2,
alice_join2,
charlie_invite,
]
expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
(bob_join_2, power),
(alice_join2, power_2),
(charlie_invite, alice_join2),
]
self.persist(events)
# We either persist as a batch or one-by-one depending on test
# parameter.
if batched:
self.persist(events)
else:
for event in events:
self.persist([event])
chain_map, link_map = self.fetch_chains(events)
# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
event_map = {e.event_id: e for e in events}
reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
self.maxDiff = None
self.assertCountEqual(
expected_links,
[
(reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
for s1, s2, t1, t2 in link_map.get_additions()
],
)
# Test that everything can reach the create event, but the create event
# can't reach anything.
@ -368,24 +400,23 @@ class EventChainStoreTestCase(HomeserverTestCase):
expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
]
# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
event_map = {e.event_id: e for e in events}
reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
self.maxDiff = None
self.assertCountEqual(
expected_links,
[
(reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
for s1, s2, t1, t2 in link_map.get_additions()
],
)
def persist(
self,
@ -489,8 +520,6 @@ class LinkMapTestCase(unittest.TestCase):
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
@ -499,18 +528,31 @@ class LinkMapTestCase(unittest.TestCase):
# Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_additions(), [])
# Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)])
self.assertTrue(link_map.add_link((2, 5), (1, 3)))
self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
def test_exists_path_from(self) -> None:
"Check that `exists_path_from` can handle non-direct links"
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
link_map.add_link((2, 1), (3, 1), new=False)
self.assertTrue(link_map.exists_path_from((1, 4), (3, 1)))
self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
link_map.add_link((1, 5), (2, 3), new=False)
link_map.add_link((2, 2), (3, 3), new=False)
self.assertTrue(link_map.exists_path_from((1, 6), (3, 2)))
self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [