Make Filter a Generic.

This commit is contained in:
Patrick Cloke 2021-10-26 13:28:32 -04:00
parent c7a5e49664
commit 86c6dc5540
2 changed files with 27 additions and 13 deletions

View file

@ -19,6 +19,7 @@ from typing import (
TYPE_CHECKING,
Awaitable,
Container,
Generic,
Iterable,
List,
Optional,
@ -186,6 +187,7 @@ class Filtering:
# Filters work across events, presence EDUs, and account data.
FilterRoomEvent = TypeVar("FilterRoomEvent", EventBase, JsonDict)
FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
@ -195,16 +197,28 @@ class FilterCollection:
room_filter_json = self._filter_json.get("room", {})
self._room_filter = Filter(
self._room_filter: Filter[FilterRoomEvent] = Filter(
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
)
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
self._room_timeline_filter: Filter[EventBase] = Filter(
room_filter_json.get("timeline", {})
)
self._room_state_filter: Filter[EventBase] = Filter(
room_filter_json.get("state", {})
)
self._room_ephemeral_filter: Filter[JsonDict] = Filter(
room_filter_json.get("ephemeral", {})
)
self._room_account_data: Filter[JsonDict] = Filter(
room_filter_json.get("account_data", {})
)
self._presence_filter: Filter[UserPresenceState] = Filter(
filter_json.get("presence", {})
)
self._account_data: Filter[JsonDict] = Filter(
filter_json.get("account_data", {})
)
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
@ -272,7 +286,7 @@ class FilterCollection:
)
class Filter:
class Filter(Generic[FilterEvent]):
def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json
@ -406,7 +420,7 @@ class Filter:
def include_redundant_members(self) -> bool:
return self.filter_json.get("include_redundant_members", False)
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter[FilterEvent]":
"""Returns a new filter with the given room IDs appended.
Args:
@ -416,9 +430,9 @@ class Filter:
filter: A new filter including the given rooms and the old
filter's rooms.
"""
newFilter = Filter(self.filter_json)
newFilter.rooms += room_ids
return newFilter
new_filter: Filter[FilterEvent] = Filter(self.filter_json)
new_filter.rooms += room_ids
return new_filter
def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:

View file

@ -180,7 +180,7 @@ class SearchHandler:
% (set(group_keys) - {"room_id", "sender"},),
)
search_filter = Filter(filter_dict)
search_filter: Filter[EventBase] = Filter(filter_dict)
# TODO: Search through left rooms too
rooms = await self.store.get_rooms_for_local_user_where_membership_is(