Fix bug which caused rejected events to be stored with the wrong room state (#6320)

Fixes a bug where rejected events were persisted with the wrong state group.

Also fixes an occasional internal-server-error when receiving events over
federation which are rejected and (possibly because they are
backwards-extremities) have no prev_group.

Fixes #6289.
This commit is contained in:
Richard van der Hoff 2019-11-06 10:01:39 +00:00 committed by GitHub
parent 0e3ab8afdc
commit 807ec3bd99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 283 additions and 101 deletions

1
changelog.d/6320.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug which casued rejected events to be persisted with the wrong room state.

View file

@ -48,10 +48,21 @@ class EventContext:
Note that this is a private attribute: it should be accessed via Note that this is a private attribute: it should be accessed via
the ``state_group`` property. the ``state_group`` property.
state_group_before_event: The ID of the state group representing the state
of the room before this event.
If this is a non-state event, this will be the same as ``state_group``. If
it's a state event, it will be the same as ``prev_group``.
If ``state_group`` is None (ie, the event is an outlier),
``state_group_before_event`` will always also be ``None``.
prev_group: If it is known, ``state_group``'s prev_group. Note that this being prev_group: If it is known, ``state_group``'s prev_group. Note that this being
None does not necessarily mean that ``state_group`` does not have None does not necessarily mean that ``state_group`` does not have
a prev_group! a prev_group!
If the event is a state event, this is normally the same as ``prev_group``.
If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
will always also be ``None``. will always also be ``None``.
@ -77,7 +88,8 @@ class EventContext:
``get_current_state_ids``. _AsyncEventContext impl calculates this ``get_current_state_ids``. _AsyncEventContext impl calculates this
on-demand: it will be None until that happens. on-demand: it will be None until that happens.
_prev_state_ids: The room state map, excluding this event. For a non-state _prev_state_ids: The room state map, excluding this event - ie, the state
in ``state_group_before_event``. For a non-state
event, this will be the same as _current_state_events. event, this will be the same as _current_state_events.
Note that it is a completely different thing to prev_group! Note that it is a completely different thing to prev_group!
@ -92,6 +104,7 @@ class EventContext:
rejected = attr.ib(default=False, type=Union[bool, str]) rejected = attr.ib(default=False, type=Union[bool, str])
_state_group = attr.ib(default=None, type=Optional[int]) _state_group = attr.ib(default=None, type=Optional[int])
state_group_before_event = attr.ib(default=None, type=Optional[int])
prev_group = attr.ib(default=None, type=Optional[int]) prev_group = attr.ib(default=None, type=Optional[int])
delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
app_service = attr.ib(default=None, type=Optional[ApplicationService]) app_service = attr.ib(default=None, type=Optional[ApplicationService])
@ -103,12 +116,18 @@ class EventContext:
@staticmethod @staticmethod
def with_state( def with_state(
state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None state_group,
state_group_before_event,
current_state_ids,
prev_state_ids,
prev_group=None,
delta_ids=None,
): ):
return EventContext( return EventContext(
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,
state_group=state_group, state_group=state_group,
state_group_before_event=state_group_before_event,
prev_group=prev_group, prev_group=prev_group,
delta_ids=delta_ids, delta_ids=delta_ids,
) )
@ -140,6 +159,7 @@ class EventContext:
"event_type": event.type, "event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None, "event_state_key": event.state_key if event.is_state() else None,
"state_group": self._state_group, "state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected, "rejected": self.rejected,
"prev_group": self.prev_group, "prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids), "delta_ids": _encode_state_dict(self.delta_ids),
@ -165,6 +185,7 @@ class EventContext:
event_type=input["event_type"], event_type=input["event_type"],
event_state_key=input["event_state_key"], event_state_key=input["event_state_key"],
state_group=input["state_group"], state_group=input["state_group"],
state_group_before_event=input["state_group_before_event"],
prev_group=input["prev_group"], prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]), delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"], rejected=input["rejected"],

View file

