Fix up some typechecking (#6150)

* type checking fixes

* changelog
This commit is contained in:
Amber Brown 2019-10-02 05:29:01 -07:00 committed by GitHub
parent 2a1470cd05
commit 864f144543
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 104 additions and 40 deletions

1
.gitignore vendored
View file

@ -10,6 +10,7 @@
*.tac
_trial_temp/
_trial_temp*/
/out
# stuff that is likely to exist when you run a server locally
/*.db

1
changelog.d/6150.misc Normal file
View file

@ -0,0 +1 @@
Expand type-checking on modules imported by synapse.config.

View file

@ -17,6 +17,7 @@
"""Contains exceptions and error codes."""
import logging
from typing import Dict
from six import iteritems
from six.moves import http_client
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {}
self._additional_fields = {} # type: Dict
else:
self._additional_fields = dict(additional_fields)

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import attr
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
)
} # type: dict[str, RoomVersion]
} # type: Dict[str, RoomVersion]

View file

@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs)
# Start the tracer
synapse.logging.opentracing.init_tracer(hs.config)
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
)
# It is now safe to start your Synapse.
hs.start_listening(listeners)

View file

@ -13,6 +13,7 @@
# limitations under the License.
import logging
from typing import Dict
from six import string_types
from six.moves.urllib import parse as urlparse
@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
seen_as_tokens = {} # type: Dict[str, str]
seen_ids = {} # type: Dict[str, str]
appservices = []

View file

@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config):
def __init__(self):
super(ConsentConfig, self).__init__()
def __init__(self, *args):
super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None
self.user_consent_template_dir = None

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from synapse.util.module_loader import load_module
from ._base import Config
@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
def read_config(self, config, **kwargs):
self.password_providers = []
self.password_providers = [] # type: List[Any]
providers = []
# We want to be backwards compatible with the old `ldap_config`

View file

@ -15,6 +15,7 @@
import os
from collections import namedtuple
from typing import Dict, List
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements = {}
requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers = []
self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to

View file

@ -19,6 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
from typing import List
import attr
import yaml
@ -243,7 +244,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
self.listeners = []
self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
raise ConfigError(
@ -287,7 +288,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False
)
complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0
validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
default=1.0,
)
complexity_error = attr.ib(
validator=attr.validators.instance_of(str),
@ -366,7 +370,7 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True
)
def has_tls_listener(self):
def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)
def generate_config_section(

View file

@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
"""
def __init__(self):
super(ServerNoticesConfig, self).__init__()
def __init__(self, *args):
super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None

View file

@ -170,6 +170,7 @@ import inspect
import logging
import re
from functools import wraps
from typing import Dict
from canonicaljson import json
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return
span = opentracing.tracer.active_span
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@ -653,7 +654,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
carrier = {}
carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier

View file

@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG
s = inspect.currentframe().f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
args=None,
args=tuple(),
exc_info=None,
)
@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames():
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):

View file

@ -125,7 +125,7 @@ class InFlightGauge(object):
)
# Counts number of in flight blocks for a given set of label values
self._registrations = {}
self._registrations = {} # type: Dict
# Protects access to _registrations
self._lock = threading.Lock()
@ -226,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous!
data = self.data_collector()
buckets = {}
buckets = {} # type: Dict[float, int]
res = []
for x in data.keys():

View file

@ -36,9 +36,9 @@ from twisted.web.resource import Resource
try:
from prometheus_client.samples import Sample
except ImportError:
Sample = namedtuple(
Sample = namedtuple( # type: ignore[no-redef] # noqa
"Sample", ["name", "labels", "value", "timestamp", "exemplar"]
) # type: ignore
)
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Set
from typing import List, Set
from pkg_resources import (
DistributionNotFound,
@ -73,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
]
CONDITIONAL_REQUIREMENTS = {
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version)
% (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
)
except DistributionNotFound:
deps_needed.append(dependency)
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be
# installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS:
try:
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed optional %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version)
% (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
)
except DistributionNotFound:
# If it's not found, we don't care

View file

@ -318,6 +318,7 @@ class StreamToken(
)
):
_SEPARATOR = "_"
START = None # type: StreamToken
@classmethod
def from_string(cls, string):
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
__slots__ = [] # type: list
@classmethod
def parse(cls, string):

View file

@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
from six.moves import range
@ -213,7 +215,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = {}
self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@ -340,10 +344,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
self.key_to_current_readers = {}
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
self.key_to_current_writer = {}
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):

View file

@ -16,6 +16,7 @@
import logging
import os
from typing import Dict
import six
from six.moves import intern
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
collectors_by_name = {}
collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View file

@ -18,10 +18,12 @@ import inspect
import logging
import threading
from collections import namedtuple
from typing import Any, cast
from six import itervalues
from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer
@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__)
class _CachedFunction(Protocol):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig
if inlineCallbacks:
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer)
wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)

View file

@ -1,3 +1,5 @@
from typing import Dict
from six import itervalues
SENTINEL = object()
@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self):
self.size = 0
self.root = {}
self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)

View file

@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
spec.loader.exec_module(mod) # type: ignore
return mod