Merge pull request #3541 from matrix-org/rav/optimize_filter_events_for_server

Refactor and optimze filter_events_for_server
This commit is contained in:
Richard van der Hoff 2018-07-17 14:01:39 +01:00 committed by GitHub
commit 9c04b4abf9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 409 additions and 138 deletions

1
changelog.d/3541.feature Normal file
View file

@ -0,0 +1 @@
Optimisation to make handling incoming federation requests more efficient.

View file

@ -43,7 +43,6 @@ from synapse.crypto.event_signing import (
add_hashes_and_signatures, add_hashes_and_signatures,
compute_event_signature, compute_event_signature,
) )
from synapse.events.utils import prune_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.state import resolve_events_with_factory from synapse.state import resolve_events_with_factory
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
@ -52,8 +51,8 @@ from synapse.util.async import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
from ._base import BaseHandler from ._base import BaseHandler
@ -501,137 +500,6 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
"""Filter the given events for the given server, redacting those the
server can't see.
Assumes the server is currently in the room.
Returns
list[FrozenEvent]
"""
# First lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
)
)
visibility_ids = set()
for sids in event_to_state_ids.itervalues():
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
# is "shared" visiblity.
if not visibility_ids:
defer.returnValue(events)
event_map = yield self.store.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues()
)
if all_open:
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except Exception:
return False
# Parses mapping `event_id -> (type, state_key) -> state event_id`
# to get all state ids that we're interested in.
event_map = yield self.store.get_events([
e_id
for key_to_eid in list(event_to_state_ids.values())
for key, e_id in key_to_eid.items()
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.iteritems()
}
erased_senders = yield self.store.are_users_erased(
e.sender for e in events,
)
def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return prune_event(event)
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.itervalues():
if ev.type != EventTypes.Member:
continue
try:
domain = get_domain_from_id(ev.state_key)
except Exception:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
return prune_event(event)
return event
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities): def backfill(self, dest, room_id, limit, extremities):
@ -1558,7 +1426,7 @@ class FederationHandler(BaseHandler):
limit limit
) )
events = yield self._filter_events_for_server(origin, room_id, events) events = yield filter_events_for_server(self.store, origin, events)
defer.returnValue(events) defer.returnValue(events)
@ -1605,8 +1473,8 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server( events = yield filter_events_for_server(
origin, event.room_id, [event] self.store, origin, [event],
) )
event = events[0] event = events[0]
defer.returnValue(event) defer.returnValue(event)
@ -1896,8 +1764,8 @@ class FederationHandler(BaseHandler):
min_depth=min_depth, min_depth=min_depth,
) )
missing_events = yield self._filter_events_for_server( missing_events = yield filter_events_for_server(
origin, room_id, missing_events, self.store, origin, missing_events,
) )
defer.returnValue(missing_events) defer.returnValue(missing_events)

View file

@ -16,10 +16,13 @@ import itertools
import logging import logging
import operator import operator
import six
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.types import get_domain_from_id
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -225,3 +228,141 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
# we turn it into a list before returning it. # we turn it into a list before returning it.
defer.returnValue(list(filtered_events)) defer.returnValue(list(filtered_events))
@defer.inlineCallbacks
def filter_events_for_server(store, server_name, events):
# First lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
)
)
visibility_ids = set()
for sids in event_to_state_ids.itervalues():
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
# If we failed to find any history visibility events then the default
# is "shared" visiblity.
if not visibility_ids:
defer.returnValue(events)
event_map = yield store.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues()
)
if all_open:
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
# We only want to pull out member events that correspond to the
# server's domain.
#
# event_to_state_ids contains lots of duplicates, so it turns out to be
# cheaper to build a complete set of unique
# ((type, state_key), event_id) tuples, and then filter out the ones we
# don't want.
#
state_key_to_event_id_set = {
e
for key_to_eid in six.itervalues(event_to_state_ids)
for e in key_to_eid.items()
}
def include(typ, state_key):
if typ != EventTypes.Member:
return True
# we avoid using get_domain_from_id here for efficiency.
idx = state_key.find(":")
if idx == -1:
return False
return state_key[idx + 1:] == server_name
event_map = yield store.get_events([
e_id
for key, e_id in state_key_to_event_id_set
if include(key[0], key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.iteritems()
}
erased_senders = yield store.are_users_erased(
e.sender for e in events,
)
def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return prune_event(event)
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.itervalues():
if ev.type != EventTypes.Member:
continue
try:
domain = get_domain_from_id(ev.state_key)
except Exception:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
# server has no users in the room: redact
return prune_event(event)
return event
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])

261
tests/test_visibility.py Normal file
View file

