Enable mypy checking for unreachable code and fix instances. (#8432)

This commit is contained in:
Patrick Cloke 2020-10-01 08:09:18 -04:00 committed by GitHub
parent c1ef579b63
commit 4ff0201e62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 38 additions and 53 deletions

1
changelog.d/8432.misc Normal file
View file

@ -0,0 +1 @@
Check for unreachable code with mypy.

View file

@ -6,6 +6,7 @@ check_untyped_defs = True
show_error_codes = True show_error_codes = True
show_traceback = True show_traceback = True
mypy_path = stubs mypy_path = stubs
warn_unreachable = True
files = files =
synapse/api, synapse/api,
synapse/appservice, synapse/appservice,

View file

@ -18,7 +18,7 @@ import os
import warnings import warnings
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import List from typing import List, Optional
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -177,8 +177,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use" "use_insecure_ssl_client_just_for_testing_do_not_use"
) )
self.tls_certificate = None self.tls_certificate = None # type: Optional[crypto.X509]
self.tls_private_key = None self.tls_private_key = None # type: Optional[crypto.PKey]
def is_disk_cert_valid(self, allow_self_signed=True): def is_disk_cert_valid(self, allow_self_signed=True):
""" """
@ -226,12 +226,12 @@ class TlsConfig(Config):
days_remaining = (expires_on - now).days days_remaining = (expires_on - now).days
return days_remaining return days_remaining
def read_certificate_from_disk(self, require_cert_and_key): def read_certificate_from_disk(self, require_cert_and_key: bool):
""" """
Read the certificates and private key from disk. Read the certificates and private key from disk.
Args: Args:
require_cert_and_key (bool): set to True to throw an error if the certificate require_cert_and_key: set to True to throw an error if the certificate
and key file are not given and key file are not given
""" """
if require_cert_and_key: if require_cert_and_key:
@ -479,13 +479,13 @@ class TlsConfig(Config):
} }
) )
def read_tls_certificate(self): def read_tls_certificate(self) -> crypto.X509:
"""Reads the TLS certificate from the configured file, and returns it """Reads the TLS certificate from the configured file, and returns it
Also checks if it is self-signed, and warns if so Also checks if it is self-signed, and warns if so
Returns: Returns:
OpenSSL.crypto.X509: the certificate The certificate
""" """
cert_path = self.tls_certificate_file cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path) logger.info("Loading TLS certificate from %s", cert_path)
@ -504,11 +504,11 @@ class TlsConfig(Config):
return cert return cert
def read_tls_private_key(self): def read_tls_private_key(self) -> crypto.PKey:
"""Reads the TLS private key from the configured file, and returns it """Reads the TLS private key from the configured file, and returns it
Returns: Returns:
OpenSSL.crypto.PKey: the private key The private key
""" """
private_key_path = self.tls_private_key_file private_key_path = self.tls_private_key_file
logger.info("Loading TLS key from %s", private_key_path) logger.info("Loading TLS key from %s", private_key_path)

View file

@ -22,7 +22,6 @@ from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Match,
Optional, Optional,
Tuple, Tuple,
Union, Union,
@ -825,14 +824,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
return False return False
def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str): if not isinstance(acl_entry, str):
logger.warning( logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
) )
return False return False
regex = glob_to_regex(acl_entry) regex = glob_to_regex(acl_entry)
return regex.match(server_name) return bool(regex.match(server_name))
class FederationHandlerRegistry: class FederationHandlerRegistry:

View file

@ -383,7 +383,7 @@ class DirectoryHandler(BaseHandler):
""" """
creator = await self.store.get_room_alias_creator(alias.to_string()) creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id: if creator == user_id:
return True return True
# Resolve the alias to the corresponding room. # Resolve the alias to the corresponding room.

View file

