0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-26 06:28:20 +02:00

Port to use state storage

This commit is contained in:
Erik Johnston 2019-10-23 17:25:54 +01:00
parent 5db03535d5
commit 69f0054ce6
19 changed files with 216 additions and 115 deletions

View file

@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(AdminHandler, self).__init__(hs) super(AdminHandler, self).__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks @defer.inlineCallbacks
def get_whois(self, user): def get_whois(self, user):
connections = [] connections = []
@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after from_key = events[-1].internal_metadata.after
events = yield filter_events_for_client(self.store, user_id, events) events = yield filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events) writer.write_events(room_id, events)
@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities: for event_id in extremities:
if not event_to_unseen_prevs[event_id]: if not event_to_unseen_prevs[event_id]:
continue continue
state = yield self.store.get_state_for_event(event_id) state = yield self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state) writer.write_state(room_id, event_id, state)
return writer.finished() return writer.finished()

View file

@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@trace @trace
@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
continue continue
# mapping from event_id -> state_dict # mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to # Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users. # the "possibly changed" users.

View file

@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler): class EventHandler(BaseHandler):
def __init__(self, hs):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, user, room_id, event_id): def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event. """Retrieve a single specified event.
@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client( filtered = yield filter_events_for_client(
self.store, user.to_string(), [event], is_peeking=is_peeking self.storage, user.to_string(), [event], is_peeking=is_peeking
) )
if not filtered: if not filtered:

View file

@ -110,6 +110,7 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -325,7 +326,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = yield self.store.get_state_groups_ids(room_id, seen) ours = yield self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list( state_maps = list(
@ -889,7 +890,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false # We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = yield filter_events_for_server( filtered_extremities = yield filter_events_for_server(
self.store, self.storage,
self.server_name, self.server_name,
list(extremities_events.values()), list(extremities_events.values()),
redact=False, redact=False,
@ -1550,7 +1551,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id event_id, allow_none=False, check_room_id=room_id
) )
state_groups = yield self.store.get_state_groups(room_id, [event_id]) state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(iteritems(state_groups)).pop() _, state = list(iteritems(state_groups)).pop()
@ -1579,7 +1580,7 @@ class FederationHandler(BaseHandler):
event_id, allow_none=False, check_room_id=room_id event_id, allow_none=False, check_room_id=room_id
) )
state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(state_groups.items()).pop() _, state = list(state_groups.items()).pop()
@ -1607,7 +1608,7 @@ class FederationHandler(BaseHandler):
events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.store, origin, events) events = yield filter_events_for_server(self.storage, origin, events)
return events return events
@ -1637,7 +1638,7 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield filter_events_for_server(self.store, origin, [event]) events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0] event = events[0]
return event return event
else: else:
@ -1903,7 +1904,7 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases # given state at the event. This should correctly handle cases
# like bans, especially with state res v2. # like bans, especially with state res v2.
state_sets = yield self.store.get_state_groups( state_sets = yield self.state_store.get_state_groups(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets = list(state_sets.values()) state_sets = list(state_sets.values())
@ -1994,7 +1995,7 @@ class FederationHandler(BaseHandler):
) )
missing_events = yield filter_events_for_server( missing_events = yield filter_events_for_server(
self.store, origin, missing_events self.storage, origin, missing_events
) )
return missing_events return missing_events
@ -2235,7 +2236,7 @@ class FederationHandler(BaseHandler):
# create a new state group as a delta from the existing one. # create a new state group as a delta from the existing one.
prev_group = context.state_group prev_group = context.state_group
state_group = yield self.store.store_state_group( state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=prev_group, prev_group=prev_group,

View file

@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache() self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
def snapshot_all_rooms( def snapshot_all_rooms(
self, self,
@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,) room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.store.get_state_for_events, [event.event_id] self.state_store.get_state_for_events, [event.event_id]
) )
deferred_room_state.addCallback( deferred_room_state.addCallback(
lambda states: states[event.event_id] lambda states: states[event.event_id]
@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
) )
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(self.store, user_id, messages) messages = yield filter_events_for_client(
self.storage, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token) start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token) end_token = now_token.copy_and_replace("room_key", room_end_token)
@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted( def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
): ):
room_state = yield self.store.get_state_for_events([member_event_id]) room_state = yield self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id] room_state = room_state[member_event_id]
@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
) )
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking self.storage, user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken.START.copy_and_replace("room_key", token) start_token = StreamToken.START.copy_and_replace("room_key", token)
@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
) )
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking self.storage, user_id, messages, is_peeking=is_peeking
) )
start_token = now_token.copy_and_replace("room_key", token) start_token = now_token.copy_and_replace("room_key", token)

View file

@ -59,6 +59,8 @@ class MessageHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -82,7 +84,7 @@ class MessageHandler(object):
data = yield self.state.get_current_state(room_id, event_type, state_key) data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key]) [membership_event_id], StateFilter.from_types([key])
) )
data = room_state[membership_event_id].get(key) data = room_state[membership_event_id].get(key)
@ -135,12 +137,12 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,)) raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client( visible_events = yield filter_events_for_client(
self.store, user_id, last_events self.storage, user_id, last_events
) )
event = last_events[0] event = last_events[0]
if visible_events: if visible_events:
room_state = yield self.store.get_state_for_events( room_state = yield self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter [event.event_id], state_filter=state_filter
) )
room_state = room_state[event.event_id] room_state = room_state[event.event_id]
@ -161,7 +163,7 @@ class MessageHandler(object):
) )
room_state = yield self.store.get_events(state_ids.values()) room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events( room_state = yield self.state_store.get_state_for_events(
[membership_event_id], state_filter=state_filter [membership_event_id], state_filter=state_filter
) )
room_state = room_state[membership_event_id] room_state = room_state[membership_event_id]

View file

@ -69,6 +69,8 @@ class PaginationHandler(object):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._server_name = hs.hostname self._server_name = hs.hostname
@ -255,7 +257,7 @@ class PaginationHandler(object):
events = event_filter.filter(events) events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, user_id, events, is_peeking=(member_event_id is None) self.storage, user_id, events, is_peeking=(member_event_id is None)
) )
if not events: if not events:
@ -274,7 +276,7 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events (EventTypes.Member, event.sender) for event in events
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter events[0].event_id, state_filter=state_filter
) )

View file

@ -822,6 +822,8 @@ class RoomContextHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter): def get_event_context(self, user, room_id, event_id, limit, event_filter):
@ -848,7 +850,7 @@ class RoomContextHandler(object):
def filter_evts(events): def filter_evts(events):
return filter_events_for_client( return filter_events_for_client(
self.store, user.to_string(), events, is_peeking=is_peeking self.storage, user.to_string(), events, is_peeking=is_peeking
) )
event = yield self.store.get_event( event = yield self.store.get_event(
@ -890,7 +892,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync? # first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687 # https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events( state = yield self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter [last_event_id], state_filter=state_filter
) )
results["state"] = list(state[last_event_id].values()) results["state"] = list(state[last_event_id].values())

View file

@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(SearchHandler, self).__init__(hs) super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@defer.inlineCallbacks @defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id): def get_old_rooms_from_upgraded_room(self, room_id):
@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
) )
events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
) )
room_events.extend(events) room_events.extend(events)
@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
) )
res["events_before"] = yield filter_events_for_client( res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"] self.storage, user.to_string(), res["events_before"]
) )
res["events_after"] = yield filter_events_for_client( res["events_after"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_after"] self.storage, user.to_string(), res["events_after"]
) )
res["start"] = now_token.copy_and_replace( res["start"] = now_token.copy_and_replace(
@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders] [(EventTypes.Member, sender) for sender in senders]
) )
state = yield self.store.get_state_for_event( state = yield self.state_store.get_state_for_event(
last_event_id, state_filter last_event_id, state_filter
) )

View file

@ -230,6 +230,8 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync") self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.storage = hs.get_storage()
self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id) # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache = ExpiringCache(
@ -417,7 +419,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids)) current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client( recents = yield filter_events_for_client(
self.store, self.storage,
sync_config.user.to_string(), sync_config.user.to_string(),
recents, recents,
always_include_ids=current_state_ids, always_include_ids=current_state_ids,
@ -470,7 +472,7 @@ class SyncHandler(object):
current_state_ids = frozenset(itervalues(current_state_ids)) current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client( loaded_recents = yield filter_events_for_client(
self.store, self.storage,
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
always_include_ids=current_state_ids, always_include_ids=current_state_ids,
@ -509,7 +511,7 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter
) )
if event.is_state(): if event.is_state():
@ -580,7 +582,7 @@ class SyncHandler(object):
return None return None
last_event = last_events[-1] last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
last_event.event_id, last_event.event_id,
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -757,11 +759,11 @@ class SyncHandler(object):
if full_state: if full_state:
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
@ -781,7 +783,7 @@ class SyncHandler(object):
) )
elif batch.limited: elif batch.limited:
if batch: if batch:
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
else: else:
@ -810,7 +812,7 @@ class SyncHandler(object):
) )
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
else: else:
@ -841,7 +843,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the # So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below. # timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.state_store.get_state_ids_for_event(
batch.events[0].event_id, batch.events[0].event_id,
# we only want members! # we only want members!
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(

View file

@ -159,6 +159,7 @@ class Notifier(object):
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.hs = hs self.hs = hs
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = [] self.pending_new_room_events = []
@ -425,7 +426,10 @@ class Notifier(object):
if name == "room": if name == "room":
new_events = yield filter_events_for_client( new_events = yield filter_events_for_client(
self.store, user.to_string(), new_events, is_peeking=is_peeking self.storage,
user.to_string(),
new_events,
is_peeking=is_peeking,
) )
elif name == "presence": elif name == "presence":
now = self.clock.time_msec() now = self.clock.time_msec()

View file

@ -64,6 +64,7 @@ class HttpPusher(object):
def __init__(self, hs, pusherdict): def __init__(self, hs, pusherdict):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict["user_name"] self.user_id = pusherdict["user_name"]
@ -329,7 +330,7 @@ class HttpPusher(object):
return d return d
ctx = yield push_tools.get_context_for_event( ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id self.storage, self.state_handler, event, self.user_id
) )
d = { d = {

View file

@ -119,6 +119,7 @@ class Mailer(object):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
@ -389,7 +390,7 @@ class Mailer(object):
} }
the_events = yield filter_events_for_client( the_events = yield filter_events_for_client(
self.store, user_id, results["events_before"] self.storage, user_id, results["events_before"]
) )
the_events.append(notif_event) the_events.append(notif_event)

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
@defer.inlineCallbacks @defer.inlineCallbacks
@ -43,22 +44,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(store, state_handler, ev, user_id): def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield store.get_state_ids_for_event(ev.event_id) room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or
# a list of people in the room # a list of people in the room
name = yield calculate_room_name( name = yield calculate_room_name(
store, room_state_ids, user_id, fallback_to_single_member=False storage.main, room_state_ids, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx["name"] = name ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id) sender_state_event = yield storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event) ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx return ctx

View file

@ -103,6 +103,7 @@ class StateHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state_store = hs.get_storage().state
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@ -271,7 +272,7 @@ class StateHandler(object):
else: else:
current_state_ids = prev_state_ids current_state_ids = prev_state_ids
state_group = yield self.store.store_state_group( state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=None, prev_group=None,
@ -321,7 +322,7 @@ class StateHandler(object):
delta_ids = dict(entry.delta_ids) delta_ids = dict(entry.delta_ids)
delta_ids[key] = event.event_id delta_ids[key] = event.event_id
state_group = yield self.store.store_state_group( state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=prev_group, prev_group=prev_group,
@ -334,7 +335,7 @@ class StateHandler(object):
delta_ids = entry.delta_ids delta_ids = entry.delta_ids
if entry.state_group is None: if entry.state_group is None:
entry.state_group = yield self.store.store_state_group( entry.state_group = yield self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=entry.prev_group, prev_group=entry.prev_group,
@ -376,14 +377,16 @@ class StateHandler(object):
# map from state group id to the state in that state group (where # map from state group id to the state in that state group (where
# 'state' is a map from state key to event id) # 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]] # dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) state_groups_ids = yield self.state_store.get_state_groups_ids(
room_id, event_ids
)
if len(state_groups_ids) == 0: if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None) return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1: elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
return _StateCacheEntry( return _StateCacheEntry(
state=state_list, state=state_list,

View file

@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_client( def filter_events_for_client(
store, user_id, events, is_peeking=False, always_include_ids=frozenset() storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
): ):
""" """
Check which events a user is allowed to see Check which events a user is allowed to see
Args: Args:
store (synapse.storage.DataStore): our datastore (can also be a worker storage
store)
user_id(str): user id to be checked user_id(str): user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked events(list[synapse.events.EventBase]): sequence of events to be checked
is_peeking(bool): should be True if: is_peeking(bool): should be True if:
@ -68,12 +68,12 @@ def filter_events_for_client(
events = list(e for e in events if not e.internal_metadata.is_soft_failed()) events = list(e for e in events if not e.internal_metadata.is_soft_failed())
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield store.get_state_for_events( event_id_to_state = yield storage.state.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types), state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = yield store.get_global_account_data_by_type_for_user( ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id "m.ignored_user_list", user_id
) )
@ -84,7 +84,7 @@ def filter_events_for_client(
else [] else []
) )
erased_senders = yield store.are_users_erased((e.sender for e in events)) erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
def allowed(event): def allowed(event):
""" """
@ -213,13 +213,17 @@ def filter_events_for_client(
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_server( def filter_events_for_server(
store, server_name, events, redact=True, check_history_visibility_only=False storage: Storage,
server_name,
events,
redact=True,
check_history_visibility_only=False,
): ):
"""Filter a list of events based on whether given server is allowed to """Filter a list of events based on whether given server is allowed to
see them. see them.
Args: Args:
store (DataStore) storage
server_name (str) server_name (str)
events (iterable[FrozenEvent]) events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or redact (bool): Whether to return a redacted version of the event, or
@ -274,7 +278,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility # Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't # of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room). # need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),) types=((EventTypes.RoomHistoryVisibility, ""),)
@ -292,14 +296,14 @@ def filter_events_for_server(
if not visibility_ids: if not visibility_ids:
all_open = True all_open = True
else: else:
event_map = yield store.get_events(visibility_ids) event_map = yield storage.main.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map) for e in itervalues(event_map)
) )
if not check_history_visibility_only: if not check_history_visibility_only:
erased_senders = yield store.are_users_erased((e.sender for e in events)) erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
else: else:
# We don't want to check whether users are erased, which is equivalent # We don't want to check whether users are erased, which is equivalent
# to no users having been erased. # to no users having been erased.
@ -328,7 +332,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids # first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events. # of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events( event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@ -358,7 +362,7 @@ def filter_events_for_server(
return False return False
return state_key[idx + 1 :] == server_name return state_key[idx + 1 :] == server_name
event_map = yield store.get_events( event_map = yield storage.main.get_events(
[ [
e_id e_id
for e_id, key in iteritems(event_id_to_state_key) for e_id, key in iteritems(event_id_to_state_key)

View file

@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_datastore = self.store
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield self.store.get_state_groups_ids( state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id] self.room, [e2.event_id]
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) state_group_map = yield self.storage.state.get_state_groups(
self.room, [e2.event_id]
)
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0] state_list = list(state_group_map.values())[0]
@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we get the full state as of the final event # check we get the full state as of the final event
state = yield self.store.get_state_for_event(e5.event_id) state = yield self.storage.state.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
) )
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
) )
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the # check we can grab a specific room member without filtering out the
# other event types # other event types
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}}, types={EventTypes.Member: {self.u_alice.to_string()}},
@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check that we can grab everything except members # check that we can grab everything except members
state = yield self.store.get_state_for_event( state = yield self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group_ids = yield self.storage.state.get_state_groups_ids(
room_id, [e5.event_id]
)
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with wildcard types # with wildcard types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict, state_dict,
) )
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types={EventTypes.Member: {e5.state_key}}, include_others=False
@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
# deliberately remove e2 (room name) from the _state_group_cache # deliberately remove e2 (room name) from the _state_group_cache
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (
group is_all,
) known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertEqual(known_absent, set()) self.assertEqual(known_absent, set())
@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
state_dict_ids.pop((e2.type, e2.state_key)) state_dict_ids.pop((e2.type, e2.state_key))
self.store._state_group_cache.invalidate(group) self.state_datastore._state_group_cache.invalidate(group)
self.store._state_group_cache.update( self.state_datastore._state_group_cache.update(
sequence=self.store._state_group_cache.sequence, sequence=self.state_datastore._state_group_cache.sequence,
key=group, key=group,
value=state_dict_ids, value=state_dict_ids,
# list fetched keys so it knows it's partial # list fetched keys so it knows it's partial
fetched_keys=((e1.type, e1.state_key),), fetched_keys=((e1.type, e1.state_key),),
) )
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (
group is_all,
) known_absent,
state_dict_ids,
) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members # test _get_state_for_group_using_cache correctly filters out members
# with types=[] # with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# wildcard types # wildcard types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types={EventTypes.Member: None}, include_others=True
@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types={EventTypes.Member: {e5.state_key}}, include_others=True
@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members # test _get_state_for_group_using_cache correctly filters in members
# with specific types # with specific types
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types={EventTypes.Member: {e5.state_key}}, include_others=False
@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict) self.assertDictEqual({}, state_dict)
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache( (
self.store._state_group_members_cache, state_dict,
is_all,
) = yield self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types={EventTypes.Member: {e5.state_key}}, include_others=False

View file

@ -158,10 +158,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = StateGroupStore() self.store = StateGroupStore()
storage = Mock(main=self.store, state=self.store)
hs = Mock( hs = Mock(
spec_set=[ spec_set=[
"config", "config",
"get_datastore", "get_datastore",
"get_storage",
"get_auth", "get_auth",
"get_state_handler", "get_state_handler",
"get_clock", "get_clock",
@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage.return_value = storage
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt) events_to_filter.append(evt)
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter self.storage, "test_server", events_to_filter
) )
# the result should be 5 redacted events, and 5 unredacted events. # the result should be 5 redacted events, and 5 unredacted events.
@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens. # ... and the filtering happens.
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter self.storage, "test_server", events_to_filter
) )
for i in range(0, len(events_to_filter)): for i in range(0, len(events_to_filter)):
@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering") logger.info("Starting filtering")
start = time.time() start = time.time()
storage = Mock()
storage.main = test_store
storage.state = test_store
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter test_store, "test_server", events_to_filter
) )