0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 17:23:53 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_fed_store

This commit is contained in:
Erik Johnston 2018-07-30 09:56:18 +01:00
commit 143f1a2532
37 changed files with 787 additions and 108 deletions

View file

@ -362,6 +362,19 @@ Synapse is in the Fedora repositories as ``matrix-synapse``::
Oleg Girko provides Fedora RPMs at
https://obs.infoserver.lv/project/monitor/matrix-synapse
OpenSUSE
--------
Synapse is in the OpenSUSE repositories as ``matrix-synapse``::
sudo zypper install matrix-synapse
SUSE Linux Enterprise Server
----------------------------
Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at
https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/
ArchLinux
---------

1
changelog.d/2970.feature Normal file
View file

@ -0,0 +1 @@
add support for the lazy_loaded_members filter as per MSC1227

1
changelog.d/3331.feature Normal file
View file

@ -0,0 +1 @@
add support for the include_redundant_members filter param as per MSC1227

1
changelog.d/3391.bugfix Normal file
View file

@ -0,0 +1 @@
Default inviter_display_name to mxid for email invites

1
changelog.d/3567.feature Normal file
View file

@ -0,0 +1 @@
make the /context API filter & lazy-load aware as per MSC1227

1
changelog.d/3609.misc Normal file
View file

@ -0,0 +1 @@
Fix a documentation typo in on_make_leave_request

1
changelog.d/3610.feature Normal file
View file

@ -0,0 +1 @@
Add metrics to track resource usage by background processes

1
changelog.d/3613.misc Normal file
View file

@ -0,0 +1 @@
Remove some redundant joins on event_edges.room_id

1
changelog.d/3614.misc Normal file
View file

@ -0,0 +1 @@
Stop populating events.content

1
changelog.d/3616.misc Normal file
View file

@ -0,0 +1 @@
Update the /send_leave path registration to use event_id rather than a transaction ID.

View file

@ -113,7 +113,13 @@ ROOM_EVENT_FILTER_SCHEMA = {
},
"contains_url": {
"type": "boolean"
}
},
"lazy_load_members": {
"type": "boolean"
},
"include_redundant_members": {
"type": "boolean"
},
}
}
@ -261,6 +267,12 @@ class FilterCollection(object):
def ephemeral_limit(self):
return self._room_ephemeral_filter.limit()
def lazy_load_members(self):
return self._room_state_filter.lazy_load_members()
def include_redundant_members(self):
return self._room_state_filter.include_redundant_members()
def filter_presence(self, events):
return self._presence_filter.filter(events)
@ -417,6 +429,12 @@ class Filter(object):
def limit(self):
return self.filter_json.get("limit", 10)
def lazy_load_members(self):
return self.filter_json.get("lazy_load_members", False)
def include_redundant_members(self):
return self.filter_json.get("include_redundant_members", False)
def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"):

View file

@ -429,7 +429,7 @@ def run(hs):
stats_process = []
def start_phone_stats_home():
run_as_background_process("phone_stats_home", phone_stats_home)
return run_as_background_process("phone_stats_home", phone_stats_home)
@defer.inlineCallbacks
def phone_stats_home():
@ -502,7 +502,7 @@ def run(hs):
)
def generate_user_daily_visit_stats():
run_as_background_process(
return run_as_background_process(
"generate_user_daily_visits",
hs.get_datastore().generate_user_daily_visits,
)

View file

@ -404,10 +404,10 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)"
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, txid):
def on_PUT(self, origin, content, query, room_id, event_id):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))

View file

@ -153,7 +153,7 @@ class GroupAttestionRenewer(object):
defer.returnValue({})
def _start_renew_attestations(self):
run_as_background_process("renew_attestations", self._renew_attestations)
return run_as_background_process("renew_attestations", self._renew_attestations)
@defer.inlineCallbacks
def _renew_attestations(self):

View file

