0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-06 03:08:58 +02:00

Merge pull request #38 from matrix-org/new_state_resolution

New state resolution
This commit is contained in:
Mark Haines 2015-01-30 14:38:30 +00:00
commit 1251d017c1
2 changed files with 441 additions and 121 deletions

View file

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
@ -36,12 +37,16 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
""" """
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
@ -210,64 +215,93 @@ class StateHandler(object):
else: else:
prev_states = [] prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
}
try: try:
new_state = {} resolved_state = self._resolve_state_events(
new_state.update(unconflicted_state) conflicted_state, auth_events
for key, events in conflicted_state.items(): )
new_state[key] = self._resolve_state_events(events)
except: except:
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
new_state = unconflicted_state
new_state.update(resolved_state)
defer.returnValue((None, new_state, prev_states)) defer.returnValue((None, new_state, prev_states))
def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events:
key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(
user_id
)
if not level:
level = power_level_event.content.get("users_default", 0)
return level
else:
return 0
@log_function @log_function
def _resolve_state_events(self, events): def _resolve_state_events(self, conflicted_state, auth_events):
curr_events = events """ This is where we actually decide which of the conflicted state to
use.
new_powers = [ We resolve conflicts in the following order:
self._get_power_level_from_event_state(e, e.user_id) 1. power levels
for e in curr_events 2. memberships
] 3. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state.items():
power_levels = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels)
new_powers = [ auth_events.update(resolved_state)
int(p) if p else 0 for p in new_powers
]
max_power = max(new_powers) for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
curr_events = [ auth_events.update(resolved_state)
z[0] for z in zip(curr_events, new_powers)
if z[1] == max_power
]
if not curr_events: for key, events in conflicted_state.items():
raise RuntimeError("Max didn't get a max?") if key not in resolved_state:
elif len(curr_events) == 1: resolved_state[key] = self._resolve_normal_events(
return curr_events[0] events, auth_events
)
# TODO: For now, just choose the one with the largest event_id. return resolved_state
return (
sorted( def _resolve_auth_events(self, events, auth_events):
curr_events, reverse = [i for i in reversed(self._ordered_events(events))]
key=lambda e: hashlib.sha1(
e.event_id + e.user_id + e.room_id + e.type auth_events = dict(auth_events)
).hexdigest()
)[0] prev_event = reverse[0]
) for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func)

View file

@ -16,11 +16,120 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler from synapse.state import StateHandler
from mock import Mock from mock import Mock
_next_event_id = 1000
def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
prev_events=[], **kwargs):
global _next_event_id
if not event_id:
_next_event_id += 1
event_id = str(_next_event_id)
if not name:
if state_key is not None:
name = "<%s-%s, %s>" % (type, state_key, event_id,)
else:
name = "<%s, %s>" % (type, event_id,)
d = {
"event_id": event_id,
"type": type,
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
"prev_events": prev_events,
}
if state_key is not None:
d["state_key"] = state_key
d.update(kwargs)
event = FrozenEvent(d)
return event
class StateGroupStore(object):
def __init__(self):
self._event_to_state_group = {}
self._group_to_state = {}
self._next_group = 1
def get_state_groups(self, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
return defer.succeed(groups)
def store_state_groups(self, event, context):
if context.current_state is None:
return
state_events = context.current_state
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
self._group_to_state[state_group] = state_events.values()
self._event_to_state_group[event.event_id] = state_group
class DictObj(dict):
def __init__(self, **kwargs):
super(DictObj, self).__init__(kwargs)
self.__dict__ = self
class Graph(object):
def __init__(self, nodes, edges):
events = {}
clobbered = set(events.keys())
for event_id, fields in nodes.items():
refs = edges.get(event_id)
if refs:
clobbered.difference_update(refs)
prev_events = [(r, {}) for r in refs]
else:
prev_events = []
events[event_id] = create_event(
event_id=event_id,
prev_events=prev_events,
**fields
)
self._leaves = clobbered
self._events = sorted(events.values(), key=lambda e: e.depth)
def walk(self):
return iter(self._events)
def get_leaves(self):
return (self._events[i] for i in self._leaves)
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = Mock(
@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes", "add_event_hashes",
] ]
) )
hs = Mock(spec=["get_datastore"]) hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_auth.return_value = Auth(hs)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@defer.inlineCallbacks
def test_branch_no_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="",
depth=1,
),
"A": DictObj(
type=EventTypes.Message,
depth=2,
),
"B": DictObj(
type=EventTypes.Message,
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"D": DictObj(
type=EventTypes.Message,
depth=4,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="creator",
content={"membership": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
),
"D": DictObj(
type=EventTypes.Message,
depth=5,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"START", "A", "C"},
{e.event_id for e in context_store["D"].current_state.values()}
)
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="creator",
content={"membership": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj(
type=EventTypes.Member,
state_key="@user_id_2:example.com",
content={"membership": Membership.BAN},
membership=Membership.BAN,
depth=4,
),
"D": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
sender="@user_id_2:example.com",
),
"E": DictObj(
type=EventTypes.Message,
depth=5,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["B"],
"D": ["B"],
"E": ["C", "D"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"START", "A", "B", "C"},
{e.event_id for e in context_store["E"].current_state.values()}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_message(self): def test_annotate_with_old_message(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(
@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
event = self.create_event(type="state", state_key="", name="event") event = create_event(type="state", state_key="", name="event")
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(
@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
event.prev_events = []
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = "group_name_1"
@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
event = self.create_event(type="state", state_key="", name="event") event = create_event(type="state", state_key="", name="event")
event.prev_events = []
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = "group_name_1"
@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
event.prev_events = []
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
group_name_1 = "group_name_1" context = yield self._get_context(event, old_state_1, old_state_2)
group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = {
group_name_1: old_state_1,
group_name_2: old_state_2,
}
context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 5)
@ -181,21 +447,70 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = self.create_event(type="test4", state_key="", name="event") event = create_event(type="test4", state_key="", name="event")
event.prev_events = []
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
event = create_event(type="test4", name="event")
member_event = create_event(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={
"membership": Membership.JOIN,
}
)
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"
group_name_2 = "group_name_2" group_name_2 = "group_name_2"
@ -204,33 +519,4 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
context = yield self.state.compute_event_context(event) return self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
def create_event(self, name=None, type=None, state_key=None):
self.event_id += 1
event_id = str(self.event_id)
if not name:
if state_key is not None:
name = "<%s-%s>" % (type, state_key)
else:
name = "<%s>" % (type, )
event = Mock(name=name, spec=[])
event.type = type
if state_key is not None:
event.state_key = state_key
event.event_id = event_id
event.is_state = lambda: (state_key is not None)
event.unsigned = {}
event.user_id = "@user_id:example.com"
event.room_id = "!room_id:example.com"
return event