mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 04:43:46 +01:00
Correctly handle the difference between prev and current state
This commit is contained in:
parent
1bb8ec296d
commit
c10cb581c6
12 changed files with 102 additions and 69 deletions
|
@ -66,7 +66,7 @@ class Auth(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, event, context, do_sig_check=True):
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
event, context.current_state_ids, for_verification=True,
|
event, context.prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
|
@ -852,7 +852,7 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_auth_events(self, builder, context):
|
def add_auth_events(self, builder, context):
|
||||||
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
|
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
|
||||||
|
|
||||||
auth_events_entries = yield self.store.add_event_hashes(
|
auth_events_entries = yield self.store.add_event_hashes(
|
||||||
auth_ids
|
auth_ids
|
||||||
|
|
|
@ -15,8 +15,9 @@
|
||||||
|
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
def __init__(self, current_state_ids=None):
|
def __init__(self):
|
||||||
self.current_state_ids = current_state_ids
|
self.current_state_ids = None
|
||||||
|
self.prev_state_ids = None
|
||||||
self.state_group = None
|
self.state_group = None
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
self.push_actions = []
|
self.push_actions = []
|
||||||
|
|
|
@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
|
||||||
# joined the room. Don't bother if the user is just
|
# joined the room. Don't bother if the user is just
|
||||||
# changing their profile info.
|
# changing their profile info.
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
prev_state_id = context.current_state_ids.get(
|
prev_state_id = context.prev_state_ids.get(
|
||||||
(event.type, event.state_key)
|
(event.type, event.state_key)
|
||||||
)
|
)
|
||||||
if prev_state_id:
|
if prev_state_id:
|
||||||
|
@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||||
|
|
||||||
state_ids = context.current_state_ids.values()
|
state_ids = context.prev_state_ids.values()
|
||||||
auth_chain = yield self.store.get_auth_chain(set(
|
auth_chain = yield self.store.get_auth_chain(set(
|
||||||
[event.event_id] + state_ids
|
[event.event_id] + state_ids
|
||||||
))
|
))
|
||||||
|
|
||||||
state = yield self.store.get_events(context.current_state_ids.values())
|
state = yield self.store.get_events(context.prev_state_ids.values())
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"state": state.values(),
|
"state": state.values(),
|
||||||
|
@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state_ids, for_verification=True,
|
event, context.prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
|
@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||||
|
|
||||||
|
if event.is_state():
|
||||||
|
event_key = (event.type, event.state_key)
|
||||||
|
else:
|
||||||
|
event_key = None
|
||||||
|
|
||||||
if event_auth_events - current_state:
|
if event_auth_events - current_state:
|
||||||
have_events = yield self.store.have_events(
|
have_events = yield self.store.have_events(
|
||||||
event_auth_events - current_state
|
event_auth_events - current_state
|
||||||
|
@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
context.current_state_ids.update({
|
context.current_state_ids.update({
|
||||||
k: a.event_id for k, a in auth_events.items()
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
if k != event_key
|
||||||
})
|
})
|
||||||
context.state_group = None
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
logger.info("Different auth after resolution: %s", different_auth)
|
logger.info("Different auth after resolution: %s", different_auth)
|
||||||
|
@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
|
||||||
if do_resolution:
|
if do_resolution:
|
||||||
# 1. Get what we think is the auth chain.
|
# 1. Get what we think is the auth chain.
|
||||||
auth_ids = yield self.auth.compute_auth_events(
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state_ids
|
event, context.prev_state_ids
|
||||||
)
|
)
|
||||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||||
|
|
||||||
|
@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
context.current_state_ids.update({
|
context.current_state_ids.update({
|
||||||
k: a.event_id for k, a in auth_events.items()
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
if k != event_key
|
||||||
})
|
})
|
||||||
context.state_group = None
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=auth_events)
|
self.auth.check(event, auth_events=auth_events)
|
||||||
|
@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
|
||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
original_invite = None
|
original_invite = None
|
||||||
original_invite_id = context.current_state_ids.get(key)
|
original_invite_id = context.prev_state_ids.get(key)
|
||||||
if original_invite_id:
|
if original_invite_id:
|
||||||
original_invite = yield self.store.get_event(
|
original_invite = yield self.store.get_event(
|
||||||
original_invite_id, allow_none=True
|
original_invite_id, allow_none=True
|
||||||
|
@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
|
||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
invite_event_id = context.current_state_ids.get(
|
invite_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.ThirdPartyInvite, token,)
|
(EventTypes.ThirdPartyInvite, token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
|
||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_event_id = context.current_state_ids.get((event.type, event.state_key))
|
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
if not prev_event:
|
if not prev_event:
|
||||||
return
|
return
|
||||||
|
@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
|
||||||
event = builder.build()
|
event = builder.build()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created event %s with current state: %s",
|
"Created event %s with state: %s",
|
||||||
event.event_id, context.current_state_ids,
|
event.event_id, context.prev_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
|
@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state_ids, for_verification=True,
|
event, context.prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
|
@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
|
||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Create and context.current_state_ids:
|
if event.type == EventTypes.Create and context.prev_state_ids:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
|
|
|
@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event_id = context.current_state_ids.get(
|
prev_member_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.Member, target.to_string()),
|
(EventTypes.Member, target.to_string()),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = yield self._can_guest_join(context.current_state_ids)
|
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
|
||||||
if not guest_can_join:
|
if not guest_can_join:
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
|
@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event_id = context.current_state_ids.get(
|
prev_member_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.Member, event.state_key),
|
(EventTypes.Member, event.state_key),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
|
@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
|
||||||
)
|
)
|
||||||
|
|
||||||
room_members = yield self.store.get_joined_users_from_context(
|
room_members = yield self.store.get_joined_users_from_context(
|
||||||
event.room_id, context.state_group, context.current_state_ids
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||||
|
|
|
@ -128,7 +128,7 @@ class StateHandler(object):
|
||||||
def get_current_user_in_room(self, room_id):
|
def get_current_user_in_room(self, room_id):
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
|
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
joined_users = yield self.store.get_joined_users_from_context(
|
joined_users = yield self.store.get_joined_users_from_state(
|
||||||
room_id, group, state_ids
|
room_id, group, state_ids
|
||||||
)
|
)
|
||||||
defer.returnValue(joined_users)
|
defer.returnValue(joined_users)
|
||||||
|
@ -154,27 +154,38 @@ class StateHandler(object):
|
||||||
# state. Certainly store.get_current_state won't return any, and
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
# persisting the event won't store the state group.
|
# persisting the event won't store the state group.
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
|
if event.is_state():
|
||||||
|
context.current_state_events = dict(context.prev_state_ids)
|
||||||
|
key = (event.type, event.state_key)
|
||||||
|
context.current_state_events[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_events = context.prev_state_ids
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = {}
|
context.current_state_ids = {}
|
||||||
|
context.prev_state_ids = {}
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
context.state_group = self.store.get_next_state_group()
|
context.state_group = self.store.get_next_state_group()
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
context.state_group = self.store.get_next_state_group()
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state_ids:
|
if key in context.prev_state_ids:
|
||||||
replaces = context.current_state_ids[key]
|
replaces = context.prev_state_ids[key]
|
||||||
if replaces != event.event_id: # Paranoia check
|
if replaces != event.event_id: # Paranoia check
|
||||||
event.unsigned["replaces_state"] = replaces
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.current_state_ids[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
@ -192,7 +203,7 @@ class StateHandler(object):
|
||||||
|
|
||||||
group, curr_state = ret
|
group, curr_state = ret
|
||||||
|
|
||||||
context.current_state_ids = curr_state
|
context.prev_state_ids = curr_state
|
||||||
if event.is_state() or group is None:
|
if event.is_state() or group is None:
|
||||||
context.state_group = self.store.get_next_state_group()
|
context.state_group = self.store.get_next_state_group()
|
||||||
else:
|
else:
|
||||||
|
@ -200,9 +211,13 @@ class StateHandler(object):
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state_ids:
|
if key in context.prev_state_ids:
|
||||||
replaces = context.current_state_ids[key]
|
replaces = context.prev_state_ids[key]
|
||||||
event.unsigned["replaces_state"] = replaces
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.current_state_ids[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
|
@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
desc="who_forgot"
|
desc="who_forgot"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_joined_users_from_context(self, room_id, state_group, state_ids):
|
def get_joined_users_from_context(self, event, context):
|
||||||
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._get_joined_users_from_context(
|
return self._get_joined_users_from_context(
|
||||||
room_id, state_group, state_ids
|
event.room_id, state_group, context.current_state_ids, event=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_joined_users_from_state(self, room_id, state_group, state_ids):
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_users_from_context(
|
||||||
|
room_id, state_group, state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
||||||
cache_context):
|
cache_context, event=None):
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
# on it. However, its important that its never None, since two current_state's
|
# on it. However, its important that its never None, since two current_state's
|
||||||
# with a state_group of None are likely to be different.
|
# with a state_group of None are likely to be different.
|
||||||
|
@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
desc="_get_joined_users_from_context",
|
desc="_get_joined_users_from_context",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(set(row["user_id"] for row in rows))
|
users_in_room = set(row["user_id"] for row in rows)
|
||||||
|
if event is not None and event.type == EventTypes.Member:
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
if event.event_id in member_event_ids:
|
||||||
|
users_in_room.add(event.state_key)
|
||||||
|
|
||||||
|
defer.returnValue(users_in_room)
|
||||||
|
|
||||||
def is_host_joined(self, room_id, host, state_group, state_ids):
|
def is_host_joined(self, room_id, host, state_group, state_ids):
|
||||||
if not state_group:
|
if not state_group:
|
||||||
|
|
|
@ -108,9 +108,6 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
state_event_ids = dict(context.current_state_ids)
|
state_event_ids = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
|
||||||
state_event_ids[(event.type, event.state_key)] = event.event_id
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="state_groups",
|
table="state_groups",
|
||||||
|
|
|
@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
else:
|
else:
|
||||||
state_ids = None
|
state_ids = None
|
||||||
|
|
||||||
context = EventContext(current_state_ids=state_ids)
|
context = EventContext()
|
||||||
|
context.current_state_ids = state_ids
|
||||||
|
context.prev_state_ids = state_ids
|
||||||
context.push_actions = push_actions
|
context.push_actions = push_actions
|
||||||
|
|
||||||
ordering = None
|
ordering = None
|
||||||
|
|
|
@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
self.assertEquals(body, {})
|
self.assertEquals(body, {})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_events_and_state(self):
|
def test_events(self):
|
||||||
get = self.get(events="-1", state="-1", timeout="0")
|
get = self.get(events="-1", timeout="0")
|
||||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
synapse.types.create_requester(self.user), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
|
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
self.assertEquals(body["events"]["field_names"], [
|
self.assertEquals(body["events"]["field_names"], [
|
||||||
"position", "internal", "json", "state_group"
|
"position", "internal", "json", "state_group"
|
||||||
])
|
])
|
||||||
self.assertEquals(body["state_groups"]["field_names"], [
|
|
||||||
"position", "room_id", "event_id"
|
|
||||||
])
|
|
||||||
self.assertEquals(body["state_group_state"]["field_names"], [
|
|
||||||
"position", "type", "state_key", "event_id"
|
|
||||||
])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_presence(self):
|
def test_presence(self):
|
||||||
|
|
|
@ -86,17 +86,8 @@ class StateGroupStore(object):
|
||||||
|
|
||||||
state_events = dict(context.current_state_ids)
|
state_events = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
self._group_to_state[context.state_group] = state_events
|
||||||
state_events[(event.type, event.state_key)] = event.event_id
|
self._event_to_state_group[event.event_id] = context.state_group
|
||||||
|
|
||||||
state_group = context.state_group
|
|
||||||
if not state_group:
|
|
||||||
state_group = self._next_group
|
|
||||||
self._next_group += 1
|
|
||||||
|
|
||||||
self._group_to_state[state_group] = state_events
|
|
||||||
|
|
||||||
self._event_to_state_group[event.event_id] = state_group
|
|
||||||
|
|
||||||
def get_events(self, event_ids, **kwargs):
|
def get_events(self, event_ids, **kwargs):
|
||||||
return {
|
return {
|
||||||
|
@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
"get_state_groups_ids",
|
"get_state_groups_ids",
|
||||||
"add_event_hashes",
|
"add_event_hashes",
|
||||||
"get_events",
|
"get_events",
|
||||||
|
"get_next_state_group",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs = Mock(spec_set=[
|
hs = Mock(spec_set=[
|
||||||
|
@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
hs.get_clock.return_value = MockClock()
|
hs.get_clock.return_value = MockClock()
|
||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
|
|
||||||
|
self.store.get_next_state_group.side_effect = Mock
|
||||||
|
|
||||||
self.state = StateHandler(hs)
|
self.state = StateHandler(hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
|
@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
store.store_state_groups(event, context)
|
store.store_state_groups(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertEqual(2, len(context_store["D"].current_state_ids))
|
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_branch_basic_conflict(self):
|
def test_branch_basic_conflict(self):
|
||||||
|
@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "C"},
|
{"START", "A", "C"},
|
||||||
{e_id for e_id in context_store["D"].current_state_ids.values()}
|
{e_id for e_id in context_store["D"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "B", "C"},
|
{"START", "A", "B", "C"},
|
||||||
{e for e in context_store["E"].current_state_ids.values()}
|
{e for e in context_store["E"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"A1", "A2", "A3", "A5", "B"},
|
{"A1", "A2", "A3", "A5", "B"},
|
||||||
{e for e in context_store["D"].current_state_ids.values()}
|
{e for e in context_store["D"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_depths(self, nodes, edges):
|
def _add_depths(self, nodes, edges):
|
||||||
|
@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_annotate_with_old_state(self):
|
def test_annotate_with_old_state(self):
|
||||||
|
@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_trivial_annotate_message(self):
|
def test_trivial_annotate_message(self):
|
||||||
event = create_event(type="test_message", name="event")
|
event = create_event(type="test_message", name="event")
|
||||||
|
@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set(context.current_state_ids.values())
|
set(context.prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_message_conflict(self):
|
def test_resolve_message_conflict(self):
|
||||||
|
@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_state_conflict(self):
|
def test_resolve_state_conflict(self):
|
||||||
|
@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_standard_depth_conflict(self):
|
def test_standard_depth_conflict(self):
|
||||||
|
|
Loading…
Reference in a new issue