0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 15:01:23 +01:00

Add is_encrypted filtering to Sliding Sync /sync (#17281)

Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
This commit is contained in:
Eric Eastwood 2024-06-17 12:06:18 -05:00 committed by GitHub
parent e5b8a3e37f
commit a5485437cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 189 additions and 100 deletions

View file

@ -0,0 +1 @@
Add `is_encrypted` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View file

@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from immutabledict import immutabledict from immutabledict import immutabledict
from synapse.api.constants import AccountDataTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import ( from synapse.types import (
@ -33,6 +33,7 @@ from synapse.types import (
UserID, UserID,
) )
from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
from synapse.types.state import StateFilter
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -85,6 +86,7 @@ class SlidingSyncHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
self.auth_blocking = hs.get_auth_blocking() self.auth_blocking = hs.get_auth_blocking()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
@ -570,8 +572,26 @@ class SlidingSyncHandler:
if filters.spaces: if filters.spaces:
raise NotImplementedError() raise NotImplementedError()
if filters.is_encrypted: # Filter for encrypted rooms
raise NotImplementedError() if filters.is_encrypted is not None:
# Make a copy so we don't run into an error: `Set changed size during
# iteration`, when we filter out and remove items
for room_id in list(filtered_room_id_set):
state_at_to_token = await self.storage_controllers.state.get_state_at(
room_id,
to_token,
state_filter=StateFilter.from_types(
[(EventTypes.RoomEncryption, "")]
),
)
is_encrypted = state_at_to_token.get((EventTypes.RoomEncryption, ""))
# If we're looking for encrypted rooms, filter out rooms that are not
# encrypted and vice versa
if (filters.is_encrypted and not is_encrypted) or (
not filters.is_encrypted and is_encrypted
):
filtered_room_id_set.remove(room_id)
if filters.is_invite: if filters.is_invite:
raise NotImplementedError() raise NotImplementedError()

View file

@ -979,91 +979,6 @@ class SyncHandler:
bundled_aggregations=bundled_aggregations, bundled_aggregations=bundled_aggregations,
) )
async def get_state_after_event(
self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
"""
state_ids = await self._state_storage_controller.get_state_ids_for_event(
event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
# with redactions: if `event_id` is a redaction event, and we don't have the
# original (possibly because it got purged), get_event will refuse to return
# the redaction event, which isn't terribly helpful here.
#
# (To be fair, in that case we could assume it's *not* a state event, and
# therefore we don't need to worry about it. But still, it seems cleaner just
# to pull the metadata.)
m = (await self.store.get_metadata_for_events([event_id]))[event_id]
if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
state_ids[(m.event_type, m.state_key)] = event_id
return state_ids
async def get_state_at(
self,
room_id: str,
stream_position: StreamToken,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""Get the room state at a particular stream position
Args:
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`.
"""
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
last_event_id = (
await self.store.get_last_event_id_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
)
if last_event_id:
state = await self.get_state_after_event(
last_event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
else:
# no events in this room - so presumably no state
state = {}
# (erikj) This should be rarely hit, but we've had some reports that
# we get more state down gappy syncs than we should, so let's add
# some logging.
logger.info(
"Failed to find any events in room %s at %s",
room_id,
stream_position.room_key,
)
return state
async def compute_summary( async def compute_summary(
self, self,
room_id: str, room_id: str,
@ -1437,7 +1352,7 @@ class SyncHandler:
await_full_state = True await_full_state = True
lazy_load_members = False lazy_load_members = False
state_at_timeline_end = await self.get_state_at( state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id, room_id,
stream_position=end_token, stream_position=end_token,
state_filter=state_filter, state_filter=state_filter,
@ -1565,7 +1480,7 @@ class SyncHandler:
else: else:
# We can get here if the user has ignored the senders of all # We can get here if the user has ignored the senders of all
# the recent events. # the recent events.
state_at_timeline_start = await self.get_state_at( state_at_timeline_start = await self._state_storage_controller.get_state_at(
room_id, room_id,
stream_position=end_token, stream_position=end_token,
state_filter=state_filter, state_filter=state_filter,
@ -1587,14 +1502,14 @@ class SyncHandler:
# about them). # about them).
state_filter = StateFilter.all() state_filter = StateFilter.all()
state_at_previous_sync = await self.get_state_at( state_at_previous_sync = await self._state_storage_controller.get_state_at(
room_id, room_id,
stream_position=since_token, stream_position=since_token,
state_filter=state_filter, state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
state_at_timeline_end = await self.get_state_at( state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id, room_id,
stream_position=end_token, stream_position=end_token,
state_filter=state_filter, state_filter=state_filter,
@ -2593,7 +2508,7 @@ class SyncHandler:
continue continue
if room_id in sync_result_builder.joined_room_ids or has_join: if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = await self.get_state_at( old_state_ids = await self._state_storage_controller.get_state_at(
room_id, room_id,
since_token, since_token,
state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]), state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
@ -2623,12 +2538,14 @@ class SyncHandler:
newly_left_rooms.append(room_id) newly_left_rooms.append(room_id)
else: else:
if not old_state_ids: if not old_state_ids:
old_state_ids = await self.get_state_at( old_state_ids = (
room_id, await self._state_storage_controller.get_state_at(
since_token, room_id,
state_filter=StateFilter.from_types( since_token,
[(EventTypes.Member, user_id)] state_filter=StateFilter.from_types(
), [(EventTypes.Member, user_id)]
),
)
) )
old_mem_ev_id = old_state_ids.get( old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None (EventTypes.Member, user_id), None

View file

@ -45,7 +45,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker, PartialStateEventsTracker,
) )
from synapse.synapse_rust.acl import ServerAclEvaluator from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import MutableStateMap, StateMap, get_domain_from_id from synapse.types import MutableStateMap, StateMap, StreamToken, get_domain_from_id
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
@ -372,6 +372,91 @@ class StateStorageController:
) )
return state_map[event_id] return state_map[event_id]
async def get_state_after_event(
self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
"""
state_ids = await self.get_state_ids_for_event(
event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
# with redactions: if `event_id` is a redaction event, and we don't have the
# original (possibly because it got purged), get_event will refuse to return
# the redaction event, which isn't terribly helpful here.
#
# (To be fair, in that case we could assume it's *not* a state event, and
# therefore we don't need to worry about it. But still, it seems cleaner just
# to pull the metadata.)
m = (await self.stores.main.get_metadata_for_events([event_id]))[event_id]
if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
state_ids[(m.event_type, m.state_key)] = event_id
return state_ids
async def get_state_at(
self,
room_id: str,
stream_position: StreamToken,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""Get the room state at a particular stream position
Args:
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`.
"""
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
last_event_id = (
await self.stores.main.get_last_event_id_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
)
if last_event_id:
state = await self.get_state_after_event(
last_event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
else:
# no events in this room - so presumably no state
state = {}
# (erikj) This should be rarely hit, but we've had some reports that
# we get more state down gappy syncs than we should, so let's add
# some logging.
logger.info(
"Failed to find any events in room %s at %s",
room_id,
stream_position.room_key,
)
return state
@trace @trace
@tag_args @tag_args
async def get_state_for_groups( async def get_state_for_groups(

View file

@ -1253,6 +1253,72 @@ class FilterRoomsTestCase(HomeserverTestCase):
self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
def test_filter_encrypted_rooms(self) -> None:
"""
Test `filter.is_encrypted` for encrypted rooms
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create a normal room
room_id = self.helper.create_room_as(
user1_id,
is_public=False,
tok=user1_tok,
)
# Create an encrypted room
encrypted_room_id = self.helper.create_room_as(
user1_id,
is_public=False,
tok=user1_tok,
)
self.helper.send_state(
encrypted_room_id,
EventTypes.RoomEncryption,
{"algorithm": "m.megolm.v1.aes-sha2"},
tok=user1_tok,
)
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
sync_room_map = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
)
)
# Try with `is_encrypted=True`
truthy_filtered_room_map = self.get_success(
self.sliding_sync_handler.filter_rooms(
UserID.from_string(user1_id),
sync_room_map,
SlidingSyncConfig.SlidingSyncList.Filters(
is_encrypted=True,
),
after_rooms_token,
)
)
self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
# Try with `is_encrypted=False`
falsy_filtered_room_map = self.get_success(
self.sliding_sync_handler.filter_rooms(
UserID.from_string(user1_id),
sync_room_map,
SlidingSyncConfig.SlidingSyncList.Filters(
is_encrypted=False,
),
after_rooms_token,
)
)
self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
class SortRoomsTestCase(HomeserverTestCase): class SortRoomsTestCase(HomeserverTestCase):
""" """