0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-21 02:24:03 +01:00
synapse/tests/storage/test_event_chain.py

587 lines
20 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Tuple
from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events import _LinkMap
from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
class EventChainStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self._next_stream_ordering = 1
def test_simple(self):
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
self.get_success(
self.store.store_room(
room_id=room_id,
room_creator_user_id="",
is_public=True,
room_version=RoomVersions.V6,
)
)
create = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Create,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "create"},
},
).build(prev_event_ids=[], auth_event_ids=[])
)
bob_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": bob,
"sender": bob,
"room_id": room_id,
"content": {"tag": "bob_join"},
},
).build(prev_event_ids=[], auth_event_ids=[create.event_id])
)
power = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.PowerLevels,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "power"},
},
).build(
prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
)
)
alice_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": bob,
"room_id": room_id,
"content": {"tag": "alice_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
alice_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
)
)
power_2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.PowerLevels,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "power_2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
bob_join_2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": bob,
"sender": bob,
"room_id": room_id,
"content": {"tag": "bob_join_2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
alice_join2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[
create.event_id,
alice_join.event_id,
power_2.event_id,
],
)
)
events = [
create,
bob_join,
power,
alice_invite,
alice_join,
bob_join_2,
power_2,
alice_join2,
]
expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
(bob_join_2, power),
(alice_join2, power_2),
]
self.persist(events)
chain_map, link_map = self.fetch_chains(events)
# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
# Test that everything can reach the create event, but the create event
# can't reach anything.
for event in events[1:]:
self.assertTrue(
link_map.exists_path_from(
chain_map[event.event_id], chain_map[create.event_id]
),
)
self.assertFalse(
link_map.exists_path_from(
chain_map[create.event_id], chain_map[event.event_id],
),
)
def test_out_of_order_events(self):
"""Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships).
"""
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
self.get_success(
self.store.store_room(
room_id=room_id,
room_creator_user_id="",
is_public=True,
room_version=RoomVersions.V6,
)
)
# First persist the base room.
create = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Create,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "create"},
},
).build(prev_event_ids=[], auth_event_ids=[])
)
bob_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": bob,
"sender": bob,
"room_id": room_id,
"content": {"tag": "bob_join"},
},
).build(prev_event_ids=[], auth_event_ids=[create.event_id])
)
power = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.PowerLevels,
"state_key": "",
"sender": bob,
"room_id": room_id,
"content": {"tag": "power"},
},
).build(
prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
)
)
self.persist([create, bob_join, power])
# Now persist an invite and a couple of memberships out of order.
alice_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": bob,
"room_id": room_id,
"content": {"tag": "alice_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
)
)
alice_join = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
)
)
alice_join2 = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": alice,
"sender": alice,
"room_id": room_id,
"content": {"tag": "alice_join2"},
},
).build(
prev_event_ids=[],
auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
)
)
self.persist([alice_join])
self.persist([alice_join2])
self.persist([alice_invite])
# The end result should be sane.
events = [create, bob_join, power, alice_invite, alice_join]
chain_map, link_map = self.fetch_chains(events)
expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
]
# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
def persist(
self, events: List[EventBase],
):
"""Persist the given events and check that the links generated match
those given.
"""
persist_events_store = self.hs.get_datastores().persist_events
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1
def _persist(txn):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
# Actually call the function that calculates the auth chain stuff.
persist_events_store._persist_event_auth_chain_txn(txn, events)
self.get_success(
persist_events_store.db_pool.runInteraction("_persist", _persist,)
)
def fetch_chains(
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
# Fetch the map from event ID -> (chain ID, sequence number)
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chains",
column="event_id",
iterable=[e.event_id for e in events],
retcols=("event_id", "chain_id", "sequence_number"),
keyvalues={},
)
)
chain_map = {
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
}
# Fetch all the links and pass them to the _LinkMap.
rows = self.get_success(
self.store.db_pool.simple_select_many_batch(
table="event_auth_chain_links",
column="origin_chain_id",
iterable=[chain_id for chain_id, _ in chain_map.values()],
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
keyvalues={},
)
)
link_map = _LinkMap()
for row in rows:
added = link_map.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]),
(row["target_chain_id"], row["target_sequence_number"]),
)
# We shouldn't have persisted any redundant links
self.assertTrue(added)
return chain_map, link_map
class LinkMapTestCase(unittest.TestCase):
def test_simple(self):
"""Basic tests for the LinkMap.
"""
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
# Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
# Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertTrue(link_map.add_link((2, 5), (1, 3)))
self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def test_background_update(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
user_id = self.register_user("foo", "pass")
token = self.login("foo", "pass")
room_id = self.helper.create_room_as(user_id, tok=token)
requester = create_requester(user_id)
store = self.hs.get_datastore()
# Mark the room as not having a chain cover index
self.get_success(
store.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False},
desc="test",
)
)
# Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler()
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
event, context = self.get_success(
event_handler.create_event(
requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
)
state1 = list(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success(
event_handler.create_event(
requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
)
state2 = list(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info.
def _delete_tables(txn):
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
self.get_success(store.db_pool.runInteraction("test", _delete_tables))
# Insert and run the background update.
self.get_success(
store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
store.db_pool.updates._all_done = False
while not self.get_success(
store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
store.db_pool.runInteraction(
"test",
store._get_auth_chain_difference_using_cover_index_txn,
room_id,
[state1, state2],
)
)