Merge pull request #3673 from matrix-org/erikj/refactor_state_handler

Refactor state module to support multiple room versions
This commit is contained in:
Erik Johnston 2018-08-22 10:04:55 +01:00 committed by GitHub
commit 3bf8bab8f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 437 additions and 277 deletions

1
changelog.d/3673.misc Normal file
View file

@ -0,0 +1 @@
Refactor state module to support multiple room versions

View file

@ -97,9 +97,14 @@ class ThirdPartyEntityKind(object):
LOCATION = "location" LOCATION = "location"
class RoomVersions(object):
V1 = "1"
VDH_TEST = "vdh-test-version"
# the version we will give rooms which are created on this server # the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = "1" DEFAULT_ROOM_VERSION = RoomVersions.V1
# vdh-test-version is a placeholder to get room versioning support working and tested # vdh-test-version is a placeholder to get room versioning support working and tested
# until we have a working v2. # until we have a working v2.
KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST}

View file

@ -291,8 +291,9 @@ class FederationHandler(BaseHandler):
ev_ids, get_prev_content=False, check_redacted=False ev_ids, get_prev_content=False, check_redacted=False
) )
room_version = yield self.store.get_room_version(pdu.room_id)
state_map = yield resolve_events_with_factory( state_map = yield resolve_events_with_factory(
state_groups, {pdu.event_id: pdu}, fetch room_version, state_groups, {pdu.event_id: pdu}, fetch
) )
state = (yield self.store.get_events(state_map.values())).values() state = (yield self.store.get_events(state_map.values())).values()
@ -1828,7 +1829,10 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
room_version = yield self.store.get_room_version(event.room_id)
new_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())], [list(local_view.values()), list(remote_view.values())],
event event
) )

View file

@ -344,6 +344,7 @@ class RoomMemberHandler(object):
latest_event_ids = ( latest_event_ids = (
event_id for (event_id, _, _) in prev_events_and_hashes event_id for (event_id, _, _) in prev_events_and_hashes
) )
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,21 +14,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import iteritems, iterkeys, itervalues from six import iteritems, itervalues
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import v1
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -264,6 +262,7 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
@ -338,8 +337,11 @@ class StateHandler(object):
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Args: Args:
room_id (str): room_id (str)
event_ids (list[str]): event_ids (list[str])
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
Returns: Returns:
Deferred[_StateCacheEntry]: resolved state Deferred[_StateCacheEntry]: resolved state
@ -353,7 +355,12 @@ class StateHandler(object):
room_id, event_ids room_id, event_ids
) )
if len(state_groups_ids) == 1: if len(state_groups_ids) == 0:
defer.returnValue(_StateCacheEntry(
state={},
state_group=None,
))
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -365,8 +372,11 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
room_version = yield self.store.get_room_version(room_id)
result = yield self._state_resolution_handler.resolve_state_groups( result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, None, self._state_map_factory, room_id, room_version, state_groups_ids, None,
self._state_map_factory,
) )
defer.returnValue(result) defer.returnValue(result)
@ -375,7 +385,7 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False, ev_ids, get_prev_content=False, check_redacted=False,
) )
def resolve_events(self, state_sets, event): def resolve_events(self, room_version, state_sets, event):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
@ -391,7 +401,9 @@ class StateHandler(object):
} }
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = resolve_events_with_state_map(
room_version, state_set_ids, state_map,
)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state) key: state_map[ev_id] for key, ev_id in iteritems(new_state)
@ -430,7 +442,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups( def resolve_state_groups(
self, room_id, state_groups_ids, event_map, state_map_factory, self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
): ):
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
@ -439,6 +451,7 @@ class StateResolutionHandler(object):
Args: Args:
room_id (str): room we are resolving for (used for logging) room_id (str): room we are resolving for (used for logging)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]): state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group map from state group id to the state in that state group
(where 'state' is a map from state key to event id) (where 'state' is a map from state key to event id)
@ -492,6 +505,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
room_version,
list(itervalues(state_groups_ids)), list(itervalues(state_groups_ids)),
event_map=event_map, event_map=event_map,
state_map_factory=state_map_factory, state_map_factory=state_map_factory,
@ -575,16 +589,10 @@ def _make_state_cache_entry(
) )
def _ordered_events(events): def resolve_events_with_state_map(room_version, state_sets, state_map):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func)
def resolve_events_with_state_map(state_sets, state_map):
""" """
Args: Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in state_map(dict): a dict from event_id to event, for all events in
@ -594,75 +602,23 @@ def resolve_events_with_state_map(state_sets, state_map):
dict[(str, str), str]: dict[(str, str), str]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
return state_sets[0] return v1.resolve_events_with_state_map(
state_sets, state_map,
unconflicted_state, conflicted_state = _seperate( )
state_sets, else:
) # This should only happen if we added a version but forgot to add it to
# the list above.
auth_events = _create_auth_events_from_maps( raise Exception(
unconflicted_state, conflicted_state, state_map "No state resolution algorithm defined for version %r" % (room_version,)
) )
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
def _seperate(state_sets): def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
for state_set in state_set_iterator:
for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
""" """
Args: Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
@ -682,185 +638,13 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
defer.returnValue(state_sets[0]) return v1.resolve_events_with_factory(
state_sets, event_map, state_map_factory,
unconflicted_state, conflicted_state = _seperate( )
state_sets, else:
) # This should only happen if we added a version but forgot to add it to
# the list above.
needed_events = set( raise Exception(
event_id "No state resolution algorithm defined for version %r" % (room_version,)
for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
) )
except Exception:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event

