Fix typing for notifier (#8064)

This commit is contained in:
Erik Johnston 2020-08-12 14:03:08 +01:00 committed by GitHub
parent 6ba621d786
commit 9d1e4942ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 16 deletions

1
changelog.d/8064.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `Notifier`.

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Tuple
from canonicaljson import json from canonicaljson import json
@ -54,7 +54,10 @@ class TransactionManager(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
async def send_new_transaction( async def send_new_transaction(
self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] self,
destination: str,
pending_pdus: List[Tuple[EventBase, int]],
pending_edus: List[Edu],
): ):
# Make a transaction-sending opentracing span. This span follows on from # Make a transaction-sending opentracing span. This span follows on from

View file

@ -25,6 +25,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
TypeVar, TypeVar,
Union,
) )
from prometheus_client import Counter from prometheus_client import Counter
@ -186,7 +187,7 @@ class Notifier(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = ( self.pending_new_room_events = (
[] []
) # type: List[Tuple[int, EventBase, Collection[str]]] ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
# Called when there are new things to stream over replication # Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]] self.replication_callbacks = [] # type: List[Callable[[], None]]
@ -246,7 +247,7 @@ class Notifier(object):
event: EventBase, event: EventBase,
room_stream_id: int, room_stream_id: int,
max_room_stream_id: int, max_room_stream_id: int,
extra_users: Collection[str] = [], extra_users: Collection[Union[str, UserID]] = [],
): ):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
@ -282,7 +283,10 @@ class Notifier(object):
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event( def _on_new_room_event(
self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = [] self,
event: EventBase,
room_stream_id: int,
extra_users: Collection[Union[str, UserID]] = [],
): ):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
@ -310,7 +314,7 @@ class Notifier(object):
self, self,
stream_key: str, stream_key: str,
new_token: int, new_token: int,
users: Collection[str] = [], users: Collection[Union[str, UserID]] = [],
rooms: Collection[str] = [], rooms: Collection[str] = [],
): ):
""" Used to inform listeners that something has happened event wise. """ Used to inform listeners that something has happened event wise.

View file

@ -13,11 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import re import re
import string import string
import sys import sys
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Tuple, TypeVar from typing import Any, Dict, Tuple, Type, TypeVar
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -33,7 +34,7 @@ else:
T_co = TypeVar("T_co", covariant=True) T_co = TypeVar("T_co", covariant=True)
class Collection(Iterable[T_co], Container[T_co], Sized): class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore
__slots__ = () __slots__ = ()
@ -141,6 +142,9 @@ def get_localpart_from_id(string):
return string[1:idx] return string[1:idx]
DS = TypeVar("DS", bound="DomainSpecificString")
class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
"""Common base class among ID/name strings that have a local part and a """Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil. domain name, prefixed with a sigil.
@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
'domain' : The domain part of the name 'domain' : The domain part of the name
""" """
__metaclass__ = abc.ABCMeta
SIGIL = abc.abstractproperty() # type: str # type: ignore
# Deny iteration because it will bite you if you try to create a singleton # Deny iteration because it will bite you if you try to create a singleton
# set by: # set by:
# users = set(user) # users = set(user)
@ -166,7 +174,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
return self return self
@classmethod @classmethod
def from_string(cls, s: str): def from_string(cls: Type[DS], s: str) -> DS:
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL: if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError( raise SynapseError(
@ -190,12 +198,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
# names on one HS # names on one HS
return cls(localpart=parts[0], domain=domain) return cls(localpart=parts[0], domain=domain)
def to_string(self): def to_string(self) -> str:
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod @classmethod
def is_valid(cls, s): def is_valid(cls: Type[DS], s: str) -> bool:
try: try:
cls.from_string(s) cls.from_string(s)
return True return True
@ -235,8 +243,9 @@ class GroupID(DomainSpecificString):
SIGIL = "+" SIGIL = "+"
@classmethod @classmethod
def from_string(cls, s): def from_string(cls: Type[DS], s: str) -> DS:
group_id = super(GroupID, cls).from_string(s) group_id = super().from_string(s) # type: DS # type: ignore
if not group_id.localpart: if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)

View file

@ -15,6 +15,7 @@
import logging import logging
from functools import wraps from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from prometheus_client import Counter from prometheus_client import Counter
@ -57,8 +58,10 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"], sub_metrics=["real_time_max", "real_time_sum"],
) )
T = TypeVar("T", bound=Callable[..., Any])
def measure_func(name=None):
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
""" """
Used to decorate an async function with a `Measure` context manager. Used to decorate an async function with a `Measure` context manager.
@ -76,7 +79,7 @@ def measure_func(name=None):
""" """
def wrapper(func): def wrapper(func: T) -> T:
block_name = func.__name__ if name is None else name block_name = func.__name__ if name is None else name
@wraps(func) @wraps(func)
@ -85,7 +88,7 @@ def measure_func(name=None):
r = await func(self, *args, **kwargs) r = await func(self, *args, **kwargs)
return r return r
return measured_func return cast(T, measured_func)
return wrapper return wrapper

View file

@ -212,7 +212,9 @@ commands = mypy \
synapse/storage/state.py \ synapse/storage/state.py \
synapse/storage/util \ synapse/storage/util \
synapse/streams \ synapse/streams \
synapse/types.py \
synapse/util/caches/stream_change_cache.py \ synapse/util/caches/stream_change_cache.py \
synapse/util/metrics.py \
tests/replication \ tests/replication \
tests/test_utils \ tests/test_utils \
tests/rest/client/v2_alpha/test_auth.py \ tests/rest/client/v2_alpha/test_auth.py \