@ -2280,6 +2280,7 @@ class FederationHandler(BaseHandler):
return EventContext.with_state( return EventContext.with_state(
state_group=state_group, state_group=state_group,
state_group_before_event=context.state_group_before_event,
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,
prev_group=prev_group, prev_group=prev_group,

View file

@ -16,6 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, Optional
from six import iteritems, itervalues from six import iteritems, itervalues
@ -27,6 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
@ -212,15 +214,17 @@ class StateHandler(object):
return joined_hosts return joined_hosts
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.
This works out what the current state should be for the event, and This works out what the current state should be for the event, and
generates a new state group if necessary. generates a new state group if necessary.
Args: Args:
event (synapse.events.EventBase): event:
old_state (dict|None): The state at the event if it can't be old_state: The state at the event if it can't be
calculated from existing events. This is normally only specified calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling. prev events for, e.g. when backfilling.
@ -251,114 +255,104 @@ class StateHandler(object):
# group for it. # group for it.
context = EventContext.with_state( context = EventContext.with_state(
state_group=None, state_group=None,
state_group_before_event=None,
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,
) )
return context return context
#
# first of all, figure out the state before the event
#
if old_state: if old_state:
# We already have the state, so we don't need to calculate it. # if we're given the state before the event, then we use that
# Let's just correctly fill out the context and create a state_ids_before_event = {
# new state group for it. (s.type, s.state_key): s.event_id for s in old_state
}
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
if event.is_state():
key = (event.type, event.state_key)
if key in prev_state_ids:
replaces = prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
current_state_ids = dict(prev_state_ids)
current_state_ids[key] = event.event_id
else: else:
current_state_ids = prev_state_ids # otherwise, we'll need to resolve the state across the prev_events.
state_group = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=current_state_ids,
)
context = EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
)
return context
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events( entry = yield self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids() event.room_id, event.prev_event_ids()
) )
prev_state_ids = entry.state state_ids_before_event = entry.state
prev_group = None state_group_before_event = entry.state_group
delta_ids = None state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
if event.is_state(): #
# If this is a state event then we need to create a new state # make sure that we have a state group at that point. If it's not a state event,
# group for the state after this event. # that will be the state group for the new event. If it *is* a state event,
# it might get rejected (in which case we'll need to persist it with the
# previous state group)
#
if not state_group_before_event:
state_group_before_event = yield self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
)
# XXX: can we update the state cache entry for the new state group? or
# could we set a flag on resolve_state_groups_for_events to tell it to
# always make a state group?
#
# now if it's not a state event, we're done
#
if not event.is_state():
return EventContext.with_state(
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
current_state_ids=state_ids_before_event,
prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
)
#
# otherwise, we'll need to create a new state group for after the event
#
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in prev_state_ids: if key in state_ids_before_event:
replaces = prev_state_ids[key] replaces = state_ids_before_event[key]
if replaces != event.event_id:
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
current_state_ids = dict(prev_state_ids) state_ids_after_event = dict(state_ids_before_event)
current_state_ids[key] = event.event_id state_ids_after_event[key] = event.event_id
if entry.state_group:
# If the state at the event has a state group assigned then
# we can use that as the prev group
prev_group = entry.state_group
delta_ids = {key: event.event_id} delta_ids = {key: event.event_id}
elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
prev_group = entry.prev_group
delta_ids = dict(entry.delta_ids)
delta_ids[key] = event.event_id
state_group = yield self.state_store.store_state_group( state_group_after_event = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=prev_group, prev_group=state_group_before_event,
delta_ids=delta_ids, delta_ids=delta_ids,
current_state_ids=current_state_ids, current_state_ids=state_ids_after_event,
) )
else:
current_state_ids = prev_state_ids
prev_group = entry.prev_group
delta_ids = entry.delta_ids
if entry.state_group is None: return EventContext.with_state(
entry.state_group = yield self.state_store.store_state_group( state_group=state_group_after_event,
event.event_id, state_group_before_event=state_group_before_event,
event.room_id, current_state_ids=state_ids_after_event,
prev_group=entry.prev_group, prev_state_ids=state_ids_before_event,
delta_ids=entry.delta_ids, prev_group=state_group_before_event,
current_state_ids=current_state_ids,
)
entry.state_id = entry.state_group
state_group = entry.state_group
context = EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
prev_group=prev_group,
delta_ids=delta_ids, delta_ids=delta_ids,
) )
return context
@measure_func() @measure_func()
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids): def resolve_state_groups_for_events(self, room_id, event_ids):

View file