321
synapse/state/v1.py Normal file
View file

@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import logging
from six import iteritems, iterkeys, itervalues
from twisted.internet import defer
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
needed_events = set(
event_id
for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
for state_set in state_set_iterator:
for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
except Exception:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func)

View file

@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
} }
events_map = {ev.event_id: ev for ev, _ in events_context} events_map = {ev.event_id: ev for ev, _ in events_context}
room_version = yield self.get_room_version(room_id)
logger.debug("calling resolve_state_groups from preserve_events") logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups( res = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups, events_map, get_events room_id, room_version, state_groups, events_map, get_events
) )
defer.returnValue((res.state, None)) defer.returnValue((res.state, None))

View file

@ -112,6 +112,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invites(self): def test_invites(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = yield self.persist( event = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="invite" type="m.room.member", key=USER_ID_2, membership="invite"
@ -133,7 +134,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_actions_for_user(self): def test_push_actions_for_user(self):
yield self.persist(type="m.room.create", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.join", key=USER_ID, membership="join") yield self.persist(type="m.room.join", key=USER_ID, membership="join")
yield self.persist( yield self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"

View file

@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import create_room, setup_test_homeserver
class RedactionTestCase(unittest.TestCase): class RedactionTestCase(unittest.TestCase):
@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test") self.room1 = RoomID.from_string("!abc123:test")
yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
self.depth = 1 self.depth = 1
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import create_room, setup_test_homeserver
class RoomMemberStoreTestCase(unittest.TestCase): class RoomMemberStoreTestCase(unittest.TestCase):
@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abc123:test") self.room = RoomID.from_string("!abc123:test")
yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None): def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(

View file

@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.state import StateHandler, StateResolutionHandler from synapse.state import StateHandler, StateResolutionHandler
@ -117,6 +117,9 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group): def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group self._event_to_state_group[event_id] = state_group
def get_room_version(self, room_id):
return RoomVersions.V1
class DictObj(dict): class DictObj(dict):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase):
def test_branch_no_conflict(self): def test_branch_no_conflict(self):
graph = Graph( graph = Graph(
nodes={ nodes={
"START": DictObj(type=EventTypes.Create, state_key="", depth=1), "START": DictObj(
type=EventTypes.Create, state_key="", content={}, depth=1,
),
"A": DictObj(type=EventTypes.Message, depth=2), "A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3), "B": DictObj(type=EventTypes.Message, depth=3),
"C": DictObj(type=EventTypes.Name, state_key="", depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=3),

View file

@ -21,7 +21,7 @@ from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server from synapse.visibility import filter_events_for_server
import tests.unittest import tests.unittest
from tests.utils import setup_test_homeserver from tests.utils import create_room, setup_test_homeserver
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_builder_factory = self.hs.get_event_builder_factory() self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filtering(self): def test_filtering(self):
# #

View file

@ -24,6 +24,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error from synapse.api.errors import CodeMessageException, cs_error
from synapse.federation.transport import server from synapse.federation.transport import server
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -539,3 +540,32 @@ class DeferredMockCallable(object):
"Expected not to received any calls, got:\n" "Expected not to received any calls, got:\n"
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls]) + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
) )
@defer.inlineCallbacks
def create_room(hs, room_id, creator_id):
"""Creates and persist a creation event for the given room
Args:
hs
room_id (str)
creator_id (str)
"""
store = hs.get_datastore()
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new({
"type": EventTypes.Create,
"state_key": "",
"sender": creator_id,
"room_id": room_id,
"content": {},
})
event, context = yield event_creation_handler.create_new_client_event(
builder
)
yield store.persist_event(event, context)