mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 12:43:50 +01:00
Merge pull request #6294 from matrix-org/erikj/add_state_storage
Add StateGroupStorage interface
This commit is contained in:
commit
dfe0cd71b6
23 changed files with 453 additions and 116 deletions
1
changelog.d/6294.misc
Normal file
1
changelog.d/6294.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Split out state storage into separate data store.
|
|
@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
|
|||
def __init__(self, hs):
|
||||
super(AdminHandler, self).__init__(hs)
|
||||
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_whois(self, user):
|
||||
connections = []
|
||||
|
@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
|
|||
|
||||
from_key = events[-1].internal_metadata.after
|
||||
|
||||
events = yield filter_events_for_client(self.store, user_id, events)
|
||||
events = yield filter_events_for_client(self.storage, user_id, events)
|
||||
|
||||
writer.write_events(room_id, events)
|
||||
|
||||
|
@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
|
|||
for event_id in extremities:
|
||||
if not event_to_unseen_prevs[event_id]:
|
||||
continue
|
||||
state = yield self.store.get_state_for_event(event_id)
|
||||
state = yield self.state_store.get_state_for_event(event_id)
|
||||
writer.write_state(room_id, event_id, state)
|
||||
|
||||
return writer.finished()
|
||||
|
|
|
@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
self.state_store = hs.get_storage().state
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
@trace
|
||||
|
@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
continue
|
||||
|
||||
# mapping from event_id -> state_dict
|
||||
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
|
||||
prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
|
||||
|
||||
# Check if we've joined the room? If so we just blindly add all the users to
|
||||
# the "possibly changed" users.
|
||||
|
|
|
@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
|
||||
class EventHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(EventHandler, self).__init__(hs)
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, user, room_id, event_id):
|
||||
"""Retrieve a single specified event.
|
||||
|
@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
|
|||
is_peeking = user.to_string() not in users
|
||||
|
||||
filtered = yield filter_events_for_client(
|
||||
self.store, user.to_string(), [event], is_peeking=is_peeking
|
||||
self.storage, user.to_string(), [event], is_peeking=is_peeking
|
||||
)
|
||||
|
||||
if not filtered:
|
||||
|
|
|
@ -110,6 +110,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
self.federation_client = hs.get_federation_client()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
@ -325,7 +326,7 @@ class FederationHandler(BaseHandler):
|
|||
event_map = {event_id: pdu}
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = yield self.store.get_state_groups_ids(room_id, seen)
|
||||
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(
|
||||
|
@ -891,7 +892,7 @@ class FederationHandler(BaseHandler):
|
|||
# We set `check_history_visibility_only` as we might otherwise get false
|
||||
# positives from users having been erased.
|
||||
filtered_extremities = yield filter_events_for_server(
|
||||
self.store,
|
||||
self.storage,
|
||||
self.server_name,
|
||||
list(extremities_events.values()),
|
||||
redact=False,
|
||||
|
@ -1552,7 +1553,7 @@ class FederationHandler(BaseHandler):
|
|||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
state_groups = yield self.store.get_state_groups(room_id, [event_id])
|
||||
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
|
||||
|
||||
if state_groups:
|
||||
_, state = list(iteritems(state_groups)).pop()
|
||||
|
@ -1581,7 +1582,7 @@ class FederationHandler(BaseHandler):
|
|||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
|
||||
state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
|
||||
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
|
||||
if state_groups:
|
||||
_, state = list(state_groups.items()).pop()
|
||||
|
@ -1609,7 +1610,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
|
||||
|
||||
events = yield filter_events_for_server(self.store, origin, events)
|
||||
events = yield filter_events_for_server(self.storage, origin, events)
|
||||
|
||||
return events
|
||||
|
||||
|
@ -1639,7 +1640,7 @@ class FederationHandler(BaseHandler):
|
|||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = yield filter_events_for_server(self.store, origin, [event])
|
||||
events = yield filter_events_for_server(self.storage, origin, [event])
|
||||
event = events[0]
|
||||
return event
|
||||
else:
|
||||
|
@ -1907,7 +1908,7 @@ class FederationHandler(BaseHandler):
|
|||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
|
||||
state_sets = yield self.store.get_state_groups(
|
||||
state_sets = yield self.state_store.get_state_groups(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets = list(state_sets.values())
|
||||
|
@ -1998,7 +1999,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
missing_events = yield filter_events_for_server(
|
||||
self.store, origin, missing_events
|
||||
self.storage, origin, missing_events
|
||||
)
|
||||
|
||||
return missing_events
|
||||
|
@ -2239,7 +2240,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# create a new state group as a delta from the existing one.
|
||||
prev_group = context.state_group
|
||||
state_group = yield self.store.store_state_group(
|
||||
state_group = yield self.state_store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=prev_group,
|
||||
|
|
|
@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
|
|||
self.validator = EventValidator()
|
||||
self.snapshot_cache = SnapshotCache()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
def snapshot_all_rooms(
|
||||
self,
|
||||
|
@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
elif event.membership == Membership.LEAVE:
|
||||
room_end_token = "s%d" % (event.stream_ordering,)
|
||||
deferred_room_state = run_in_background(
|
||||
self.store.get_state_for_events, [event.event_id]
|
||||
self.state_store.get_state_for_events, [event.event_id]
|
||||
)
|
||||
deferred_room_state.addCallback(
|
||||
lambda states: states[event.event_id]
|
||||
|
@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
|
|||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield filter_events_for_client(self.store, user_id, messages)
|
||||
messages = yield filter_events_for_client(
|
||||
self.storage, user_id, messages
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token)
|
||||
end_token = now_token.copy_and_replace("room_key", room_end_token)
|
||||
|
@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
def _room_initial_sync_parted(
|
||||
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||
):
|
||||
room_state = yield self.store.get_state_for_events([member_event_id])
|
||||
room_state = yield self.state_store.get_state_for_events([member_event_id])
|
||||
|
||||
room_state = room_state[member_event_id]
|
||||
|
||||
|
@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages, is_peeking=is_peeking
|
||||
self.storage, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = StreamToken.START.copy_and_replace("room_key", token)
|
||||
|
@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages, is_peeking=is_peeking
|
||||
self.storage, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token)
|
||||
|
|
|
@ -59,6 +59,8 @@ class MessageHandler(object):
|
|||
self.clock = hs.get_clock()
|
||||
self.state = hs.get_state_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -82,7 +84,7 @@ class MessageHandler(object):
|
|||
data = yield self.state.get_current_state(room_id, event_type, state_key)
|
||||
elif membership == Membership.LEAVE:
|
||||
key = (event_type, state_key)
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
room_state = yield self.state_store.get_state_for_events(
|
||||
[membership_event_id], StateFilter.from_types([key])
|
||||
)
|
||||
data = room_state[membership_event_id].get(key)
|
||||
|
@ -135,12 +137,12 @@ class MessageHandler(object):
|
|||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||
|
||||
visible_events = yield filter_events_for_client(
|
||||
self.store, user_id, last_events
|
||||
self.storage, user_id, last_events
|
||||
)
|
||||
|
||||
event = last_events[0]
|
||||
if visible_events:
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
room_state = yield self.state_store.get_state_for_events(
|
||||
[event.event_id], state_filter=state_filter
|
||||
)
|
||||
room_state = room_state[event.event_id]
|
||||
|
@ -161,7 +163,7 @@ class MessageHandler(object):
|
|||
)
|
||||
room_state = yield self.store.get_events(state_ids.values())
|
||||
elif membership == Membership.LEAVE:
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
room_state = yield self.state_store.get_state_for_events(
|
||||
[membership_event_id], state_filter=state_filter
|
||||
)
|
||||
room_state = room_state[membership_event_id]
|
||||
|
|
|
@ -69,6 +69,8 @@ class PaginationHandler(object):
|
|||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
self.clock = hs.get_clock()
|
||||
self._server_name = hs.hostname
|
||||
|
||||
|
@ -255,7 +257,7 @@ class PaginationHandler(object):
|
|||
events = event_filter.filter(events)
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store, user_id, events, is_peeking=(member_event_id is None)
|
||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
)
|
||||
|
||||
if not events:
|
||||
|
@ -274,7 +276,7 @@ class PaginationHandler(object):
|
|||
(EventTypes.Member, event.sender) for event in events
|
||||
)
|
||||
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
|
|
|
@ -822,6 +822,8 @@ class RoomContextHandler(object):
|
|||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event_context(self, user, room_id, event_id, limit, event_filter):
|
||||
|
@ -848,7 +850,7 @@ class RoomContextHandler(object):
|
|||
|
||||
def filter_evts(events):
|
||||
return filter_events_for_client(
|
||||
self.store, user.to_string(), events, is_peeking=is_peeking
|
||||
self.storage, user.to_string(), events, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
event = yield self.store.get_event(
|
||||
|
@ -890,7 +892,7 @@ class RoomContextHandler(object):
|
|||
# first? Shouldn't we be consistent with /sync?
|
||||
# https://github.com/matrix-org/matrix-doc/issues/687
|
||||
|
||||
state = yield self.store.get_state_for_events(
|
||||
state = yield self.state_store.get_state_for_events(
|
||||
[last_event_id], state_filter=state_filter
|
||||
)
|
||||
results["state"] = list(state[last_event_id].values())
|
||||
|
|
|
@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
|
|||
def __init__(self, hs):
|
||||
super(SearchHandler, self).__init__(hs)
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_old_rooms_from_upgraded_room(self, room_id):
|
||||
|
@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
|
|||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store, user.to_string(), filtered_events
|
||||
self.storage, user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
events.sort(key=lambda e: -rank_map[e.event_id])
|
||||
|
@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
|
|||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store, user.to_string(), filtered_events
|
||||
self.storage, user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
room_events.extend(events)
|
||||
|
@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
|
|||
)
|
||||
|
||||
res["events_before"] = yield filter_events_for_client(
|
||||
self.store, user.to_string(), res["events_before"]
|
||||
self.storage, user.to_string(), res["events_before"]
|
||||
)
|
||||
|
||||
res["events_after"] = yield filter_events_for_client(
|
||||
self.store, user.to_string(), res["events_after"]
|
||||
self.storage, user.to_string(), res["events_after"]
|
||||
)
|
||||
|
||||
res["start"] = now_token.copy_and_replace(
|
||||
|
@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
|
|||
[(EventTypes.Member, sender) for sender in senders]
|
||||
)
|
||||
|
||||
state = yield self.store.get_state_for_event(
|
||||
state = yield self.state_store.get_state_for_event(
|
||||
last_event_id, state_filter
|
||||
)
|
||||
|
||||
|
|
|
@ -230,6 +230,8 @@ class SyncHandler(object):
|
|||
self.response_cache = ResponseCache(hs, "sync")
|
||||
self.state = hs.get_state_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
|
||||
self.lazy_loaded_members_cache = ExpiringCache(
|
||||
|
@ -417,7 +419,7 @@ class SyncHandler(object):
|
|||
current_state_ids = frozenset(itervalues(current_state_ids))
|
||||
|
||||
recents = yield filter_events_for_client(
|
||||
self.store,
|
||||
self.storage,
|
||||
sync_config.user.to_string(),
|
||||
recents,
|
||||
always_include_ids=current_state_ids,
|
||||
|
@ -470,7 +472,7 @@ class SyncHandler(object):
|
|||
current_state_ids = frozenset(itervalues(current_state_ids))
|
||||
|
||||
loaded_recents = yield filter_events_for_client(
|
||||
self.store,
|
||||
self.storage,
|
||||
sync_config.user.to_string(),
|
||||
loaded_recents,
|
||||
always_include_ids=current_state_ids,
|
||||
|
@ -509,7 +511,7 @@ class SyncHandler(object):
|
|||
Returns:
|
||||
A Deferred map from ((type, state_key)->Event)
|
||||
"""
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
event.event_id, state_filter=state_filter
|
||||
)
|
||||
if event.is_state():
|
||||
|
@ -580,7 +582,7 @@ class SyncHandler(object):
|
|||
return None
|
||||
|
||||
last_event = last_events[-1]
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
last_event.event_id,
|
||||
state_filter=StateFilter.from_types(
|
||||
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
|
||||
|
@ -757,11 +759,11 @@ class SyncHandler(object):
|
|||
|
||||
if full_state:
|
||||
if batch:
|
||||
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||
current_state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
batch.events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
|
@ -781,7 +783,7 @@ class SyncHandler(object):
|
|||
)
|
||||
elif batch.limited:
|
||||
if batch:
|
||||
state_at_timeline_start = yield self.store.get_state_ids_for_event(
|
||||
state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
|
||||
batch.events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
else:
|
||||
|
@ -810,7 +812,7 @@ class SyncHandler(object):
|
|||
)
|
||||
|
||||
if batch:
|
||||
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||
current_state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id, state_filter=state_filter
|
||||
)
|
||||
else:
|
||||
|
@ -841,7 +843,7 @@ class SyncHandler(object):
|
|||
# So we fish out all the member events corresponding to the
|
||||
# timeline here, and then dedupe any redundant ones below.
|
||||
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
batch.events[0].event_id,
|
||||
# we only want members!
|
||||
state_filter=StateFilter.from_types(
|
||||
|
|
|
@ -159,6 +159,7 @@ class Notifier(object):
|
|||
self.room_to_user_streams = {}
|
||||
|
||||
self.hs = hs
|
||||
self.storage = hs.get_storage()
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastore()
|
||||
self.pending_new_room_events = []
|
||||
|
@ -425,7 +426,10 @@ class Notifier(object):
|
|||
|
||||
if name == "room":
|
||||
new_events = yield filter_events_for_client(
|
||||
self.store, user.to_string(), new_events, is_peeking=is_peeking
|
||||
self.storage,
|
||||
user.to_string(),
|
||||
new_events,
|
||||
is_peeking=is_peeking,
|
||||
)
|
||||
elif name == "presence":
|
||||
now = self.clock.time_msec()
|
||||
|
|
|
@ -64,6 +64,7 @@ class HttpPusher(object):
|
|||
def __init__(self, hs, pusherdict):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.storage = self.hs.get_storage()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.user_id = pusherdict["user_name"]
|
||||
|
@ -329,7 +330,7 @@ class HttpPusher(object):
|
|||
return d
|
||||
|
||||
ctx = yield push_tools.get_context_for_event(
|
||||
self.store, self.state_handler, event, self.user_id
|
||||
self.storage, self.state_handler, event, self.user_id
|
||||
)
|
||||
|
||||
d = {
|
||||
|
|
|
@ -119,6 +119,7 @@ class Mailer(object):
|
|||
self.store = self.hs.get_datastore()
|
||||
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.storage = hs.get_storage()
|
||||
self.app_name = app_name
|
||||
|
||||
logger.info("Created Mailer for app_name %s" % app_name)
|
||||
|
@ -389,7 +390,7 @@ class Mailer(object):
|
|||
}
|
||||
|
||||
the_events = yield filter_events_for_client(
|
||||
self.store, user_id, results["events_before"]
|
||||
self.storage, user_id, results["events_before"]
|
||||
)
|
||||
the_events.append(notif_event)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||
from synapse.storage import Storage
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -43,22 +44,22 @@ def get_badge_count(store, user_id):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_context_for_event(store, state_handler, ev, user_id):
|
||||
def get_context_for_event(storage: Storage, state_handler, ev, user_id):
|
||||
ctx = {}
|
||||
|
||||
room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
|
||||
room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
|
||||
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.name, an alias or
|
||||
# a list of people in the room
|
||||
name = yield calculate_room_name(
|
||||
store, room_state_ids, user_id, fallback_to_single_member=False
|
||||
storage.main, room_state_ids, user_id, fallback_to_single_member=False
|
||||
)
|
||||
if name:
|
||||
ctx["name"] = name
|
||||
|
||||
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
|
||||
sender_state_event = yield store.get_event(sender_state_event_id)
|
||||
sender_state_event = yield storage.main.get_event(sender_state_event_id)
|
||||
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
|
||||
|
||||
return ctx
|
||||
|
|
|
@ -103,6 +103,7 @@ class StateHandler(object):
|
|||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state_store = hs.get_storage().state
|
||||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
|
@ -271,7 +272,7 @@ class StateHandler(object):
|
|||
else:
|
||||
current_state_ids = prev_state_ids
|
||||
|
||||
state_group = yield self.store.store_state_group(
|
||||
state_group = yield self.state_store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=None,
|
||||
|
@ -321,7 +322,7 @@ class StateHandler(object):
|
|||
delta_ids = dict(entry.delta_ids)
|
||||
delta_ids[key] = event.event_id
|
||||
|
||||
state_group = yield self.store.store_state_group(
|
||||
state_group = yield self.state_store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=prev_group,
|
||||
|
@ -334,7 +335,7 @@ class StateHandler(object):
|
|||
delta_ids = entry.delta_ids
|
||||
|
||||
if entry.state_group is None:
|
||||
entry.state_group = yield self.store.store_state_group(
|
||||
entry.state_group = yield self.state_store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=entry.prev_group,
|
||||
|
@ -376,14 +377,16 @@ class StateHandler(object):
|
|||
# map from state group id to the state in that state group (where
|
||||
# 'state' is a map from state key to event id)
|
||||
# dict[int, dict[(str, str), str]]
|
||||
state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
|
||||
state_groups_ids = yield self.state_store.get_state_groups_ids(
|
||||
room_id, event_ids
|
||||
)
|
||||
|
||||
if len(state_groups_ids) == 0:
|
||||
return _StateCacheEntry(state={}, state_group=None)
|
||||
elif len(state_groups_ids) == 1:
|
||||
name, state_list = list(state_groups_ids.items()).pop()
|
||||
|
||||
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
||||
prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
|
||||
|
||||
return _StateCacheEntry(
|
||||
state=state_list,
|
||||
|
|
|
@ -30,6 +30,7 @@ stored in `synapse.storage.schema`.
|
|||
from synapse.storage.data_stores import DataStores
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
from synapse.storage.persist_events import EventsPersistenceStorage
|
||||
from synapse.storage.state import StateGroupStorage
|
||||
|
||||
__all__ = ["DataStores", "DataStore"]
|
||||
|
||||
|
@ -45,6 +46,7 @@ class Storage(object):
|
|||
self.main = stores.main
|
||||
|
||||
self.persistence = EventsPersistenceStorage(hs, stores)
|
||||
self.state = StateGroupStorage(hs, stores)
|
||||
|
||||
|
||||
def are_all_users_on_domain(txn, database_engine, domain):
|
||||
|
|
|
@ -547,7 +547,7 @@ class EventsPersistenceStorage(object):
|
|||
|
||||
if missing_event_ids:
|
||||
# Now pull out the state groups for any missing events from DB
|
||||
event_to_groups = yield self.state_store._get_state_group_for_events(
|
||||
event_to_groups = yield self.main_store._get_state_group_for_events(
|
||||
missing_event_ids
|
||||
)
|
||||
event_id_to_state_group.update(event_to_groups)
|
||||
|
|
|
@ -19,6 +19,8 @@ from six import iteritems, itervalues
|
|||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -322,3 +324,234 @@ class StateFilter(object):
|
|||
)
|
||||
|
||||
return member_filter, non_member_filter
|
||||
|
||||
|
||||
class StateGroupStorage(object):
|
||||
"""High level interface to fetching state for event.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores):
|
||||
self.stores = stores
|
||||
|
||||
def get_state_group_delta(self, state_group):
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
|
||||
(prev_group, delta_ids)
|
||||
"""
|
||||
|
||||
return self.stores.main.get_state_group_delta(state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, _room_id, event_ids):
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
Args:
|
||||
_room_id (str): id of the room for these events
|
||||
event_ids (iterable[str]): ids of the events
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self.stores.main._get_state_for_groups(groups)
|
||||
|
||||
return group_to_state
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_group(self, state_group):
|
||||
"""Get the event IDs of all the state in the given state group
|
||||
|
||||
Args:
|
||||
state_group (int)
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = yield self._get_state_for_groups((state_group,))
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
""" Get the state groups for the given list of event_ids
|
||||
Returns:
|
||||
Deferred[dict[int, list[EventBase]]]:
|
||||
dict of state_group_id -> list of state events.
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
[
|
||||
ev_id
|
||||
for group_ids in itervalues(group_to_ids)
|
||||
for ev_id in itervalues(group_ids)
|
||||
],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
return {
|
||||
group: [
|
||||
state_event_map[v]
|
||||
for v in itervalues(event_id_map)
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in iteritems(group_to_ids)
|
||||
}
|
||||
|
||||
def _get_state_groups_from_groups(self, groups, state_filter):
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
Args:
|
||||
groups(list[int]): list of state group IDs to query
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
|
||||
return self.stores.main._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
Args:
|
||||
event_ids (list[string])
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self.stores.main._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
state_event_map = yield self.stores.main.get_events(
|
||||
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: {
|
||||
k: state_event_map[v]
|
||||
for k, v in iteritems(group_to_state[group])
|
||||
if v in state_event_map
|
||||
}
|
||||
for event_id, group in iteritems(event_to_groups)
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
|
||||
Args:
|
||||
event_ids(list(str)): events whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from event_id -> (type, state_key) -> event_id
|
||||
"""
|
||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(itervalues(event_to_groups))
|
||||
group_to_state = yield self.stores.main._get_state_for_groups(
|
||||
groups, state_filter
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in iteritems(event_to_groups)
|
||||
}
|
||||
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
|
||||
return state_map[event_id]
|
||||
|
||||
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
|
||||
Args:
|
||||
groups (iterable[int]): list of state groups for which we want
|
||||
to get the state.
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
from the database.
|
||||
Returns:
|
||||
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||
"""
|
||||
return self.stores.main._get_state_for_groups(groups, state_filter)
|
||||
|
||||
def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
):
|
||||
"""Store a new set of state, returning a newly assigned state group.
|
||||
|
||||
Args:
|
||||
event_id (str): The event ID for which the state was calculated
|
||||
room_id (str)
|
||||
prev_group (int|None): A previous state group for the room, optional.
|
||||
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||
`current_state_ids`, if `prev_group` was given. Same format as
|
||||
`current_state_ids`.
|
||||
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||
to event_id.
|
||||
|
||||
Returns:
|
||||
Deferred[int]: The state group ID
|
||||
"""
|
||||
return self.stores.main.store_state_group(
|
||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
)
|
||||
|
|
|
@ -23,6 +23,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
|
@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = (
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_client(
|
||||
store, user_id, events, is_peeking=False, always_include_ids=frozenset()
|
||||
storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
|
||||
):
|
||||
"""
|
||||
Check which events a user is allowed to see
|
||||
|
||||
Args:
|
||||
store (synapse.storage.DataStore): our datastore (can also be a worker
|
||||
store)
|
||||
storage
|
||||
user_id(str): user id to be checked
|
||||
events(list[synapse.events.EventBase]): sequence of events to be checked
|
||||
is_peeking(bool): should be True if:
|
||||
|
@ -68,12 +68,12 @@ def filter_events_for_client(
|
|||
events = list(e for e in events if not e.internal_metadata.is_soft_failed())
|
||||
|
||||
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
|
||||
event_id_to_state = yield store.get_state_for_events(
|
||||
event_id_to_state = yield storage.state.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(types),
|
||||
)
|
||||
|
||||
ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
|
||||
ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", user_id
|
||||
)
|
||||
|
||||
|
@ -84,7 +84,7 @@ def filter_events_for_client(
|
|||
else []
|
||||
)
|
||||
|
||||
erased_senders = yield store.are_users_erased((e.sender for e in events))
|
||||
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
|
||||
|
||||
def allowed(event):
|
||||
"""
|
||||
|
@ -213,13 +213,17 @@ def filter_events_for_client(
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_server(
|
||||
store, server_name, events, redact=True, check_history_visibility_only=False
|
||||
storage: Storage,
|
||||
server_name,
|
||||
events,
|
||||
redact=True,
|
||||
check_history_visibility_only=False,
|
||||
):
|
||||
"""Filter a list of events based on whether given server is allowed to
|
||||
see them.
|
||||
|
||||
Args:
|
||||
store (DataStore)
|
||||
storage
|
||||
server_name (str)
|
||||
events (iterable[FrozenEvent])
|
||||
redact (bool): Whether to return a redacted version of the event, or
|
||||
|
@ -274,7 +278,7 @@ def filter_events_for_server(
|
|||
# Lets check to see if all the events have a history visibility
|
||||
# of "shared" or "world_readable". If thats the case then we don't
|
||||
# need to check membership (as we know the server is in the room).
|
||||
event_to_state_ids = yield store.get_state_ids_for_events(
|
||||
event_to_state_ids = yield storage.state.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(
|
||||
types=((EventTypes.RoomHistoryVisibility, ""),)
|
||||
|
@ -292,14 +296,14 @@ def filter_events_for_server(
|
|||
if not visibility_ids:
|
||||
all_open = True
|
||||
else:
|
||||
event_map = yield store.get_events(visibility_ids)
|
||||
event_map = yield storage.main.get_events(visibility_ids)
|
||||
all_open = all(
|
||||
e.content.get("history_visibility") in (None, "shared", "world_readable")
|
||||
for e in itervalues(event_map)
|
||||
)
|
||||
|
||||
if not check_history_visibility_only:
|
||||
erased_senders = yield store.are_users_erased((e.sender for e in events))
|
||||
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
|
||||
else:
|
||||
# We don't want to check whether users are erased, which is equivalent
|
||||
# to no users having been erased.
|
||||
|
@ -328,7 +332,7 @@ def filter_events_for_server(
|
|||
|
||||
# first, for each event we're wanting to return, get the event_ids
|
||||
# of the history vis and membership state at those events.
|
||||
event_to_state_ids = yield store.get_state_ids_for_events(
|
||||
event_to_state_ids = yield storage.state.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(
|
||||
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
|
||||
|
@ -358,7 +362,7 @@ def filter_events_for_server(
|
|||
return False
|
||||
return state_key[idx + 1 :] == server_name
|
||||
|
||||
event_map = yield store.get_events(
|
||||
event_map = yield storage.main.get_events(
|
||||
[
|
||||
e_id
|
||||
for e_id, key in iteritems(event_id_to_state_key)
|
||||
|
|
|
@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_datastore = self.store
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
|
@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
||||
)
|
||||
|
||||
state_group_map = yield self.store.get_state_groups_ids(
|
||||
state_group_map = yield self.storage.state.get_state_groups_ids(
|
||||
self.room, [e2.event_id]
|
||||
)
|
||||
self.assertEqual(len(state_group_map), 1)
|
||||
|
@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
||||
)
|
||||
|
||||
state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
|
||||
state_group_map = yield self.storage.state.get_state_groups(
|
||||
self.room, [e2.event_id]
|
||||
)
|
||||
self.assertEqual(len(state_group_map), 1)
|
||||
state_list = list(state_group_map.values())[0]
|
||||
|
||||
|
@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
)
|
||||
|
||||
# check we get the full state as of the final event
|
||||
state = yield self.store.get_state_for_event(e5.event_id)
|
||||
state = yield self.storage.state.get_state_for_event(e5.event_id)
|
||||
|
||||
self.assertIsNotNone(e4)
|
||||
|
||||
|
@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
)
|
||||
|
||||
# check we can filter to the m.room.name event (with a '' state key)
|
||||
state = yield self.store.get_state_for_event(
|
||||
state = yield self.storage.state.get_state_for_event(
|
||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
|
||||
|
||||
# check we can filter to the m.room.name event (with a wildcard None state key)
|
||||
state = yield self.store.get_state_for_event(
|
||||
state = yield self.storage.state.get_state_for_event(
|
||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
|
||||
)
|
||||
|
||||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
|
||||
|
||||
# check we can grab the m.room.member events (with a wildcard None state key)
|
||||
state = yield self.store.get_state_for_event(
|
||||
state = yield self.storage.state.get_state_for_event(
|
||||
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
|
||||
|
@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# check we can grab a specific room member without filtering out the
|
||||
# other event types
|
||||
state = yield self.store.get_state_for_event(
|
||||
state = yield self.storage.state.get_state_for_event(
|
||||
e5.event_id,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {self.u_alice.to_string()}},
|
||||
|
@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
)
|
||||
|
||||
# check that we can grab everything except members
|
||||
state = yield self.store.get_state_for_event(
|
||||
state = yield self.storage.state.get_state_for_event(
|
||||
e5.event_id,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
|
@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
#######################################################
|
||||
|
||||
room_id = self.room.to_string()
|
||||
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
|
||||
group_ids = yield self.storage.state.get_state_groups_ids(
|
||||
room_id, [e5.event_id]
|
||||
)
|
||||
group = list(group_ids.keys())[0]
|
||||
|
||||
# test _get_state_for_group_using_cache correctly filters out members
|
||||
# with types=[]
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
|
@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
|
@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# test _get_state_for_group_using_cache correctly filters in members
|
||||
# with wildcard types
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
|
@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
|
@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# test _get_state_for_group_using_cache correctly filters in members
|
||||
# with specific types
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
|
@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
|
@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# test _get_state_for_group_using_cache correctly filters in members
|
||||
# with specific types
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||
|
@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
#######################################################
|
||||
# deliberately remove e2 (room name) from the _state_group_cache
|
||||
|
||||
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
|
||||
group
|
||||
)
|
||||
(
|
||||
is_all,
|
||||
known_absent,
|
||||
state_dict_ids,
|
||||
) = self.state_datastore._state_group_cache.get(group)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertEqual(known_absent, set())
|
||||
|
@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
)
|
||||
|
||||
state_dict_ids.pop((e2.type, e2.state_key))
|
||||
self.store._state_group_cache.invalidate(group)
|
||||
self.store._state_group_cache.update(
|
||||
sequence=self.store._state_group_cache.sequence,
|
||||
self.state_datastore._state_group_cache.invalidate(group)
|
||||
self.state_datastore._state_group_cache.update(
|
||||
sequence=self.state_datastore._state_group_cache.sequence,
|
||||
key=group,
|
||||
value=state_dict_ids,
|
||||
# list fetched keys so it knows it's partial
|
||||
fetched_keys=((e1.type, e1.state_key),),
|
||||
)
|
||||
|
||||
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
|
||||
group
|
||||
)
|
||||
(
|
||||
is_all,
|
||||
known_absent,
|
||||
state_dict_ids,
|
||||
) = self.state_datastore._state_group_cache.get(group)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
|
||||
|
@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
# test _get_state_for_group_using_cache correctly filters out members
|
||||
# with types=[]
|
||||
room_id = self.room.to_string()
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
|
@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
|
||||
|
||||
room_id = self.room.to_string()
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
|
@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# test _get_state_for_group_using_cache correctly filters in members
|
||||
# wildcard types
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
|
@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
|
@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# test _get_state_for_group_using_cache correctly filters in members
|
||||
# with specific types
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
|
@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
|
@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
# test _get_state_for_group_using_cache correctly filters in members
|
||||
# with specific types
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||
|
@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({}, state_dict)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||
self.store._state_group_members_cache,
|
||||
(
|
||||
state_dict,
|
||||
is_all,
|
||||
) = yield self.state_datastore._get_state_for_group_using_cache(
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||
|
|
|
@ -158,10 +158,12 @@ class Graph(object):
|
|||
class StateTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.store = StateGroupStore()
|
||||
storage = Mock(main=self.store, state=self.store)
|
||||
hs = Mock(
|
||||
spec_set=[
|
||||
"config",
|
||||
"get_datastore",
|
||||
"get_storage",
|
||||
"get_auth",
|
||||
"get_state_handler",
|
||||
"get_clock",
|
||||
|
@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
|
|||
hs.get_clock.return_value = MockClock()
|
||||
hs.get_auth.return_value = Auth(hs)
|
||||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
||||
hs.get_storage.return_value = storage
|
||||
|
||||
self.state = StateHandler(hs)
|
||||
self.event_id = 0
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import succeed
|
||||
|
||||
|
@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
|||
events_to_filter.append(evt)
|
||||
|
||||
filtered = yield filter_events_for_server(
|
||||
self.store, "test_server", events_to_filter
|
||||
self.storage, "test_server", events_to_filter
|
||||
)
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
|
@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
|||
|
||||
# ... and the filtering happens.
|
||||
filtered = yield filter_events_for_server(
|
||||
self.store, "test_server", events_to_filter
|
||||
self.storage, "test_server", events_to_filter
|
||||
)
|
||||
|
||||
for i in range(0, len(events_to_filter)):
|
||||
|
@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
|||
|
||||
logger.info("Starting filtering")
|
||||
start = time.time()
|
||||
|
||||
storage = Mock()
|
||||
storage.main = test_store
|
||||
storage.state = test_store
|
||||
|
||||
filtered = yield filter_events_for_server(
|
||||
test_store, "test_server", events_to_filter
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue