0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-05 02:38:54 +02:00

Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.6.0

Conflicts:
	synapse/state.py
This commit is contained in:
Erik Johnston 2014-12-16 18:35:46 +00:00
commit f76269392b
15 changed files with 99 additions and 322 deletions

View file

@ -36,7 +36,7 @@ class LoggingConfig(Config):
help="The verbosity level." help="The verbosity level."
) )
logging_group.add_argument( logging_group.add_argument(
'-f', '--log-file', dest="log_file", default=None, '-f', '--log-file', dest="log_file", default="homeserver.log",
help="File to log to." help="File to log to."
) )
logging_group.add_argument( logging_group.add_argument(

View file

@ -30,20 +30,20 @@ class _EventInternalMetadata(object):
def _event_dict_property(key): def _event_dict_property(key):
def getter(self): def getter(self):
return self._event_dict[key] return self._event_dict[key]
def setter(self, v): def setter(self, v):
self._event_dict[key] = v self._event_dict[key] = v
def delete(self): def delete(self):
del self._event_dict[key] del self._event_dict[key]
return property( return property(
getter, getter,
setter, setter,
delete, delete,
) )
class EventBase(object): class EventBase(object):

View file

@ -20,8 +20,6 @@ from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.events.snapshot import EventContext
import logging import logging
@ -61,8 +59,6 @@ class BaseHandler(object):
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
yield run_on_reactor() yield run_on_reactor()
context = EventContext()
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id, builder.room_id,
) )
@ -78,14 +74,11 @@ class BaseHandler(object):
builder.depth = depth builder.depth = depth
state_handler = self.state_handler state_handler = self.state_handler
ret = yield state_handler.annotate_context_with_state(
builder, context = yield state_handler.compute_event_context(builder)
context,
)
prev_state = ret
if builder.is_state(): if builder.is_state():
builder.prev_state = prev_state builder.prev_state = context.prev_state_events
yield self.auth.add_auth_events(builder, context) yield self.auth.add_auth_events(builder, context)

View file

@ -156,4 +156,3 @@ class DirectoryHandler(BaseHandler):
"sender": user_id, "sender": user_id,
"content": {"aliases": aliases}, "content": {"aliases": aliases},
}) })

View file

