0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-02 18:59:04 +02:00

Add type hints to admin and room list handlers. (#8973)

This commit is contained in:
Patrick Cloke 2020-12-29 17:42:10 -05:00 committed by GitHub
parent 14a7371375
commit 9999eb2d02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 70 deletions

1
changelog.d/8973.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the admin and room list handlers.

View file

@ -25,6 +25,7 @@ files =
synapse/handlers/_base.py, synapse/handlers/_base.py,
synapse/handlers/account_data.py, synapse/handlers/account_data.py,
synapse/handlers/account_validity.py, synapse/handlers/account_validity.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py, synapse/handlers/appservice.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
@ -45,6 +46,7 @@ files =
synapse/handlers/read_marker.py, synapse/handlers/read_marker.py,
synapse/handlers/register.py, synapse/handlers/register.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_list.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py, synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py, synapse/handlers/saml_handler.py,
@ -114,6 +116,9 @@ ignore_missing_imports = True
[mypy-h11] [mypy-h11]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-msgpack]
ignore_missing_imports = True
[mypy-opentracing] [mypy-opentracing]
ignore_missing_imports = True ignore_missing_imports = True

View file

@ -13,27 +13,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import List from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.events import FrozenEvent from synapse.events import EventBase
from synapse.types import RoomStreamToken, StateMap from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler): class AdminHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
async def get_whois(self, user): async def get_whois(self, user: UserID) -> JsonDict:
connections = [] connections = []
sessions = await self.store.get_user_ip_and_agents(user) sessions = await self.store.get_user_ip_and_agents(user)
@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
return ret return ret
async def get_user(self, user): async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details""" """Function to get user details"""
ret = await self.store.get_user_by_id(user.to_string()) ret = await self.store.get_user_by_id(user.to_string())
if ret: if ret:
@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
ret["threepids"] = threepids ret["threepids"] = threepids
return ret return ret
async def export_user_data(self, user_id, writer): async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
"""Write all data we have on the user to the given writer. """Write all data we have on the user to the given writer.
Args: Args:
user_id (str) user_id: The user ID to fetch data of.
writer (ExfiltrationWriter) writer: The writer to write to.
Returns: Returns:
Resolves when all data for a user has been written. Resolves when all data for a user has been written.
@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
from_key = RoomStreamToken(0, 0) from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering) to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room # Events that we've processed in this room
written_events = set() # type: Set[str]
# We need to track gaps in the events stream so that we can then # We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track # write out the state at those events. We do this by keeping track
@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events # The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen # that have the unseen event in their prev_events, i.e. the unseen
# events "children". dict[str, set[str]] # events "children".
unseen_to_child_events = {} unseen_to_child_events = {} # type: Dict[str, Set[str]]
# We fetch events in the room the user could see by fetching *all* # We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most # events that we have and then filtering, this isn't the most
@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
return writer.finished() return writer.finished()
class ExfiltrationWriter: class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data. """Interface used to specify how to write exported data.
""" """
def write_events(self, room_id: str, events: List[FrozenEvent]): @abc.abstractmethod
def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room. """Write a batch of events for a room.
""" """
pass raise NotImplementedError()
def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): @abc.abstractmethod
def write_state(
self, room_id: str, event_id: str, state: StateMap[EventBase]
) -> None:
"""Write the state at the given event in the room. """Write the state at the given event in the room.
This only gets called for backward extremities rather than for each This only gets called for backward extremities rather than for each
event. event.
""" """
pass raise NotImplementedError()
def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): @abc.abstractmethod
def write_invite(
self, room_id: str, event: EventBase, state: StateMap[dict]
) -> None:
"""Write an invite for the room, with associated invite state. """Write an invite for the room, with associated invite state.
Args: Args:
room_id room_id: The room ID the invite is for.
event event: The invite event.
state: A subset of the state at the state: A subset of the state at the invite, with a subset of the
invite, with a subset of the event keys (type, state_key event keys (type, state_key content and sender).
content and sender)
""" """
raise NotImplementedError()
def finished(self): @abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written. """Called when all data has successfully been exported and written.
This functions return value is passed to the caller of This functions return value is passed to the caller of
`export_user_data`. `export_user_data`.
""" """
pass raise NotImplementedError()

View file

