Type hints for the remaining two files in synapse.http. (#11164)

* Teach MyPy that the sentinel context is False

This means that if `ctx: LoggingContextOrSentinel`
then `bool(ctx)` narrows us to `ctx:LoggingContext`, which is a really
neat find!

* Annotate RequestMetrics

- Raise errors for sentry if we use the sentinel context
- Ensure we don't raise an error and carry on, but not recording stats
- Include stack trace in the error case to lower Sean's blood pressure

* Make mypy pass for synapse.http.request_metrics

* Make synapse.http.connectproxyclient pass mypy

Co-authored-by: reivilibre <oliverw@matrix.org>
This commit is contained in:
David Robertson 2021-10-28 14:14:42 +01:00 committed by GitHub
parent a19bf32a03
commit 1bfd141205
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 37 deletions

1
changelog.d/11164.misc Normal file
View file

@ -0,0 +1 @@
Add type hints so that `synapse.http` passes `mypy` checks.

View file

@ -16,6 +16,7 @@ no_implicit_optional = True
files = files =
scripts-dev/sign_json, scripts-dev/sign_json,
synapse/__init__.py,
synapse/api, synapse/api,
synapse/appservice, synapse/appservice,
synapse/config, synapse/config,
@ -31,16 +32,7 @@ files =
synapse/federation, synapse/federation,
synapse/groups, synapse/groups,
synapse/handlers, synapse/handlers,
synapse/http/additional_resource.py, synapse/http,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/srv_resolver.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
synapse/http/proxyagent.py,
synapse/http/servlet.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging, synapse/logging,
synapse/metrics, synapse/metrics,
synapse/module_api, synapse/module_api,

View file

@ -84,7 +84,11 @@ class HTTPConnectProxyEndpoint:
def __repr__(self): def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
def connect(self, protocolFactory: ClientFactory): # Mypy encounters a false positive here: it complains that ClientFactory
# is incompatible with IProtocolFactory. But ClientFactory inherits from
# Factory, which implements IProtocolFactory. So I think this is a bug
# in mypy-zope.
def connect(self, protocolFactory: ClientFactory): # type: ignore[override]
f = HTTPProxiedClientFactory( f = HTTPProxiedClientFactory(
self._host, self._port, protocolFactory, self._proxy_creds self._host, self._port, protocolFactory, self._proxy_creds
) )
@ -119,13 +123,15 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port = dst_port self.dst_port = dst_port
self.wrapped_factory = wrapped_factory self.wrapped_factory = wrapped_factory
self.proxy_creds = proxy_creds self.proxy_creds = proxy_creds
self.on_connection = defer.Deferred() self.on_connection: "defer.Deferred[None]" = defer.Deferred()
def startedConnecting(self, connector): def startedConnecting(self, connector):
return self.wrapped_factory.startedConnecting(connector) return self.wrapped_factory.startedConnecting(connector)
def buildProtocol(self, addr): def buildProtocol(self, addr):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr) wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
if wrapped_protocol is None:
raise TypeError("buildProtocol produced None instead of a Protocol")
return HTTPConnectProtocol( return HTTPConnectProtocol(
self.dst_host, self.dst_host,
@ -235,7 +241,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
self.host = host self.host = host
self.port = port self.port = port
self.proxy_creds = proxy_creds self.proxy_creds = proxy_creds
self.on_connected = defer.Deferred() self.on_connected: "defer.Deferred[None]" = defer.Deferred()
def connectionMade(self): def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT") logger.debug("Connected to proxy, sending CONNECT")

View file

@ -15,6 +15,8 @@
import logging import logging
import threading import threading
import traceback
from typing import Dict, Mapping, Set, Tuple
from prometheus_client.core import Counter, Histogram from prometheus_client.core import Counter, Histogram
@ -105,19 +107,14 @@ in_flight_requests_db_sched_duration = Counter(
["method", "servlet"], ["method", "servlet"],
) )
# The set of all in flight requests, set[RequestMetrics] _in_flight_requests: Set["RequestMetrics"] = set()
_in_flight_requests = set()
# Protects the _in_flight_requests set from concurrent access # Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock() _in_flight_requests_lock = threading.Lock()
def _get_in_flight_counts(): def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
"""Returns a count of all in flight requests by (method, server_name) """Returns a count of all in flight requests by (method, server_name)"""
Returns:
dict[tuple[str, str], int]
"""
# Cast to a list to prevent it changing while the Prometheus # Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics # thread is collecting metrics
with _in_flight_requests_lock: with _in_flight_requests_lock:
@ -127,8 +124,9 @@ def _get_in_flight_counts():
rm.update_metrics() rm.update_metrics()
# Map from (method, name) -> int, the number of in flight requests of that # Map from (method, name) -> int, the number of in flight requests of that
# type # type. The key type is Tuple[str, str], but we leave the length unspecified
counts = {} # for compatability with LaterGauge's annotations.
counts: Dict[Tuple[str, ...], int] = {}
for rm in reqs: for rm in reqs:
key = (rm.method, rm.name) key = (rm.method, rm.name)
counts[key] = counts.get(key, 0) + 1 counts[key] = counts.get(key, 0) + 1
@ -145,15 +143,21 @@ LaterGauge(
class RequestMetrics: class RequestMetrics:
def start(self, time_sec, name, method): def start(self, time_sec: float, name: str, method: str) -> None:
self.start = time_sec self.start_ts = time_sec
self.start_context = current_context() self.start_context = current_context()
self.name = name self.name = name
self.method = method self.method = method
if self.start_context:
# _request_stats records resource usage that we have already added # _request_stats records resource usage that we have already added
# to the "in flight" metrics. # to the "in flight" metrics.
self._request_stats = self.start_context.get_resource_usage() self._request_stats = self.start_context.get_resource_usage()
else:
logger.error(
"Tried to start a RequestMetric from the sentinel context.\n%s",
"".join(traceback.format_stack()),
)
with _in_flight_requests_lock: with _in_flight_requests_lock:
_in_flight_requests.add(self) _in_flight_requests.add(self)
@ -169,12 +173,18 @@ class RequestMetrics:
tag = context.tag tag = context.tag
if context != self.start_context: if context != self.start_context:
logger.warning( logger.error(
"Context have unexpectedly changed %r, %r", "Context have unexpectedly changed %r, %r",
context, context,
self.start_context, self.start_context,
) )
return return
else:
logger.error(
"Trying to stop RequestMetrics in the sentinel context.\n%s",
"".join(traceback.format_stack()),
)
return
response_code = str(response_code) response_code = str(response_code)
@ -183,7 +193,7 @@ class RequestMetrics:
response_count.labels(self.method, self.name, tag).inc() response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe( response_timer.labels(self.method, self.name, tag, response_code).observe(
time_sec - self.start time_sec - self.start_ts
) )
resource_usage = context.get_resource_usage() resource_usage = context.get_resource_usage()
@ -213,6 +223,12 @@ class RequestMetrics:
def update_metrics(self): def update_metrics(self):
"""Updates the in flight metrics with values from this request.""" """Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(
"Tried to update a RequestMetric from the sentinel context.\n%s",
"".join(traceback.format_stack()),
)
return
new_stats = self.start_context.get_resource_usage() new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats diff = new_stats - self._request_stats

View file

@ -220,7 +220,7 @@ class _Sentinel:
self.scope = None self.scope = None
self.tag = None self.tag = None
def __str__(self): def __str__(self) -> str:
return "sentinel" return "sentinel"
def copy_to(self, record): def copy_to(self, record):
@ -241,7 +241,7 @@ class _Sentinel:
def record_event_fetch(self, event_count): def record_event_fetch(self, event_count):
pass pass
def __bool__(self): def __bool__(self) -> Literal[False]:
return False return False

View file

@ -20,7 +20,7 @@ import os
import platform import platform
import threading import threading
import time import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Union from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
import attr import attr
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
@ -67,7 +67,11 @@ class LaterGauge:
labels = attr.ib(hash=False, type=Optional[Iterable[str]]) labels = attr.ib(hash=False, type=Optional[Iterable[str]])
# callback: should either return a value (if there are no labels for this metric), # callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value # or dict mapping from a label tuple to a value
caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]]) caller = attr.ib(
type=Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
]
)
def collect(self): def collect(self):
@ -80,11 +84,11 @@ class LaterGauge:
yield g yield g
return return
if isinstance(calls, dict): if isinstance(calls, (int, float)):
g.add_metric([], calls)
else:
for k, v in calls.items(): for k, v in calls.items():
g.add_metric(k, v) g.add_metric(k, v)
else:
g.add_metric([], calls)
yield g yield g