@ -17,7 +17,6 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, SynapseError, StoreError,
@ -202,7 +201,7 @@ class FederationHandler(BaseHandler):
e.msg, e.msg,
affected=event.event_id, affected=event.event_id,
) )
# if we're receiving valid events from an origin, # if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state # it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap) # for sending (although this is a bit of a leap)
@ -260,12 +259,9 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
# FIXME (erikj): Not sure this actually works :/ # FIXME (erikj): Not sure this actually works :/
context = EventContext() context = yield self.state_handler.compute_event_context(event)
yield self.state_handler.annotate_context_with_state(event, context)
events.append( events.append((event, context))
(event, context)
)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
@ -547,8 +543,6 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
context = EventContext()
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.signatures.update( event.signatures.update(
@ -559,7 +553,7 @@ class FederationHandler(BaseHandler):
) )
) )
yield self.state_handler.annotate_context_with_state(event, context) context = yield self.state_handler.compute_event_context(event)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
@ -685,17 +679,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True): current_state=None, fetch_missing=True):
context = EventContext()
logger.debug( logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s", "_handle_new_event: Before annotate: %s, sigs: %s",
event.event_id, event.signatures, event.event_id, event.signatures,
) )
yield self.state_handler.annotate_context_with_state( context = yield self.state_handler.compute_event_context(
event, event, old_state=state
context,
old_state=state
) )
logger.debug( logger.debug(

View file

@ -18,8 +18,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from ._base import BaseHandler from ._base import BaseHandler
@ -66,35 +64,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def send_message(self, event=None, suppress_auth=False):
""" Send a message.
Args:
event : The message event to store.
suppress_auth (bool) : True to suppress auth for this message. This
is primarily so the home server can inject messages into rooms at
will.
Raises:
SynapseError if something went wrong.
"""
self.ratelimit(event.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(event.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, suppress_auth=suppress_auth
)
with PreserveLoggingContext():
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
feedback=False): feedback=False):
@ -217,13 +186,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(fb) defer.returnValue(fb)
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event)
# store message in db
yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room. """Retrieve all state events for a given room.

View file

@ -45,8 +45,8 @@ class TypingNotificationHandler(BaseHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop self._member_typing_until = {} # clock time we expect to stop
self._member_typing_timer = {} # deferreds to manage theabove self._member_typing_timer = {} # deferreds to manage theabove
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} self._room_serials = {}

View file

@ -135,6 +135,7 @@ class BaseMediaResource(Resource):
if download is None: if download is None:
download = self._get_remote_media_impl(server_name, media_id) download = self._get_remote_media_impl(server_name, media_id)
self.downloads[key] = download self.downloads[key] = download
@download.addBoth @download.addBoth
def callback(media_info): def callback(media_info):
del self.downloads[key] del self.downloads[key]

View file

@ -310,8 +310,8 @@ class RoomMessageListRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request, pagination_config = PaginationConfig.from_request(
default_limit=10, request, default_limit=10,
) )
with_feedback = "feedback" in request.args with_feedback = "feedback" in request.args
handler = self.handlers.message_handler handler = self.handlers.message_handler
@ -466,7 +466,9 @@ class RoomRedactEventRestServlet(RestServlet):
class RoomTypingRestServlet(RestServlet): class RoomTypingRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$") PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):

View file

@ -19,6 +19,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# FIXME: elsewhere we use FooStore to indicate something in the storage layer... # FIXME: elsewhere we use FooStore to indicate something in the storage layer...
class HttpTransactionStore(object): class HttpTransactionStore(object):

View file

@ -19,10 +19,10 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
import copy
import logging import logging
import hashlib import hashlib
@ -43,71 +43,6 @@ class StateHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
@log_function
def annotate_event_with_state(self, event, old_state=None):
""" Annotates the event with the current state events as of that event.
This method adds three new attributes to the event:
* `state_events`: The state up to and including the event. Encoded
as a dict mapping tuple (type, state_key) -> event.
* `old_state_events`: The state up to, but excluding, the event.
Encoded similarly as `state_events`.
* `state_group`: If there is an existing state group that can be
used, then return that. Otherwise return `None`. See state
storage for more information.
If the argument `old_state` is given (in the form of a list of
events), then they are used as a the values for `old_state_events` and
the value for `state_events` is generated from it. `state_group` is
set to None.
This needs to be called before persisting the event.
"""
yield run_on_reactor()
if old_state:
event.state_group = None
event.old_state_events = {
(s.type, s.state_key): s for s in old_state
}
event.state_events = event.old_state_events
if hasattr(event, "state_key"):
event.state_events[(event.type, event.state_key)] = event
defer.returnValue(False)
return
if event.is_outlier():
event.state_group = None
event.old_state_events = None
event.state_events = None
defer.returnValue(False)
return
ids = [e for e, _ in event.prev_events]
ret = yield self.resolve_state_groups(ids)
state_group, new_state, _ = ret
event.old_state_events = copy.deepcopy(new_state)
if hasattr(event, "state_key"):
key = (event.type, event.state_key)
if key in new_state:
event.replaces_state = new_state[key].event_id
new_state[key] = event
elif state_group:
event.state_group = state_group
event.state_events = new_state
defer.returnValue(False)
event.state_group = None
event.state_events = new_state
defer.returnValue(hasattr(event, "state_key"))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
""" Returns the current state for the room as a list. This is done by """ Returns the current state for the room as a list. This is done by
@ -136,7 +71,7 @@ class StateHandler(object):
defer.returnValue(res[1].values()) defer.returnValue(res[1].values())
@defer.inlineCallbacks @defer.inlineCallbacks
def annotate_context_with_state(self, event, context, old_state=None): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph `current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event` just before the event - i.e. it never includes `event`
@ -146,8 +81,11 @@ class StateHandler(object):
Args: Args:
event (EventBase) event (EventBase)
context (EventContext) Returns:
an EventContext
""" """
context = EventContext()
yield run_on_reactor() yield run_on_reactor()
if old_state: if old_state:
@ -173,7 +111,8 @@ class StateHandler(object):
if replaces.event_id != event.event_id: # Paranoia check if replaces.event_id != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
defer.returnValue([]) context.prev_state_events = []
defer.returnValue(context)
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( ret = yield self.resolve_state_groups(
@ -211,7 +150,8 @@ class StateHandler(object):
else: else:
context.auth_events = {} context.auth_events = {}
defer.returnValue(prev_state) context.prev_state_events = prev_state
defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function

View file

@ -417,86 +417,6 @@ class DataStore(RoomMemberStore, RoomStore,
], ],
) )
def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
user_id (synapse.types.UserId): The user to snapshot the room for.
state_type (str): Optional state type to snapshot.
state_key (str): Optional state key to snapshot.
Returns:
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
prev_events = self._get_latest_events_in_room(
txn,
event.room_id
)
prev_state = None
state_key = None
if hasattr(event, "state_key"):
state_key = event.state_key
prev_state = self._get_latest_state_in_room(
txn,
event.room_id,
type=event.type,
state_key=state_key,
)
return Snapshot(
store=self,
room_id=event.room_id,
user_id=event.user_id,
prev_events=prev_events,
prev_state=prev_state,
state_type=event.type,
state_key=state_key,
)
return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
"""Snapshot of the state of a room
Args:
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
state_key (str, optional): State key captured by the snapshot
prev_state_pdu (PduEntry, optional): pdu id of
the previous value of the state type and key in the room.
"""
def __init__(self, store, room_id, user_id, prev_events,
prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_events = prev_events
self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
def fill_out_prev_events(self, event):
if not hasattr(event, "prev_events"):
event.prev_events = [
(event_id, hashes)
for event_id, hashes, _ in self.prev_events
]
if self.prev_events:
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
if not hasattr(event, "prev_state") and self.prev_state is not None:
event.prev_state = self.prev_state
def schema_path(schema): def schema_path(schema):
""" Get a filesystem path for the named database schema """ Get a filesystem path for the named database schema

View file

@ -34,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
self.state_handler = NonCallableMock(spec_set=[ self.state_handler = NonCallableMock(spec_set=[
"annotate_context_with_state", "compute_event_context",
]) ])
self.auth = NonCallableMock(spec_set=[ self.auth = NonCallableMock(spec_set=[
@ -91,11 +91,12 @@ class FederationTestCase(unittest.TestCase):
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
def annotate(ev, context, old_state=None): def annotate(ev, old_state=None):
context = Mock()
context.current_state = {} context.current_state = {}
context.auth_events = {} context.auth_events = {}
return defer.succeed(False) return defer.succeed(context)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu( yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False "fo", pdu, False
@ -109,15 +110,12 @@ class FederationTestCase(unittest.TestCase):
context=ANY, context=ANY,
) )
self.state_handler.annotate_context_with_state.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
ANY, ANY, old_state=None,
ANY,
old_state=None,
) )
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, ANY, extra_users=[]
extra_users=[]
) )

View file

@ -60,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"check_host_in_room", "check_host_in_room",
]), ]),
state_handler=NonCallableMock(spec_set=[ state_handler=NonCallableMock(spec_set=[
"annotate_context_with_state", "compute_event_context",
"get_current_state", "get_current_state",
]), ]),
config=self.mock_config, config=self.mock_config,
@ -110,7 +110,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
defer.succeed([]) defer.succeed([])
) )
def annotate(_, ctx): def annotate(_):
ctx = Mock()
ctx.current_state = { ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member( (EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green", user_id="@alice:green",
@ -121,10 +122,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
room_id=room_id, room_id=room_id,
), ),
} }
ctx.prev_state_events = []
return defer.succeed(True) return defer.succeed(ctx)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx): def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[ ctx.auth_events = ctx.current_state[
@ -146,8 +148,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
yield room_handler.change_membership(event, context) yield room_handler.change_membership(event, context)
self.state_handler.annotate_context_with_state.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
builder, context builder
) )
self.auth.add_auth_events.assert_called_once_with( self.auth.add_auth_events.assert_called_once_with(
@ -189,7 +191,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
defer.succeed([]) defer.succeed([])
) )
def annotate(_, ctx): def annotate(_):
ctx = Mock()
ctx.current_state = { ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member( (EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red", user_id="@bob:red",
@ -197,10 +200,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
membership=Membership.INVITE membership=Membership.INVITE
), ),
} }
ctx.prev_state_events = []
return defer.succeed(True) return defer.succeed(ctx)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx): def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[ ctx.auth_events = ctx.current_state[
@ -262,7 +266,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
defer.succeed([]) defer.succeed([])
) )
def annotate(_, ctx): def annotate(_):
ctx = Mock()
ctx.current_state = { ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member( (EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red", user_id="@bob:red",
@ -270,10 +275,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
membership=Membership.JOIN membership=Membership.JOIN
), ),
} }
ctx.prev_state_events = []
return defer.succeed(True) return defer.succeed(ctx)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx): def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[ ctx.auth_events = ctx.current_state[

View file

@ -26,6 +26,7 @@ class StateTestCase(unittest.TestCase):
self.store = Mock( self.store = Mock(
spec_set=[ spec_set=[
"get_state_groups", "get_state_groups",
"add_event_hashes",
] ]
) )
hs = Mock(spec=["get_datastore"]) hs = Mock(spec=["get_datastore"])
@ -44,17 +45,20 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""), self.create_event(type="test2", state_key=""),
] ]
yield self.state.annotate_event_with_state(event, old_state=old_state) context = yield self.state.compute_event_context(
event, old_state=old_state
)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual(set(old_state), set(event.old_state_events.values())) self.assertEqual(
self.assertDictEqual(event.old_state_events, event.state_events) set(old_state), set(context.current_state.values())
)
self.assertIsNone(event.state_group) self.assertIsNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
@ -66,21 +70,21 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""), self.create_event(type="test2", state_key=""),
] ]
yield self.state.annotate_event_with_state(event, old_state=old_state) context = yield self.state.compute_event_context(
event, old_state=old_state
)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set(old_state + [event]), set(old_state),
set(event.old_state_events.values()) set(context.current_state.values())
) )
self.assertDictEqual(event.old_state_events, event.state_events) self.assertIsNone(context.state_group)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
@ -99,30 +103,19 @@ class StateTestCase(unittest.TestCase):
group_name: old_state, group_name: old_state,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in event.old_state_events.values()]) set([e.event_id for e in context.current_state.values()])
) )
self.assertDictEqual( self.assertEqual(group_name, context.state_group)
{
k: v.event_id
for k, v in event.old_state_events.items()
},
{
k: v.event_id
for k, v in event.state_events.items()
}
)
self.assertEqual(group_name, event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
@ -141,38 +134,19 @@ class StateTestCase(unittest.TestCase):
group_name: old_state, group_name: old_state,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in event.old_state_events.values()]) set([e.event_id for e in context.current_state.values()])
) )
self.assertEqual( self.assertIsNone(context.state_group)
set([e.event_id for e in old_state] + [event.event_id]),
set([e.event_id for e in event.state_events.values()])
)
new_state = {
k: v.event_id
for k, v in event.state_events.items()
}
old_state = {
k: v.event_id
for k, v in event.old_state_events.items()
}
old_state[(event.type, event.state_key)] = event.event_id
self.assertDictEqual(
old_state,
new_state
)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
@ -199,16 +173,11 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
self.assertEqual(len(event.old_state_events), 5) self.assertEqual(len(context.current_state), 5)
self.assertEqual( self.assertIsNone(context.state_group)
set([e.event_id for e in event.state_events.values()]),
set([e.event_id for e in event.old_state_events.values()])
)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
@ -235,19 +204,11 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
self.assertEqual(len(event.old_state_events), 5) self.assertEqual(len(context.current_state), 5)
expected_new = event.old_state_events self.assertIsNone(context.state_group)
expected_new[(event.type, event.state_key)] = event
self.assertEqual(
set([e.event_id for e in expected_new.values()]),
set([e.event_id for e in event.state_events.values()]),
)
self.assertIsNone(event.state_group)
def create_event(self, name=None, type=None, state_key=None): def create_event(self, name=None, type=None, state_key=None):
self.event_id += 1 self.event_id += 1
@ -266,6 +227,9 @@ class StateTestCase(unittest.TestCase):
event.state_key = state_key event.state_key = state_key
event.event_id = event_id event.event_id = event_id
event.is_state = lambda: (state_key is not None)
event.unsigned = {}
event.user_id = "@user_id:example.com" event.user_id = "@user_id:example.com"
event.room_id = "!room_id:example.com" event.room_id = "!room_id:example.com"