From 3503f42741d438b67490bc774ee2c3a856f6bc81 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 May 2022 11:17:33 +0100 Subject: [PATCH] Easy type hints in synapse.logging.opentracing (#12894) --- changelog.d/12894.misc | 1 + synapse/config/tracer.py | 6 +- synapse/logging/opentracing.py | 114 ++++++++++-------- synapse/metrics/background_process_metrics.py | 9 +- 4 files changed, 73 insertions(+), 57 deletions(-) create mode 100644 changelog.d/12894.misc diff --git a/changelog.d/12894.misc b/changelog.d/12894.misc new file mode 100644 index 000000000..646a62fcc --- /dev/null +++ b/changelog.d/12894.misc @@ -0,0 +1 @@ +Add type annotations to `synapse.logging.opentracing`. diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 3472a9a01..ae68a3dd1 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Set +from typing import Any, List, Set from synapse.types import JsonDict from synapse.util.check_dependencies import DependencyException, check_requirements @@ -49,7 +49,9 @@ class TracerConfig(Config): # The tracer is enabled so sanitize the config - self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", []) + self.opentracer_whitelist: List[str] = opentracing_config.get( + "homeserver_whitelist", [] + ) if not isinstance(self.opentracer_whitelist, list): raise ConfigError("Tracer homeserver_whitelist config is malformed") diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index a02b5bf6b..903ec40c8 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -168,9 +168,24 @@ import inspect import logging import re from functools import wraps -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + Generator, + Iterable, + List, + Optional, + Pattern, + Type, + TypeVar, + Union, +) import attr +from typing_extensions import ParamSpec from twisted.internet import defer from twisted.web.http import Request @@ -256,7 +271,7 @@ try: def set_process(self, *args, **kwargs): return self._reporter.set_process(*args, **kwargs) - def report_span(self, span): + def report_span(self, span: "opentracing.Span") -> None: try: return self._reporter.report_span(span) except Exception: @@ -307,15 +322,19 @@ _homeserver_whitelist: Optional[Pattern[str]] = None Sentinel = object() -def only_if_tracing(func): +P = ParamSpec("P") +R = TypeVar("R") + + +def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: """Executes the function only if we're tracing. Otherwise returns None.""" @wraps(func) - def _only_if_tracing_inner(*args, **kwargs): + def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: if opentracing: return func(*args, **kwargs) else: - return + return None return _only_if_tracing_inner @@ -356,17 +375,10 @@ def ensure_active_span(message, ret=None): return ensure_active_span_inner_1 -@contextlib.contextmanager -def noop_context_manager(*args, **kwargs): - """Does exactly what it says on the tin""" - # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6 - yield - - # Setup -def init_tracer(hs: "HomeServer"): +def init_tracer(hs: "HomeServer") -> None: """Set the whitelists and initialise the JaegerClient tracer""" global opentracing if not hs.config.tracing.opentracer_enabled: @@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"): @only_if_tracing -def set_homeserver_whitelist(homeserver_whitelist): +def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None: """Sets the homeserver whitelist Args: - homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers + homeserver_whitelist: regexes specifying whitelisted homeservers """ global _homeserver_whitelist if homeserver_whitelist: @@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist): @only_if_tracing -def whitelisted_homeserver(destination): +def whitelisted_homeserver(destination: str) -> bool: """Checks if a destination matches the whitelist Args: - destination (str) + destination """ if _homeserver_whitelist: - return _homeserver_whitelist.match(destination) + return _homeserver_whitelist.match(destination) is not None return False @@ -457,11 +469,11 @@ def start_active_span( Args: See opentracing.tracer Returns: - scope (Scope) or noop_context_manager + scope (Scope) or contextlib.nullcontext """ if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] if tracer is None: # use the global tracer by default @@ -505,7 +517,7 @@ def start_active_span_follows_from( tracer: override the opentracing tracer. By default the global tracer is used. """ if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span( @@ -525,19 +537,19 @@ def start_active_span_follows_from( def start_active_span_from_edu( - edu_content, - operation_name, - references: Optional[list] = None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -): + edu_content: Dict[str, Any], + operation_name: str, + references: Optional[List["opentracing.Reference"]] = None, + tags: Optional[Dict] = None, + start_time: Optional[float] = None, + ignore_active_span: bool = False, + finish_on_close: bool = True, +) -> "opentracing.Scope": """ Extracts a span context from an edu and uses it to start a new active span Args: - edu_content (dict): and edu_content with a `context` field whose value is + edu_content: an edu_content with a `context` field whose value is canonical json for a dict which contains opentracing information. For the other args see opentracing.tracer @@ -545,7 +557,7 @@ def start_active_span_from_edu( references = references or [] if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} @@ -578,27 +590,27 @@ def start_active_span_from_edu( # Opentracing setters for tags, logs, etc @only_if_tracing -def active_span(): +def active_span() -> Optional["opentracing.Span"]: """Get the currently active span, if any""" return opentracing.tracer.active_span @ensure_active_span("set a tag") -def set_tag(key, value): +def set_tag(key: str, value: Union[str, bool, int, float]) -> None: """Sets a tag on the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") -def log_kv(key_values, timestamp=None): +def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None: """Log to the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @ensure_active_span("set the traces operation name") -def set_operation_name(operation_name): +def set_operation_name(operation_name: str) -> None: """Sets the operation name of the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_operation_name(operation_name) @@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None: span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") -def is_context_forced_tracing(span_context) -> bool: +def is_context_forced_tracing( + span_context: Optional["opentracing.SpanContext"], +) -> bool: """Check if sampling has been force for the given span context.""" if span_context is None: return False @@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None: @ensure_active_span("get the active span context as a dict", ret={}) -def get_active_span_text_map(destination=None): +def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]: """ Gets a span context as a dict. This can be used instead of manually injecting a span into an empty carrier. Args: - destination (str): the name of the remote server. + destination: the name of the remote server. Returns: dict: the active span's context if opentracing is enabled, otherwise empty. @@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None): @ensure_active_span("get the span context as a string.", ret={}) -def active_span_context_as_string(): +def active_span_context_as_string() -> str: """ Returns: The active span context encoded as a string. @@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon @only_if_tracing -def span_context_from_string(carrier): +def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]: """ Returns: The active span context decoded from a string. """ - carrier = json_decoder.decode(carrier) - return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) + payload: Dict[str, str] = json_decoder.decode(carrier) + return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload) @only_if_tracing -def extract_text_map(carrier): +def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]: """ Wrapper method for opentracing's tracer.extract for TEXT_MAP. Args: - carrier (dict): a dict possibly containing a span context. + carrier: a dict possibly containing a span context. Returns: The active span context extracted from carrier. @@ -843,7 +857,7 @@ def trace(func=None, opname=None): return decorator -def tag_args(func): +def tag_args(func: Callable[P, R]) -> Callable[P, R]: """ Tags all of the args to the active span. """ @@ -852,11 +866,11 @@ def tag_args(func): return func @wraps(func) - def _tag_args_inner(*args, **kwargs): + def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R: argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): - set_tag("ARG_" + arg, args[i]) - set_tag("args", args[len(argspec.args) :]) + set_tag("ARG_" + arg, args[i]) # type: ignore[index] + set_tag("args", args[len(argspec.args) :]) # type: ignore[index] set_tag("kwargs", kwargs) return func(*args, **kwargs) @@ -864,7 +878,9 @@ def tag_args(func): @contextlib.contextmanager -def trace_servlet(request: "SynapseRequest", extract_context: bool = False): +def trace_servlet( + request: "SynapseRequest", extract_context: bool = False +) -> Generator[None, None, None]: """Returns a context manager which traces a request. It starts a span with some servlet specific tags such as the request metrics name and request information. diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 298809742..eef3462e1 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -14,6 +14,7 @@ import logging import threading +from contextlib import nullcontext from functools import wraps from types import TracebackType from typing import ( @@ -41,11 +42,7 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, ) -from synapse.logging.opentracing import ( - SynapseTags, - noop_context_manager, - start_active_span, -) +from synapse.logging.opentracing import SynapseTags, start_active_span from synapse.metrics._types import Collector if TYPE_CHECKING: @@ -238,7 +235,7 @@ def run_as_background_process( f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)} ) else: - ctx = noop_context_manager() + ctx = nullcontext() with ctx: return await func(*args, **kwargs) except Exception: