forked from MirrorHub/synapse
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.6.0
Conflicts: synapse/state.py
This commit is contained in:
commit
f76269392b
15 changed files with 99 additions and 322 deletions
|
@ -36,7 +36,7 @@ class LoggingConfig(Config):
|
|||
help="The verbosity level."
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-f', '--log-file', dest="log_file", default=None,
|
||||
'-f', '--log-file', dest="log_file", default="homeserver.log",
|
||||
help="File to log to."
|
||||
)
|
||||
logging_group.add_argument(
|
||||
|
|
|
@ -30,20 +30,20 @@ class _EventInternalMetadata(object):
|
|||
|
||||
|
||||
def _event_dict_property(key):
|
||||
def getter(self):
|
||||
return self._event_dict[key]
|
||||
def getter(self):
|
||||
return self._event_dict[key]
|
||||
|
||||
def setter(self, v):
|
||||
self._event_dict[key] = v
|
||||
def setter(self, v):
|
||||
self._event_dict[key] = v
|
||||
|
||||
def delete(self):
|
||||
del self._event_dict[key]
|
||||
def delete(self):
|
||||
del self._event_dict[key]
|
||||
|
||||
return property(
|
||||
getter,
|
||||
setter,
|
||||
delete,
|
||||
)
|
||||
return property(
|
||||
getter,
|
||||
setter,
|
||||
delete,
|
||||
)
|
||||
|
||||
|
||||
class EventBase(object):
|
||||
|
|
|
@ -20,8 +20,6 @@ from synapse.util.async import run_on_reactor
|
|||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
|
||||
from synapse.events.snapshot import EventContext
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
|
@ -61,8 +59,6 @@ class BaseHandler(object):
|
|||
def _create_new_client_event(self, builder):
|
||||
yield run_on_reactor()
|
||||
|
||||
context = EventContext()
|
||||
|
||||
latest_ret = yield self.store.get_latest_events_in_room(
|
||||
builder.room_id,
|
||||
)
|
||||
|
@ -78,14 +74,11 @@ class BaseHandler(object):
|
|||
builder.depth = depth
|
||||
|
||||
state_handler = self.state_handler
|
||||
ret = yield state_handler.annotate_context_with_state(
|
||||
builder,
|
||||
context,
|
||||
)
|
||||
prev_state = ret
|
||||
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = prev_state
|
||||
builder.prev_state = context.prev_state_events
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
|
|
|
@ -156,4 +156,3 @@ class DirectoryHandler(BaseHandler):
|
|||
"sender": user_id,
|
||||
"content": {"aliases": aliases},
|
||||
})
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.api.errors import (
|
||||
AuthError, FederationError, SynapseError, StoreError,
|
||||
|
@ -202,7 +201,7 @@ class FederationHandler(BaseHandler):
|
|||
e.msg,
|
||||
affected=event.event_id,
|
||||
)
|
||||
|
||||
|
||||
# if we're receiving valid events from an origin,
|
||||
# it's probably a good idea to mark it as not in retry-state
|
||||
# for sending (although this is a bit of a leap)
|
||||
|
@ -260,12 +259,9 @@ class FederationHandler(BaseHandler):
|
|||
event = pdu
|
||||
|
||||
# FIXME (erikj): Not sure this actually works :/
|
||||
context = EventContext()
|
||||
yield self.state_handler.annotate_context_with_state(event, context)
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
events.append(
|
||||
(event, context)
|
||||
)
|
||||
events.append((event, context))
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
|
@ -547,8 +543,6 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
event = pdu
|
||||
|
||||
context = EventContext()
|
||||
|
||||
event.internal_metadata.outlier = True
|
||||
|
||||
event.signatures.update(
|
||||
|
@ -559,7 +553,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
)
|
||||
|
||||
yield self.state_handler.annotate_context_with_state(event, context)
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
|
@ -685,17 +679,14 @@ class FederationHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||
current_state=None, fetch_missing=True):
|
||||
context = EventContext()
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
yield self.state_handler.annotate_context_with_state(
|
||||
event,
|
||||
context,
|
||||
old_state=state
|
||||
context = yield self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
|
|
@ -18,8 +18,6 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import RoomError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from synapse.events.validator import EventValidator
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -66,35 +64,6 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_message(self, event=None, suppress_auth=False):
|
||||
""" Send a message.
|
||||
|
||||
Args:
|
||||
event : The message event to store.
|
||||
suppress_auth (bool) : True to suppress auth for this message. This
|
||||
is primarily so the home server can inject messages into rooms at
|
||||
will.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
|
||||
self.ratelimit(event.user_id)
|
||||
# TODO(paul): Why does 'event' not have a 'user' object?
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, suppress_auth=suppress_auth
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_handlers().presence_handler.bump_presence_active_time(
|
||||
user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||
feedback=False):
|
||||
|
@ -217,13 +186,6 @@ class MessageHandler(BaseHandler):
|
|||
defer.returnValue(fb)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_feedback(self, event):
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
# store message in db
|
||||
yield self._on_new_room_event(event, snapshot)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_events(self, user_id, room_id):
|
||||
"""Retrieve all state events for a given room.
|
||||
|
|
|
@ -45,8 +45,8 @@ class TypingNotificationHandler(BaseHandler):
|
|||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
self._member_typing_until = {} # clock time we expect to stop
|
||||
self._member_typing_timer = {} # deferreds to manage theabove
|
||||
self._member_typing_until = {} # clock time we expect to stop
|
||||
self._member_typing_timer = {} # deferreds to manage theabove
|
||||
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {}
|
||||
|
|
|
@ -135,6 +135,7 @@ class BaseMediaResource(Resource):
|
|||
if download is None:
|
||||
download = self._get_remote_media_impl(server_name, media_id)
|
||||
self.downloads[key] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
|
|
|
@ -310,8 +310,8 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
pagination_config = PaginationConfig.from_request(request,
|
||||
default_limit=10,
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10,
|
||||
)
|
||||
with_feedback = "feedback" in request.args
|
||||
handler = self.handlers.message_handler
|
||||
|
@ -466,7 +466,9 @@ class RoomRedactEventRestServlet(RestServlet):
|
|||
|
||||
|
||||
class RoomTypingRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$")
|
||||
PATTERN = client_path_pattern(
|
||||
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, user_id):
|
||||
|
|
|
@ -19,6 +19,7 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
|
||||
class HttpTransactionStore(object):
|
||||
|
||||
|
|
|
@ -19,10 +19,10 @@ from twisted.internet import defer
|
|||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events.snapshot import EventContext
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
|
@ -43,71 +43,6 @@ class StateHandler(object):
|
|||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def annotate_event_with_state(self, event, old_state=None):
|
||||
""" Annotates the event with the current state events as of that event.
|
||||
|
||||
This method adds three new attributes to the event:
|
||||
* `state_events`: The state up to and including the event. Encoded
|
||||
as a dict mapping tuple (type, state_key) -> event.
|
||||
* `old_state_events`: The state up to, but excluding, the event.
|
||||
Encoded similarly as `state_events`.
|
||||
* `state_group`: If there is an existing state group that can be
|
||||
used, then return that. Otherwise return `None`. See state
|
||||
storage for more information.
|
||||
|
||||
If the argument `old_state` is given (in the form of a list of
|
||||
events), then they are used as a the values for `old_state_events` and
|
||||
the value for `state_events` is generated from it. `state_group` is
|
||||
set to None.
|
||||
|
||||
This needs to be called before persisting the event.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
if old_state:
|
||||
event.state_group = None
|
||||
event.old_state_events = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
}
|
||||
event.state_events = event.old_state_events
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
event.state_events[(event.type, event.state_key)] = event
|
||||
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
if event.is_outlier():
|
||||
event.state_group = None
|
||||
event.old_state_events = None
|
||||
event.state_events = None
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
ids = [e for e, _ in event.prev_events]
|
||||
|
||||
ret = yield self.resolve_state_groups(ids)
|
||||
state_group, new_state, _ = ret
|
||||
|
||||
event.old_state_events = copy.deepcopy(new_state)
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
key = (event.type, event.state_key)
|
||||
if key in new_state:
|
||||
event.replaces_state = new_state[key].event_id
|
||||
new_state[key] = event
|
||||
elif state_group:
|
||||
event.state_group = state_group
|
||||
event.state_events = new_state
|
||||
defer.returnValue(False)
|
||||
|
||||
event.state_group = None
|
||||
event.state_events = new_state
|
||||
|
||||
defer.returnValue(hasattr(event, "state_key"))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
""" Returns the current state for the room as a list. This is done by
|
||||
|
@ -136,7 +71,7 @@ class StateHandler(object):
|
|||
defer.returnValue(res[1].values())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def annotate_context_with_state(self, event, context, old_state=None):
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
""" Fills out the context with the `current state` of the graph. The
|
||||
`current state` here is defined to be the state of the event graph
|
||||
just before the event - i.e. it never includes `event`
|
||||
|
@ -146,8 +81,11 @@ class StateHandler(object):
|
|||
|
||||
Args:
|
||||
event (EventBase)
|
||||
context (EventContext)
|
||||
Returns:
|
||||
an EventContext
|
||||
"""
|
||||
context = EventContext()
|
||||
|
||||
yield run_on_reactor()
|
||||
|
||||
if old_state:
|
||||
|
@ -173,7 +111,8 @@ class StateHandler(object):
|
|||
if replaces.event_id != event.event_id: # Paranoia check
|
||||
event.unsigned["replaces_state"] = replaces.event_id
|
||||
|
||||
defer.returnValue([])
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
if event.is_state():
|
||||
ret = yield self.resolve_state_groups(
|
||||
|
@ -211,7 +150,8 @@ class StateHandler(object):
|
|||
else:
|
||||
context.auth_events = {}
|
||||
|
||||
defer.returnValue(prev_state)
|
||||
context.prev_state_events = prev_state
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
|
|
@ -417,86 +417,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
],
|
||||
)
|
||||
|
||||
def snapshot_room(self, event):
|
||||
"""Snapshot the room for an update by a user
|
||||
Args:
|
||||
room_id (synapse.types.RoomId): The room to snapshot.
|
||||
user_id (synapse.types.UserId): The user to snapshot the room for.
|
||||
state_type (str): Optional state type to snapshot.
|
||||
state_key (str): Optional state key to snapshot.
|
||||
Returns:
|
||||
synapse.storage.Snapshot: A snapshot of the state of the room.
|
||||
"""
|
||||
def _snapshot(txn):
|
||||
prev_events = self._get_latest_events_in_room(
|
||||
txn,
|
||||
event.room_id
|
||||
)
|
||||
|
||||
prev_state = None
|
||||
state_key = None
|
||||
if hasattr(event, "state_key"):
|
||||
state_key = event.state_key
|
||||
prev_state = self._get_latest_state_in_room(
|
||||
txn,
|
||||
event.room_id,
|
||||
type=event.type,
|
||||
state_key=state_key,
|
||||
)
|
||||
|
||||
return Snapshot(
|
||||
store=self,
|
||||
room_id=event.room_id,
|
||||
user_id=event.user_id,
|
||||
prev_events=prev_events,
|
||||
prev_state=prev_state,
|
||||
state_type=event.type,
|
||||
state_key=state_key,
|
||||
)
|
||||
|
||||
return self.runInteraction("snapshot_room", _snapshot)
|
||||
|
||||
|
||||
class Snapshot(object):
|
||||
"""Snapshot of the state of a room
|
||||
Args:
|
||||
store (DataStore): The datastore.
|
||||
room_id (RoomId): The room of the snapshot.
|
||||
user_id (UserId): The user this snapshot is for.
|
||||
prev_events (list): The list of event ids this snapshot is after.
|
||||
membership_state (RoomMemberEvent): The current state of the user in
|
||||
the room.
|
||||
state_type (str, optional): State type captured by the snapshot
|
||||
state_key (str, optional): State key captured by the snapshot
|
||||
prev_state_pdu (PduEntry, optional): pdu id of
|
||||
the previous value of the state type and key in the room.
|
||||
"""
|
||||
|
||||
def __init__(self, store, room_id, user_id, prev_events,
|
||||
prev_state, state_type=None, state_key=None):
|
||||
self.store = store
|
||||
self.room_id = room_id
|
||||
self.user_id = user_id
|
||||
self.prev_events = prev_events
|
||||
self.prev_state = prev_state
|
||||
self.state_type = state_type
|
||||
self.state_key = state_key
|
||||
|
||||
def fill_out_prev_events(self, event):
|
||||
if not hasattr(event, "prev_events"):
|
||||
event.prev_events = [
|
||||
(event_id, hashes)
|
||||
for event_id, hashes, _ in self.prev_events
|
||||
]
|
||||
|
||||
if self.prev_events:
|
||||
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
|
||||
else:
|
||||
event.depth = 0
|
||||
|
||||
if not hasattr(event, "prev_state") and self.prev_state is not None:
|
||||
event.prev_state = self.prev_state
|
||||
|
||||
|
||||
def schema_path(schema):
|
||||
""" Get a filesystem path for the named database schema
|
||||
|
|
|
@ -34,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
self.state_handler = NonCallableMock(spec_set=[
|
||||
"annotate_context_with_state",
|
||||
"compute_event_context",
|
||||
])
|
||||
|
||||
self.auth = NonCallableMock(spec_set=[
|
||||
|
@ -91,11 +91,12 @@ class FederationTestCase(unittest.TestCase):
|
|||
self.datastore.get_room.return_value = defer.succeed(True)
|
||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||
|
||||
def annotate(ev, context, old_state=None):
|
||||
def annotate(ev, old_state=None):
|
||||
context = Mock()
|
||||
context.current_state = {}
|
||||
context.auth_events = {}
|
||||
return defer.succeed(False)
|
||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
||||
return defer.succeed(context)
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(
|
||||
"fo", pdu, False
|
||||
|
@ -109,15 +110,12 @@ class FederationTestCase(unittest.TestCase):
|
|||
context=ANY,
|
||||
)
|
||||
|
||||
self.state_handler.annotate_context_with_state.assert_called_once_with(
|
||||
ANY,
|
||||
ANY,
|
||||
old_state=None,
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
ANY, old_state=None,
|
||||
)
|
||||
|
||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
ANY,
|
||||
extra_users=[]
|
||||
ANY, extra_users=[]
|
||||
)
|
||||
|
|
|
@ -60,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
"check_host_in_room",
|
||||
]),
|
||||
state_handler=NonCallableMock(spec_set=[
|
||||
"annotate_context_with_state",
|
||||
"compute_event_context",
|
||||
"get_current_state",
|
||||
]),
|
||||
config=self.mock_config,
|
||||
|
@ -110,7 +110,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
defer.succeed([])
|
||||
)
|
||||
|
||||
def annotate(_, ctx):
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
|
@ -121,10 +122,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
room_id=room_id,
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(True)
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
|
@ -146,8 +148,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.state_handler.annotate_context_with_state.assert_called_once_with(
|
||||
builder, context
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
builder
|
||||
)
|
||||
|
||||
self.auth.add_auth_events.assert_called_once_with(
|
||||
|
@ -189,7 +191,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
defer.succeed([])
|
||||
)
|
||||
|
||||
def annotate(_, ctx):
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
|
@ -197,10 +200,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
membership=Membership.INVITE
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(True)
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
|
@ -262,7 +266,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
defer.succeed([])
|
||||
)
|
||||
|
||||
def annotate(_, ctx):
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
|
@ -270,10 +275,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
membership=Membership.JOIN
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(True)
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
|
|
|
@ -26,6 +26,7 @@ class StateTestCase(unittest.TestCase):
|
|||
self.store = Mock(
|
||||
spec_set=[
|
||||
"get_state_groups",
|
||||
"add_event_hashes",
|
||||
]
|
||||
)
|
||||
hs = Mock(spec=["get_datastore"])
|
||||
|
@ -44,17 +45,20 @@ class StateTestCase(unittest.TestCase):
|
|||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
yield self.state.annotate_event_with_state(event, old_state=old_state)
|
||||
context = yield self.state.compute_event_context(
|
||||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(set(old_state), set(event.old_state_events.values()))
|
||||
self.assertDictEqual(event.old_state_events, event.state_events)
|
||||
self.assertEqual(
|
||||
set(old_state), set(context.current_state.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_annotate_with_old_state(self):
|
||||
|
@ -66,21 +70,21 @@ class StateTestCase(unittest.TestCase):
|
|||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
yield self.state.annotate_event_with_state(event, old_state=old_state)
|
||||
context = yield self.state.compute_event_context(
|
||||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set(old_state + [event]),
|
||||
set(event.old_state_events.values())
|
||||
set(old_state),
|
||||
set(context.current_state.values())
|
||||
)
|
||||
|
||||
self.assertDictEqual(event.old_state_events, event.state_events)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_trivial_annotate_message(self):
|
||||
|
@ -99,30 +103,19 @@ class StateTestCase(unittest.TestCase):
|
|||
group_name: old_state,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
)
|
||||
|
||||
self.assertDictEqual(
|
||||
{
|
||||
k: v.event_id
|
||||
for k, v in event.old_state_events.items()
|
||||
},
|
||||
{
|
||||
k: v.event_id
|
||||
for k, v in event.state_events.items()
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(group_name, event.state_group)
|
||||
self.assertEqual(group_name, context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_trivial_annotate_state(self):
|
||||
|
@ -141,38 +134,19 @@ class StateTestCase(unittest.TestCase):
|
|||
group_name: old_state,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state] + [event.event_id]),
|
||||
set([e.event_id for e in event.state_events.values()])
|
||||
)
|
||||
|
||||
new_state = {
|
||||
k: v.event_id
|
||||
for k, v in event.state_events.items()
|
||||
}
|
||||
old_state = {
|
||||
k: v.event_id
|
||||
for k, v in event.old_state_events.items()
|
||||
}
|
||||
old_state[(event.type, event.state_key)] = event.event_id
|
||||
self.assertDictEqual(
|
||||
old_state,
|
||||
new_state
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve_message_conflict(self):
|
||||
|
@ -199,16 +173,11 @@ class StateTestCase(unittest.TestCase):
|
|||
group_name_2: old_state_2,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
self.assertEqual(len(event.old_state_events), 5)
|
||||
self.assertEqual(len(context.current_state), 5)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in event.state_events.values()]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve_state_conflict(self):
|
||||
|
@ -235,19 +204,11 @@ class StateTestCase(unittest.TestCase):
|
|||
group_name_2: old_state_2,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
self.assertEqual(len(event.old_state_events), 5)
|
||||
self.assertEqual(len(context.current_state), 5)
|
||||
|
||||
expected_new = event.old_state_events
|
||||
expected_new[(event.type, event.state_key)] = event
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in expected_new.values()]),
|
||||
set([e.event_id for e in event.state_events.values()]),
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
def create_event(self, name=None, type=None, state_key=None):
|
||||
self.event_id += 1
|
||||
|
@ -266,6 +227,9 @@ class StateTestCase(unittest.TestCase):
|
|||
event.state_key = state_key
|
||||
event.event_id = event_id
|
||||
|
||||
event.is_state = lambda: (state_key is not None)
|
||||
event.unsigned = {}
|
||||
|
||||
event.user_id = "@user_id:example.com"
|
||||
event.room_id = "!room_id:example.com"
|
||||
|
||||
|
|
Loading…
Reference in a new issue