@ -962,8 +962,6 @@ class RoomCreationHandler(BaseHandler):
try: try:
random_string = stringutils.random_string(18) random_string = stringutils.random_string(18)
gen_room_id = RoomID(random_string, self.hs.hostname).to_string() gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
await self.store.store_room( await self.store.store_room(
room_id=gen_room_id, room_id=gen_room_id,
room_creator_user_id=creator_id, room_creator_user_id=creator_id,

View file

@ -642,7 +642,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def send_membership_event( async def send_membership_event(
self, self,
requester: Requester, requester: Optional[Requester],
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,

View file

@ -87,7 +87,7 @@ class SyncConfig:
class TimelineBatch: class TimelineBatch:
prev_batch = attr.ib(type=StreamToken) prev_batch = attr.ib(type=StreamToken)
events = attr.ib(type=List[EventBase]) events = attr.ib(type=List[EventBase])
limited = attr.ib(bool) limited = attr.ib(type=bool)
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used

View file

@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return callback_return = await raw_callback_return
else: else:
callback_return = raw_callback_return callback_return = raw_callback_return # type: ignore
return callback_return return callback_return
@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return callback_return = await raw_callback_return
else: else:
callback_return = raw_callback_return callback_return = raw_callback_return # type: ignore
return callback_return return callback_return

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os.path import os.path
import sys import sys
@ -89,14 +88,7 @@ class LogContextObserver:
context = current_context() context = current_context()
# Copy the context information to the log event. # Copy the context information to the log event.
if context is not None: context.copy_to_twisted_log_entry(event)
context.copy_to_twisted_log_entry(event)
else:
# If there's no logging context, not even the root one, we might be
# starting up or it might be from non-Synapse code. Log it as if it
# came from the root logger.
event["request"] = None
event["scope"] = None
self.observer(event) self.observer(event)

View file

@ -16,7 +16,7 @@
import logging import logging
import re import re
from typing import Any, Dict, List, Pattern, Union from typing import Any, Dict, List, Optional, Pattern, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import UserID from synapse.types import UserID
@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent:
return r.search(body) return r.search(body)
def _get_value(self, dotted_key: str) -> str: def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None) return self._value_cache.get(dotted_key, None)

View file

@ -51,10 +51,11 @@ import fcntl
import logging import logging
import struct import struct
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec() self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0 self.last_sent_command = 0
self.time_we_closed = None # When we requested the connection be closed # When we requested the connection be closed
self.time_we_closed = None # type: Optional[int]
self.received_ping = False # Have we reecived a ping from the other side self.received_ping = False # Have we received a ping from the other side
self.state = ConnectionStates.CONNECTING self.state = ConnectionStates.CONNECTING
@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.pending_commands = [] # type: List[Command] self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings. # The LoopingCall for sending pings.
self._send_ping_loop = None self._send_ping_loop = None # type: Optional[task.LoopingCall]
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus. # background process so that the CPU stats get reported to prometheus.

View file

@ -738,7 +738,7 @@ def _make_state_cache_entry(
# failing that, look for the closest match. # failing that, look for the closest match.
prev_group = None prev_group = None
delta_ids = None delta_ids = None # type: Optional[StateMap[str]]
for old_group, old_state in state_groups_ids.items(): for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}

View file

@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events import encode_json
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.frozenutils import frozendict_json_encoder
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted() and original_event.internal_metadata.is_redacted()
): ):
# Redaction was allowed # Redaction was allowed
pruned_json = encode_json( pruned_json = frozendict_json_encoder.encode(
prune_event_dict( prune_event_dict(
original_event.room_version, original_event.get_dict() original_event.room_version, original_event.get_dict()
) )
@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return return
# Prune the event's dict then convert it to JSON. # Prune the event's dict then convert it to JSON.
pruned_json = encode_json( pruned_json = frozendict_json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict()) prune_event_dict(event.room_version, event.get_dict())
) )

View file

@ -52,16 +52,6 @@ event_counter = Counter(
) )
def encode_json(json_object):
"""
Encode a Python object as JSON and return it in a Unicode string.
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
out = out.decode("utf8")
return out
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@ -743,7 +733,9 @@ class PersistEventsStore:
logger.exception("") logger.exception("")
raise raise
metadata_json = encode_json(event.internal_metadata.get_dict()) metadata_json = frozendict_json_encoder.encode(
event.internal_metadata.get_dict()
)
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id)) txn.execute(sql, (metadata_json, event.event_id))
@ -797,10 +789,10 @@ class PersistEventsStore:
{ {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": encode_json( "internal_metadata": frozendict_json_encoder.encode(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
), ),
"json": encode_json(event_dict(event)), "json": frozendict_json_encoder.encode(event_dict(event)),
"format_version": event.format_version, "format_version": event.format_version,
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts

View file

@ -546,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_room_event_before_stream_ordering( async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int self, room_id: str, stream_ordering: int
) -> Tuple[int, int, str]: ) -> Optional[Tuple[int, int, str]]:
"""Gets details of the first event in a room at or before a stream ordering """Gets details of the first event in a room at or before a stream ordering
Args: Args:

View file

@ -421,7 +421,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id) self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id) self._finished_ids.add(next_id)
new_cur = None new_cur = None # type: Optional[int]
if self._unfinished_ids: if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the # If there are unfinished IDs then the new position will be the