@ -1231,7 +1231,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
# if the event was rejected, just give it the same state as its # if the event was rejected, just give it the same state as its
# predecessor. # predecessor.
if context.rejected: if context.rejected:
state_groups[event.event_id] = context.prev_group state_groups[event.event_id] = context.state_group_before_event
continue continue
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group

View file

@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from tests import unittest from tests import unittest
logger = logging.getLogger(__name__)
class FederationTestCase(unittest.HomeserverTestCase): class FederationTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.code, 403, failure) self.assertEqual(failure.code, 403, failure)
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.") self.assertEqual(failure.msg, "You are not invited to this room.")
def test_rejected_message_event_state(self):
"""
Check that we store the state group correctly for rejected non-state events.
Regression test for #6289.
"""
OTHER_SERVER = "otherserver"
OTHER_USER = "@otheruser:" + OTHER_SERVER
# create the room
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
# pretend that another server has joined
join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
# check the state group
sg = self.successResultOf(
self.store._get_state_group_for_event(join_event.event_id)
)
# build and send an event which will be rejected
ev = event_from_pdu_json(
{
"type": EventTypes.Message,
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
},
join_event.format_version,
)
with LoggingContext(request="send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
# that should have been rejected
e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
self.assertIsNotNone(e.rejected_reason)
# ... and the state group should be the same as before
sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
self.assertEqual(sg, sg2)
def test_rejected_state_event_state(self):
"""
Check that we store the state group correctly for rejected state events.
Regression test for #6289.
"""
OTHER_SERVER = "otherserver"
OTHER_USER = "@otheruser:" + OTHER_SERVER
# create the room
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
# pretend that another server has joined
join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
# check the state group
sg = self.successResultOf(
self.store._get_state_group_for_event(join_event.event_id)
)
# build and send an event which will be rejected
ev = event_from_pdu_json(
{
"type": "org.matrix.test",
"state_key": "test_key",
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
"depth": join_event["depth"] + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
},
join_event.format_version,
)
with LoggingContext(request="send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
# that should have been rejected
e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
self.assertIsNotNone(e.rejected_reason)
# ... and the state group should be the same as before
sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
self.assertEqual(sg, sg2)
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
with LoggingContext(request="send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
self.get_success(d)
# sanity-check: the room should show that the new user is a member
r = self.get_success(self.store.get_current_state_ids(room_id))
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event

View file

@ -21,6 +21,7 @@ from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest from tests import unittest
@ -198,16 +199,22 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk()) self.store.register_events(graph.walk())
context_store = {} context_store = {} # type: dict[str, EventContext]
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) ctx_c = context_store["C"]
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids)) self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_basic_conflict(self): def test_branch_basic_conflict(self):
graph = Graph( graph = Graph(
@ -241,12 +248,19 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) # C ends up winning the resolution between B and C
ctx_c = context_store["C"]
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
) )
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_have_banned_conflict(self): def test_branch_have_banned_conflict(self):
graph = Graph( graph = Graph(
@ -292,11 +306,18 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) # C ends up winning the resolution between C and D because bans win over other
# changes
ctx_c = context_store["C"]
ctx_e = context_store["E"]
prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
) )
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_have_perms_conflict(self): def test_branch_have_perms_conflict(self):
@ -360,12 +381,20 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) # B ends up winning the resolution between B and C because power levels
# win over other changes.
ctx_b = context_store["B"]
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
) )
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
def _get_depth(ev): def _get_depth(ev):
node = nodes[ev] node = nodes[ev]
@ -390,13 +419,16 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state) context = yield self.state.compute_event_context(event, old_state=old_state)
current_state_ids = yield context.get_current_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
self.assertEqual( current_state_ids = yield context.get_current_state_ids(self.store)
set(e.event_id for e in old_state), set(current_state_ids.values()) self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
) )
self.assertIsNotNone(context.state_group) self.assertIsNotNone(context.state_group_before_event)
self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
@ -411,11 +443,18 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state) context = yield self.state.compute_event_context(event, old_state=old_state)
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
self.assertEqual( current_state_ids = yield context.get_current_state_ids(self.store)
set(e.event_id for e in old_state), set(prev_state_ids.values()) self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
) )
self.assertIsNotNone(context.state_group_before_event)
self.assertNotEqual(context.state_group_before_event, context.state_group)
self.assertEqual(context.state_group_before_event, context.prev_group)
self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id" prev_event_id = "prev_event_id"