Allow filtering events for multiple users at once

This commit is contained in:
Erik Johnston 2016-01-18 10:45:09 +00:00
parent 5de1563997
commit cc66a9a5e3
2 changed files with 67 additions and 39 deletions

View file

@ -53,16 +53,54 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False): def _filter_events_for_clients(self, users, events):
# Assumes that user has at some point joined the room if not is_guest. """ Returns dict of user_id -> list of events that user is allowed to
see.
"""
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_guest):
state = event_id_to_state[event.event_id]
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
def allowed(event, membership, visibility):
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_guest: if is_guest:
return False return False
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
@ -78,43 +116,20 @@ class BaseHandler(object):
return True return True
event_id_to_state = yield self.store.get_state_for_events( defer.returnValue({
frozenset(e.event_id for e in events), user_id: [
types=( event
(EventTypes.RoomHistoryVisibility, ""), for event in events
(EventTypes.Member, user_id), if allowed(event, user_id, is_guest)
) ]
) for user_id, is_guest in users
})
events_to_return = [] @defer.inlineCallbacks
for event in events: def _filter_events_for_client(self, user_id, events, is_guest=False):
state = event_id_to_state[event.event_id] # Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
membership_event = state.get((EventTypes.Member, user_id), None) defer.returnValue(res.get(user_id, []))
if membership_event:
was_forgotten_at_event = yield self.store.was_forgotten_at(
membership_event.state_key,
membership_event.room_id,
membership_event.event_id
)
if was_forgotten_at_event:
membership = None
else:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()

View file

@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f) yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all() self.was_forgotten_at.invalidate_all()
self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id)) self.did_forget.invalidate((user_id, room_id))
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
return rows[0][0] return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f) forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1) defer.returnValue(forgot == 1)
@cached()
def who_forgot_in_room(self, room_id):
return self._simple_select_list(
table="room_memberships",
retcols=("user_id", "event_id"),
keyvalues={
"room_id": room_id,
"forgotten": 1,
},
desc="who_forgot"
)