@ -15,19 +15,22 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Optional, Tuple
import msgpack import msgpack
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.api.errors import Codes, HttpResponseException from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(hs, "room_list") self.response_cache = ResponseCache(
hs, "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache( self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000 hs, "remote_room_list", timeout_ms=30 * 1000
) ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list( async def get_local_public_room_list(
self, self,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[dict] = None,
network_tuple=EMPTY_THIRD_PARTY_ID, network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation=False, from_federation: bool = False,
): ) -> JsonDict:
"""Generate a local public room list. """Generate a local public room list.
There are multiple different lists: the main one plus one per third There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all. party network. A client can ask for a specific list or to return all.
Args: Args:
limit (int|None) limit
since_token (str|None) since_token
search_filter (dict|None) search_filter
network_tuple (ThirdPartyInstanceID): Which public list to use. network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
from_federation (bool): true iff the request comes from the federation from_federation: true iff the request comes from the federation API
API
""" """
if not self.enable_room_list_search: if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0} return {"chunk": [], "total_room_count_estimate": 0}
@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
self, self,
limit: Optional[int] = None, limit: Optional[int] = None,
since_token: Optional[str] = None, since_token: Optional[str] = None,
search_filter: Optional[Dict] = None, search_filter: Optional[dict] = None,
network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
from_federation: bool = False, from_federation: bool = False,
) -> Dict[str, Any]: ) -> JsonDict:
"""Generate a public room list. """Generate a public room list.
Args: Args:
limit: Maximum amount of rooms to return. limit: Maximum amount of rooms to return.
@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
if since_token: if since_token:
batch_token = RoomListNextBatch.from_token(since_token) batch_token = RoomListNextBatch.from_token(since_token)
bounds = (batch_token.last_joined_members, batch_token.last_room_id) bounds = (
batch_token.last_joined_members,
batch_token.last_room_id,
) # type: Optional[Tuple[int, str]]
forwards = batch_token.direction_is_forward forwards = batch_token.direction_is_forward
has_batch_token = True
else: else:
batch_token = None
bounds = None bounds = None
forwards = True forwards = True
has_batch_token = False
# we request one more than wanted to see if there are more pages to come # we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None probing_limit = limit + 1 if limit is not None else None
@ -169,7 +177,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results] results = [build_room_entry(r) for r in results]
response = {} response = {} # type: JsonDict
num_results = len(results) num_results = len(results)
if limit is not None: if limit is not None:
more_to_come = num_results == probing_limit more_to_come = num_results == probing_limit
@ -187,7 +195,7 @@ class RoomListHandler(BaseHandler):
initial_entry = results[0] initial_entry = results[0]
if forwards: if forwards:
if batch_token: if has_batch_token:
# If there was a token given then we assume that there # If there was a token given then we assume that there
# must be previous results. # must be previous results.
response["prev_batch"] = RoomListNextBatch( response["prev_batch"] = RoomListNextBatch(
@ -203,7 +211,7 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True, direction_is_forward=True,
).to_token() ).to_token()
else: else:
if batch_token: if has_batch_token:
response["next_batch"] = RoomListNextBatch( response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"], last_joined_members=final_entry["num_joined_members"],
last_room_id=final_entry["room_id"], last_room_id=final_entry["room_id"],
@ -293,7 +301,7 @@ class RoomListHandler(BaseHandler):
return None return None
# Return whether this room is open to federation users or not # Return whether this room is open to federation users or not
create_event = current_state.get((EventTypes.Create, "")) create_event = current_state[EventTypes.Create, ""]
result["m.federate"] = create_event.content.get("m.federate", True) result["m.federate"] = create_event.content.get("m.federate", True)
name_event = current_state.get((EventTypes.Name, "")) name_event = current_state.get((EventTypes.Name, ""))
@ -336,13 +344,13 @@ class RoomListHandler(BaseHandler):
async def get_remote_public_room_list( async def get_remote_public_room_list(
self, self,
server_name, server_name: str,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[dict] = None,
include_all_networks=False, include_all_networks: bool = False,
third_party_instance_id=None, third_party_instance_id: Optional[str] = None,
): ) -> JsonDict:
if not self.enable_room_list_search: if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0} return {"chunk": [], "total_room_count_estimate": 0}
@ -399,13 +407,13 @@ class RoomListHandler(BaseHandler):
async def _get_remote_list_cached( async def _get_remote_list_cached(
self, self,
server_name, server_name: str,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[dict] = None,
include_all_networks=False, include_all_networks: bool = False,
third_party_instance_id=None, third_party_instance_id: Optional[str] = None,
): ) -> JsonDict:
repl_layer = self.hs.get_federation_client() repl_layer = self.hs.get_federation_client()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
@ -456,24 +464,24 @@ class RoomListNextBatch(
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod @classmethod
def from_token(cls, token): def from_token(cls, token: str) -> "RoomListNextBatch":
decoded = msgpack.loads(decode_base64(token), raw=False) decoded = msgpack.loads(decode_base64(token), raw=False)
return RoomListNextBatch( return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
) )
def to_token(self): def to_token(self) -> str:
return encode_base64( return encode_base64(
msgpack.dumps( msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()} {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
) )
) )
def copy_and_replace(self, **kwds): def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
return self._replace(**kwds) return self._replace(**kwds)
def _matches_room_entry(room_entry, search_filter): def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
if search_filter and search_filter.get("generic_search_term", None): if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper() generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper(): if generic_search_term in room_entry.get("name", "").upper():

View file

@ -14,11 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -546,7 +547,9 @@ class ClientIpStore(ClientIpWorkerStore):
} }
return ret return ret
async def get_user_ip_and_agents(self, user): async def get_user_ip_and_agents(
self, user: UserID
) -> List[Dict[str, Union[str, int]]]:
user_id = user.to_string() user_id = user.to_string()
results = {} results = {}