@ -0,0 +1,261 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
import tests.unittest
from tests.utils import setup_test_homeserver
logger = logging.getLogger(__name__)
TEST_ROOM_ID = "!TEST:ROOM"
class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_filtering(self):
#
# The events to be filtered consist of 10 membership events (it doesn't
# really matter if they are joins or leaves, so let's make them joins).
# One of those membership events is going to be for a user on the
# server we are filtering for (so we can check the filtering is doing
# the right thing).
#
# before we do that, we persist some other events to act as state.
self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
events_to_filter = []
for i in range(0, 10):
user = "@user%i:%s" % (
i, "test_server" if i == 5 else "other_server"
)
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter,
)
# the result should be 5 redacted events, and 5 unredacted events.
for i in range(0, 5):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertNotIn("a", filtered[i].content)
for i in range(5, 10):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
@defer.inlineCallbacks
def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility}
builder = self.event_builder_factory.new({
"type": "m.room.history_visibility",
"sender": user_id,
"state_key": "",
"room_id": TEST_ROOM_ID,
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.hs.get_datastore().persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}):
content = {"membership": membership}
content.update(extra_content)
builder = self.event_builder_factory.new({
"type": "m.room.member",
"sender": user_id,
"state_key": user_id,
"room_id": TEST_ROOM_ID,
"content": content,
})
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
yield self.hs.get_datastore().persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def test_large_room(self):
# see what happens when we have a large room with hundreds of thousands
# of membership events
# As above, the events to be filtered consist of 10 membership events,
# where one of them is for a user on the server we are filtering for.
import cProfile
import pstats
import time
# we stub out the store, because building up all that state the normal
# way is very slow.
test_store = _TestStore()
# our initial state is 100000 membership events and one
# history_visibility event.
room_state = []
history_visibility_evt = FrozenEvent({
"event_id": "$history_vis",
"type": "m.room.history_visibility",
"sender": "@resident_user_0:test.com",
"state_key": "",
"room_id": TEST_ROOM_ID,
"content": {"history_visibility": "joined"},
})
room_state.append(history_visibility_evt)
test_store.add_event(history_visibility_evt)
for i in range(0, 100000):
user = "@resident_user_%i:test.com" % (i, )
evt = FrozenEvent({
"event_id": "$res_event_%i" % (i, ),
"type": "m.room.member",
"state_key": user,
"sender": user,
"room_id": TEST_ROOM_ID,
"content": {
"membership": "join",
"extra": "zzz,"
},
})
room_state.append(evt)
test_store.add_event(evt)
events_to_filter = []
for i in range(0, 10):
user = "@user%i:%s" % (
i, "test_server" if i == 5 else "other_server"
)
evt = FrozenEvent({
"event_id": "$evt%i" % (i, ),
"type": "m.room.member",
"state_key": user,
"sender": user,
"room_id": TEST_ROOM_ID,
"content": {
"membership": "join",
"extra": "zzz",
},
})
events_to_filter.append(evt)
room_state.append(evt)
test_store.add_event(evt)
test_store.set_state_ids_for_event(evt, {
(e.type, e.state_key): e.event_id for e in room_state
})
pr = cProfile.Profile()
pr.enable()
logger.info("Starting filtering")
start = time.time()
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter,
)
logger.info("Filtering took %f seconds", time.time() - start)
pr.disable()
with open("filter_events_for_server.profile", "w+") as f:
ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
ps.print_stats()
# the result should be 5 redacted events, and 5 unredacted events.
for i in range(0, 5):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertNotIn("extra", filtered[i].content)
for i in range(5, 10):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["extra"], "zzz")
test_large_room.skip = "Disabled by default because it's slow"
class _TestStore(object):
"""Implements a few methods of the DataStore, so that we can test
filter_events_for_server
"""
def __init__(self):
# data for get_events: a map from event_id to event
self.events = {}
# data for get_state_ids_for_events mock: a map from event_id to
# a map from (type_state_key) -> event_id for the state at that
# event
self.state_ids_for_events = {}
def add_event(self, event):
self.events[event.event_id] = event
def set_state_ids_for_event(self, event, state):
self.state_ids_for_events[event.event_id] = state
def get_state_ids_for_events(self, events, types):
res = {}
include_memberships = False
for (type, state_key) in types:
if type == "m.room.history_visibility":
continue
if type != "m.room.member" or state_key is not None:
raise RuntimeError(
"Unimplemented: get_state_ids with type (%s, %s)" %
(type, state_key),
)
include_memberships = True
if include_memberships:
for event_id in events:
res[event_id] = self.state_ids_for_events[event_id]
else:
k = ("m.room.history_visibility", "")
for event_id in events:
hve = self.state_ids_for_events[event_id][k]
res[event_id] = {k: hve}
return succeed(res)
def get_events(self, events):
return succeed({
event_id: self.events[event_id] for event_id in events
})
def are_users_erased(self, users):
return succeed({u: False for u in users})