@ -1233,7 +1233,7 @@ class FederationHandler(BaseHandler):
@log_function
def on_make_leave_request(self, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
join event for the room and return that. We do *not* persist or
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
builder = self.event_builder_factory.new({

View file

@ -256,7 +256,7 @@ class ProfileHandler(BaseHandler):
)
def _start_update_remote_profile_cache(self):
run_as_background_process(
return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache,
)

View file

@ -15,6 +15,7 @@
# limitations under the License.
"""Contains functions for performing events on rooms."""
import itertools
import logging
import math
import string
@ -401,7 +402,7 @@ class RoomContextHandler(object):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit):
def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@ -411,6 +412,8 @@ class RoomContextHandler(object):
event_id (str)
limit (int): The maximum number of events to return in total
(excluding state).
event_filter (Filter|None): the filter to apply to the events returned
(excluding the target event_id)
Returns:
dict, or None if the event isn't found
@ -443,7 +446,7 @@ class RoomContextHandler(object):
)
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit
room_id, event_id, before_limit, after_limit, event_filter
)
results["events_before"] = yield filter_evts(results["events_before"])
@ -455,8 +458,23 @@ class RoomContextHandler(object):
else:
last_event_id = event_id
types = None
filtered_types = None
if event_filter and event_filter.lazy_load_members():
members = set(ev.sender for ev in itertools.chain(
results["events_before"],
(results["event"],),
results["events_after"],
))
filtered_types = [EventTypes.Member]
types = [(EventTypes.Member, member) for member in members]
# XXX: why do we return the state as of the last event rather than the
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events(
[last_event_id], None
[last_event_id], types, filtered_types=filtered_types,
)
results["state"] = list(state[last_event_id].values())

View file

@ -708,6 +708,10 @@ class RoomMemberHandler(object):
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
# if user has no display name, default to their MXID
if not inviter_display_name:
inviter_display_name = user.to_string()
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:

View file

@ -287,7 +287,7 @@ class SearchHandler(BaseHandler):
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
event.room_id, event.event_id, before_limit, after_limit,
)
logger.info(

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,6 +26,8 @@ from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user
from synapse.types import RoomStreamToken
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure, measure_func
@ -32,6 +35,14 @@ from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__)
# Store the cache that tracks which lazy-loaded members have been sent to a given
# client for no more than 30 minutes.
LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000
# Remember the last 100 members we sent to a client for the purposes of
# avoiding redundantly sending the same lazy-loaded members to the client
LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
@ -181,6 +192,12 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache", self.clock,
max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise
@ -416,29 +433,44 @@ class SyncHandler(object):
))
@defer.inlineCallbacks
def get_state_after_event(self, event):
def get_state_after_event(self, event, types=None, filtered_types=None):
"""
Get the room state after the given event
Args:
event(synapse.events.EventBase): event of interest
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(event.event_id)
state_ids = yield self.store.get_state_ids_for_event(
event.event_id, types, filtered_types=filtered_types,
)
if event.is_state():
state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state_ids)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
""" Get the room state at a particular stream position
Args:
room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
A Deferred map from ((type, state_key)->Event)
@ -453,7 +485,9 @@ class SyncHandler(object):
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(last_event)
state = yield self.get_state_after_event(
last_event, types, filtered_types=filtered_types,
)
else:
# no events in this room - so presumably no state
@ -485,59 +519,129 @@ class SyncHandler(object):
# TODO(mjark) Check for new redactions in the state events.
with Measure(self.clock, "compute_state_delta"):
types = None
filtered_types = None
lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = (
sync_config.filter_collection.include_redundant_members()
)
if lazy_load_members:
# We only request state for the members needed to display the
# timeline:
types = [
(EventTypes.Member, state_key)
for state_key in set(
event.sender # FIXME: we also care about invite targets etc.
for event in batch.events
)
]
# only apply the filtering to room members
filtered_types = [EventTypes.Member]
timeline_state = {
(event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
if full_state:
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id
batch.events[-1].event_id, types=types,
filtered_types=filtered_types,
)
state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
)
else:
current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token
room_id, stream_position=now_token, types=types,
filtered_types=filtered_types,
)
state_ids = current_state_ids
timeline_state = {
(event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_ids,
previous={},
current=current_state_ids,
lazy_load_members=lazy_load_members,
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
room_id, stream_position=since_token, types=types,
filtered_types=filtered_types,
)
current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id
batch.events[-1].event_id, types=types,
filtered_types=filtered_types,
)
state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
)
timeline_state = {
(event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
current=current_state_ids,
lazy_load_members=lazy_load_members,
)
else:
state_ids = {}
if lazy_load_members:
if types:
state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
)
if lazy_load_members and not include_redundant_members:
cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
self.lazy_loaded_members_cache[cache_key] = cache
else:
logger.debug("found LruCache for %r", cache_key)
# if it's a new sync sequence, then assume the client has had
# amnesia and doesn't want any recent lazy-loaded members
# de-duplicated.
if since_token is None:
logger.debug("clearing LruCache for %r", cache_key)
cache.clear()
else:
# only send members which aren't in our LruCache (either
# because they're new to this client or have been pushed out
# of the cache)
logger.debug("filtering state from %r...", state_ids)
state_ids = {
t: event_id
for t, event_id in state_ids.iteritems()
if cache.get(t[1]) != event_id
}
logger.debug("...to %r", state_ids)
# add any member IDs we are about to send into our LruCache
for t, event_id in itertools.chain(
state_ids.items(),
timeline_state.items(),
):
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
state = {}
if state_ids:
@ -1448,7 +1552,9 @@ def _action_has_highlight(actions):
return False
def _calculate_state(timeline_contains, timeline_start, previous, current):
def _calculate_state(
timeline_contains, timeline_start, previous, current, lazy_load_members,
):
"""Works out what state to include in a sync response.
Args:
@ -1457,6 +1563,9 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync)
current (dict): state at the end of the timeline
lazy_load_members (bool): whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
Returns:
dict
@ -1472,9 +1581,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
}
c_ids = set(e for e in current.values())
tc_ids = set(e for e in timeline_contains.values())
p_ids = set(e for e in previous.values())
ts_ids = set(e for e in timeline_start.values())
p_ids = set(e for e in previous.values())
tc_ids = set(e for e in timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
# as we may not have sent them to the client before. We find these membership
# events by filtering them out of timeline_start, which has already been filtered
# to only include membership events for the senders in the timeline.
# In practice, we can do this by removing them from the p_ids list,
# which is the list of relevant state we know we have already sent to the client.
# see https://github.com/matrix-org/synapse/pull/2970
# /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
p_ids.difference_update(
e for t, e in timeline_start.iteritems()
if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids

View file

@ -151,13 +151,19 @@ def run_as_background_process(desc, func, *args, **kwargs):
This should be used to wrap processes which are fired off to run in the
background, instead of being associated with a particular request.
It returns a Deferred which completes when the function completes, but it doesn't
follow the synapse logcontext rules, which makes it appropriate for passing to
clock.looping_call and friends (or for firing-and-forgetting in the middle of a
normal synapse inlineCallbacks function).
Args:
desc (str): a description for this background process type
func: a function, which may return a Deferred
args: positional args for func
kwargs: keyword args for func
Returns: None
Returns: Deferred which returns the result of func, but note that it does not
follow the synapse logcontext rules.
"""
@defer.inlineCallbacks
def run():
@ -176,4 +182,4 @@ def run_as_background_process(desc, func, *args, **kwargs):
_background_processes[desc].remove(proc)
with PreserveLoggingContext():
run()
return run()

View file

@ -531,11 +531,20 @@ class RoomEventContextServlet(ClientV1RestServlet):
limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
results = yield self.room_context_handler.get_event_context(
requester.user,
room_id,
event_id,
limit,
event_filter,
)
if not results:

View file

@ -106,7 +106,7 @@ class MediaRepository(object):
)
def _start_update_recently_accessed(self):
run_as_background_process(
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed,
)

View file

@ -373,7 +373,7 @@ class PreviewUrlResource(Resource):
})
def _start_expire_url_cache_data(self):
run_as_background_process(
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data,
)

View file

@ -102,7 +102,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
to_update,
)
run_as_background_process(
return run_as_background_process(
"update_client_ips", update,
)

View file

@ -712,7 +712,7 @@ class DeviceStore(SQLBaseStore):
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
run_as_background_process(
return run_as_background_process(
"prune_old_outbound_device_pokes",
self.runInteraction,
"_prune_old_outbound_device_pokes",

View file

@ -114,9 +114,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
sql = (
"SELECT b.event_id, MAX(e.depth) FROM events as e"
" INNER JOIN event_edges as g"
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
" ON g.event_id = e.event_id"
" INNER JOIN event_backward_extremities as b"
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
" ON g.prev_event_id = b.event_id"
" WHERE b.room_id = ? AND g.is_state is ?"
" GROUP BY b.event_id"
)
@ -330,8 +330,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
"SELECT depth, prev_event_id FROM event_edges"
" INNER JOIN events"
" ON prev_event_id = events.event_id"
" AND event_edges.room_id = events.room_id"
" WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
" WHERE event_edges.event_id = ?"
" AND event_edges.is_state = ?"
" LIMIT ?"
)
@ -365,7 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
txn.execute(
query,
(room_id, event_id, False, limit - len(event_results))
(event_id, False, limit - len(event_results))
)
for row in txn:
@ -402,7 +401,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
query = (
"SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? AND is_state = ? "
"WHERE event_id = ? AND is_state = ? "
"LIMIT ?"
)
@ -411,7 +410,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
for event_id in front:
txn.execute(
query,
(room_id, event_id, False, limit - len(event_results))
(event_id, False, limit - len(event_results))
)
for e_id, in txn:
@ -549,7 +548,7 @@ class EventFederationStore(EventFederationWorkerStore):
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
)
run_as_background_process(
return run_as_background_process(
"delete_old_forward_extrem_cache",
self.runInteraction,
"_delete_old_forward_extrem_cache",

View file

@ -460,7 +460,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _find_stream_orderings_for_times(self):
run_as_background_process(
return run_as_background_process(
"event_push_action_stream_orderings",
self.runInteraction,
"_find_stream_orderings_for_times",
@ -790,7 +790,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
""", (room_id, user_id, stream_ordering))
def _start_rotate_notifs(self):
run_as_background_process("rotate_notifs", self._rotate_notifs)
return run_as_background_process("rotate_notifs", self._rotate_notifs)
@defer.inlineCallbacks
def _rotate_notifs(self):

View file

@ -530,7 +530,6 @@ class EventsStore(EventsWorkerStore):
iterable=list(new_latest_event_ids),
retcols=["prev_event_id"],
keyvalues={
"room_id": room_id,
"is_state": False,
},
desc="_calculate_new_extremeties",
@ -1199,7 +1198,6 @@ class EventsStore(EventsWorkerStore):
"type": event.type,
"processed": True,
"outlier": event.internal_metadata.is_outlier(),
"content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(),
"sender": event.sender,

View file

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
We want to stop populating 'event.content', so we need to make it nullable.
If this has to be rolled back, then the following should populate the missing data:
Postgres:
UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej
WHERE ej.event_id = events.event_id AND
stream_ordering < (
SELECT stream_ordering FROM events WHERE content IS NOT NULL
ORDER BY stream_ordering LIMIT 1
);
UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej
WHERE ej.event_id = events.event_id AND
stream_ordering > (
SELECT stream_ordering FROM events WHERE content IS NOT NULL
ORDER BY stream_ordering DESC LIMIT 1
);
SQLite:
UPDATE events SET content=(
SELECT json_extract(json,'$.content') FROM event_json ej
WHERE ej.event_id = events.event_id
)
WHERE
stream_ordering < (
SELECT stream_ordering FROM events WHERE content IS NOT NULL
ORDER BY stream_ordering LIMIT 1
)
OR stream_ordering > (
SELECT stream_ordering FROM events WHERE content IS NOT NULL
ORDER BY stream_ordering DESC LIMIT 1
);
"""
import logging
from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
pass
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("""
ALTER TABLE events ALTER COLUMN content DROP NOT NULL;
""")
return
# sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html
cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'")
(oldsql,) = cur.fetchone()
sql = oldsql.replace("content TEXT NOT NULL", "content TEXT")
if sql == oldsql:
raise Exception("Couldn't find null constraint to drop in %s" % oldsql)
logger.info("Replacing definition of 'events' with: %s", sql)
cur.execute("PRAGMA schema_version")
(oldver,) = cur.fetchone()
cur.execute("PRAGMA writable_schema=ON")
cur.execute(
"UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'",
(sql, ),
)
cur.execute("PRAGMA schema_version=%i" % (oldver+1,))
cur.execute("PRAGMA writable_schema=OFF")

View file

@ -37,7 +37,8 @@ CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
is_state BOOL NOT NULL,
is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular
-- event dag edge.
UNIQUE (event_id, prev_event_id, room_id, is_state)
);

View file

@ -19,7 +19,12 @@ CREATE TABLE IF NOT EXISTS events(
event_id TEXT NOT NULL,
type TEXT NOT NULL,
room_id TEXT NOT NULL,
content TEXT NOT NULL,
-- 'content' used to be created NULLable, but as of delta 50 we drop that constraint.
-- the hack we use to drop the constraint doesn't work for an in-memory sqlite
-- database, which breaks the sytests. Hence, we no longer make it nullable.
content TEXT,
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,

View file

@ -186,7 +186,17 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned.
Returns:
dictionary state_group -> (dict of (type, state_key) -> event id)
"""
results = {}
@ -200,8 +210,11 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue(results)
def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
def _get_state_groups_from_groups_txn(
self, txn, groups, types=None,
):
results = {group: {} for group in groups}
if types is not None:
types = list(set(types)) # deduplicate types list
@ -239,7 +252,7 @@ class StateGroupWorkerStore(SQLBaseStore):
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
if types:
if types is not None:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
@ -278,6 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]])
where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@ -332,16 +346,20 @@ class StateGroupWorkerStore(SQLBaseStore):
return results
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, types):
def get_state_for_events(self, event_ids, types, filtered_types=None):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list.
Args:
event_ids (list)
types (list): List of (type, state_key) tuples which are used to
filter the state fetched. `state_key` may be None, which matches
any `state_key`
event_ids (list[string])
types (list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
deferred: A list of dicts corresponding to the event_ids given.
@ -352,7 +370,7 @@ class StateGroupWorkerStore(SQLBaseStore):
)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types)
group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@ -371,15 +389,19 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None):
def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
Get the state dicts corresponding to a list of events
Args:
event_ids(list(str)): events whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
A deferred dict from event_id -> (type, state_key) -> state_event
@ -389,7 +411,7 @@ class StateGroupWorkerStore(SQLBaseStore):
)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types)
group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
event_to_state = {
event_id: group_to_state[group]
@ -399,37 +421,45 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
def get_state_for_event(self, event_id, types=None, filtered_types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
state_map = yield self.get_state_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, types=None):
def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], types)
state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id])
@cached(max_entries=50000)
@ -460,56 +490,73 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types):
def _get_some_state_from_cache(self, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
`missing_types` is the list of types that aren't in the cache for that
group. `got_all` is a bool indicating if we successfully retrieved all
Args:
group(int): The state group to lookup
types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns 2-tuple (`state_dict`, `got_all`).
`got_all` is a bool indicating if we successfully retrieved all
requests state from the cache, if False we need to query the DB for the
missing state.
Args:
group: The state group to lookup
types (list): List of 2-tuples of the form (`type`, `state_key`),
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
# tracks whether any of ourrequested types are missing from the cache
missing_types = False
for typ, state_key in types:
key = (typ, state_key)
if state_key is None:
if (
state_key is None or
(filtered_types is not None and typ not in filtered_types)
):
type_to_key[typ] = None
missing_types.add(key)
# we mark the type as missing from the cache because
# when the cache was populated it might have been done with a
# restricted set of state_keys, so the wildcard will not work
# and the cache may be incomplete.
missing_types = True
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
if key not in state_dict_ids and key not in known_absent:
missing_types.add(key)
missing_types = True
sentinel = object()
def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
return False
return filtered_types is not None and typ not in filtered_types
if valid_state_keys is None:
return True
if state_key in valid_state_keys:
return True
return False
got_all = is_all or not missing_types
got_all = is_all
if not got_all:
# the cache is incomplete. We may still have got all the results we need, if
# we don't have any wildcards in the match list.
if not missing_types and filtered_types is None:
got_all = True
return {
k: v for k, v in iteritems(state_dict_ids)
if include(k[0], k[1])
}, missing_types, got_all
}, got_all
def _get_all_state_from_cache(self, group):
"""Checks if group is in cache. See `_get_state_for_groups`
@ -526,7 +573,7 @@ class StateGroupWorkerStore(SQLBaseStore):
return state_dict_ids, is_all
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
def _get_state_for_groups(self, groups, types=None, filtered_types=None):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@ -540,6 +587,9 @@ class StateGroupWorkerStore(SQLBaseStore):
Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
@ -551,8 +601,8 @@ class StateGroupWorkerStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types,
state_dict_ids, got_all = self._get_some_state_from_cache(
group, types, filtered_types
)
results[group] = state_dict_ids
@ -579,13 +629,13 @@ class StateGroupWorkerStore(SQLBaseStore):
# cache. Hence, if we are doing a wildcard lookup, populate the
# cache fully so that we can do an efficient lookup next time.
if types and any(k is None for (t, k) in types):
if filtered_types or (types and any(k is None for (t, k) in types)):
types_to_fetch = None
else:
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types_to_fetch,
missing_groups, types_to_fetch
)
for group, group_state_dict in iteritems(group_to_state_dict):
@ -595,7 +645,10 @@ class StateGroupWorkerStore(SQLBaseStore):
if types:
for k, v in iteritems(group_state_dict):
(typ, _) = k
if k in types or (typ, None) in types:
if (
(k in types or (typ, None) in types) or
(filtered_types and typ not in filtered_types)
):
state_dict[k] = v
else:
state_dict.update(group_state_dict)

View file

@ -527,7 +527,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None,
):
"""Retrieve events and pagination tokens around a given event in a
room.
@ -536,6 +538,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
Returns:
dict
@ -543,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn,
room_id, event_id, before_limit, after_limit
room_id, event_id, before_limit, after_limit, event_filter,
)
events_before = yield self._get_events(
@ -563,7 +566,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"end": results["after"]["token"],
})
def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit):
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter,
):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
@ -572,6 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
Returns:
dict
@ -601,11 +607,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit,
event_filter=event_filter,
)
events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit,
event_filter=event_filter,
)
events_after = [r.event_id for r in rows]

View file

@ -273,7 +273,9 @@ class TransactionStore(SQLBaseStore):
return self.cursor_to_dict(txn)
def _start_cleanup_transactions(self):
run_as_background_process("cleanup_transactions", self._cleanup_transactions)
return run_as_background_process(
"cleanup_transactions", self._cleanup_transactions,
)
def _cleanup_transactions(self):
now = self._clock.time_msec()

View file

@ -64,7 +64,7 @@ class ExpiringCache(object):
return
def f():
run_as_background_process(
return run_as_background_process(
"prune_cache_%s" % self._cache_name,
self._prune_cache,
)

319
tests/storage/test_state.py Normal file
View file

@ -0,0 +1,319 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
import tests.unittest
import tests.utils
logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(StateStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
self.room = RoomID.from_string("!abc123:test")
yield self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
is_public=True
)
@defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.new({
"type": typ,
"sender": sender.to_string(),
"state_key": state_key,
"room_id": room.to_string(),
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
def assertStateMapEqual(self, s1, s2):
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, '', {},
)
e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, '', {
"name": "test room"
},
)
e3 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {
"membership": Membership.JOIN
},
)
e4 = yield self.inject_state_event(
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
"membership": Membership.JOIN
},
)
e5 = yield self.inject_state_event(
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
"membership": Membership.LEAVE
},
)
# check we get the full state as of the final event
state = yield self.store.get_state_for_event(
e5.event_id, None, filtered_types=None
)
self.assertIsNotNone(e4)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5,
}, state)
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, '')], filtered_types=None
)
self.assertStateMapEqual({
(e2.type, e2.state_key): e2,
}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, None)], filtered_types=None
)
self.assertStateMapEqual({
(e2.type, e2.state_key): e2,
}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, None)], filtered_types=None
)
self.assertStateMapEqual({
(e3.type, e3.state_key): e3,
(e5.type, e5.state_key): e5,
}, state)
# check we can use filter_types to grab a specific room member
# without filtering out the other event types
state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, self.u_alice.to_string())],
filtered_types=[EventTypes.Member],
)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3,
}, state)
# check that types=[], filtered_types=[EventTypes.Member]
# doesn't return all members
state = yield self.store.get_state_for_event(
e5.event_id, [], filtered_types=[EventTypes.Member],
)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
}, state)
#######################################################
# _get_some_state_from_cache tests against a full cache
#######################################################
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = group_ids.keys()[0]
# test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
self.assertDictEqual(state_dict_ids, {
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
})
state_dict_ids.pop((e2.type, e2.state_key))
self.store._state_group_cache.invalidate(group)
self.store._state_group_cache.update(
sequence=self.store._state_group_cache.sequence,
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=(
(e1.type, e1.state_key),
(e3.type, e3.state_key),
(e5.type, e5.state_key),
)
)
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([
(e1.type, e1.state_key),
(e3.type, e3.state_key),
(e5.type, e5.state_key),
]))
self.assertDictEqual(state_dict_ids, {
(e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id,
(e5.type, e5.state_key): e5.event_id,
})
############################################
# test that things work with a partial cache
# test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({
(e1.type, e1.state_key): e1.event_id,
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
self.assertEqual(is_all, True)
self.assertDictEqual({
(e5.type, e5.state_key): e5.event_id,
}, state_dict)