0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 07:00:05 +01:00

Easy type hints in synapse.logging.opentracing (#12894)

This commit is contained in:
David Robertson 2022-05-27 11:17:33 +01:00 committed by GitHub
parent f1605b7447
commit 3503f42741
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 57 deletions

1
changelog.d/12894.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations to `synapse.logging.opentracing`.

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Set from typing import Any, List, Set
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements from synapse.util.check_dependencies import DependencyException, check_requirements
@ -49,7 +49,9 @@ class TracerConfig(Config):
# The tracer is enabled so sanitize the config # The tracer is enabled so sanitize the config
self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", []) self.opentracer_whitelist: List[str] = opentracing_config.get(
"homeserver_whitelist", []
)
if not isinstance(self.opentracer_whitelist, list): if not isinstance(self.opentracer_whitelist, list):
raise ConfigError("Tracer homeserver_whitelist config is malformed") raise ConfigError("Tracer homeserver_whitelist config is malformed")

View file

@ -168,9 +168,24 @@ import inspect
import logging import logging
import re import re
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Optional,
Pattern,
Type,
TypeVar,
Union,
)
import attr import attr
from typing_extensions import ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.web.http import Request from twisted.web.http import Request
@ -256,7 +271,7 @@ try:
def set_process(self, *args, **kwargs): def set_process(self, *args, **kwargs):
return self._reporter.set_process(*args, **kwargs) return self._reporter.set_process(*args, **kwargs)
def report_span(self, span): def report_span(self, span: "opentracing.Span") -> None:
try: try:
return self._reporter.report_span(span) return self._reporter.report_span(span)
except Exception: except Exception:
@ -307,15 +322,19 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
Sentinel = object() Sentinel = object()
def only_if_tracing(func): P = ParamSpec("P")
R = TypeVar("R")
def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
"""Executes the function only if we're tracing. Otherwise returns None.""" """Executes the function only if we're tracing. Otherwise returns None."""
@wraps(func) @wraps(func)
def _only_if_tracing_inner(*args, **kwargs): def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
if opentracing: if opentracing:
return func(*args, **kwargs) return func(*args, **kwargs)
else: else:
return return None
return _only_if_tracing_inner return _only_if_tracing_inner
@ -356,17 +375,10 @@ def ensure_active_span(message, ret=None):
return ensure_active_span_inner_1 return ensure_active_span_inner_1
@contextlib.contextmanager
def noop_context_manager(*args, **kwargs):
"""Does exactly what it says on the tin"""
# TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
yield
# Setup # Setup
def init_tracer(hs: "HomeServer"): def init_tracer(hs: "HomeServer") -> None:
"""Set the whitelists and initialise the JaegerClient tracer""" """Set the whitelists and initialise the JaegerClient tracer"""
global opentracing global opentracing
if not hs.config.tracing.opentracer_enabled: if not hs.config.tracing.opentracer_enabled:
@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"):
@only_if_tracing @only_if_tracing
def set_homeserver_whitelist(homeserver_whitelist): def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
"""Sets the homeserver whitelist """Sets the homeserver whitelist
Args: Args:
homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers homeserver_whitelist: regexes specifying whitelisted homeservers
""" """
global _homeserver_whitelist global _homeserver_whitelist
if homeserver_whitelist: if homeserver_whitelist:
@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist):
@only_if_tracing @only_if_tracing
def whitelisted_homeserver(destination): def whitelisted_homeserver(destination: str) -> bool:
"""Checks if a destination matches the whitelist """Checks if a destination matches the whitelist
Args: Args:
destination (str) destination
""" """
if _homeserver_whitelist: if _homeserver_whitelist:
return _homeserver_whitelist.match(destination) return _homeserver_whitelist.match(destination) is not None
return False return False
@ -457,11 +469,11 @@ def start_active_span(
Args: Args:
See opentracing.tracer See opentracing.tracer
Returns: Returns:
scope (Scope) or noop_context_manager scope (Scope) or contextlib.nullcontext
""" """
if opentracing is None: if opentracing is None:
return noop_context_manager() # type: ignore[unreachable] return contextlib.nullcontext() # type: ignore[unreachable]
if tracer is None: if tracer is None:
# use the global tracer by default # use the global tracer by default
@ -505,7 +517,7 @@ def start_active_span_follows_from(
tracer: override the opentracing tracer. By default the global tracer is used. tracer: override the opentracing tracer. By default the global tracer is used.
""" """
if opentracing is None: if opentracing is None:
return noop_context_manager() # type: ignore[unreachable] return contextlib.nullcontext() # type: ignore[unreachable]
references = [opentracing.follows_from(context) for context in contexts] references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span( scope = start_active_span(
@ -525,19 +537,19 @@ def start_active_span_follows_from(
def start_active_span_from_edu( def start_active_span_from_edu(
edu_content, edu_content: Dict[str, Any],
operation_name, operation_name: str,
references: Optional[list] = None, references: Optional[List["opentracing.Reference"]] = None,
tags=None, tags: Optional[Dict] = None,
start_time=None, start_time: Optional[float] = None,
ignore_active_span=False, ignore_active_span: bool = False,
finish_on_close=True, finish_on_close: bool = True,
): ) -> "opentracing.Scope":
""" """
Extracts a span context from an edu and uses it to start a new active span Extracts a span context from an edu and uses it to start a new active span
Args: Args:
edu_content (dict): and edu_content with a `context` field whose value is edu_content: an edu_content with a `context` field whose value is
canonical json for a dict which contains opentracing information. canonical json for a dict which contains opentracing information.
For the other args see opentracing.tracer For the other args see opentracing.tracer
@ -545,7 +557,7 @@ def start_active_span_from_edu(
references = references or [] references = references or []
if opentracing is None: if opentracing is None:
return noop_context_manager() # type: ignore[unreachable] return contextlib.nullcontext() # type: ignore[unreachable]
carrier = json_decoder.decode(edu_content.get("context", "{}")).get( carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {} "opentracing", {}
@ -578,27 +590,27 @@ def start_active_span_from_edu(
# Opentracing setters for tags, logs, etc # Opentracing setters for tags, logs, etc
@only_if_tracing @only_if_tracing
def active_span(): def active_span() -> Optional["opentracing.Span"]:
"""Get the currently active span, if any""" """Get the currently active span, if any"""
return opentracing.tracer.active_span return opentracing.tracer.active_span
@ensure_active_span("set a tag") @ensure_active_span("set a tag")
def set_tag(key, value): def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
"""Sets a tag on the active span""" """Sets a tag on the active span"""
assert opentracing.tracer.active_span is not None assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value) opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log") @ensure_active_span("log")
def log_kv(key_values, timestamp=None): def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
"""Log to the active span""" """Log to the active span"""
assert opentracing.tracer.active_span is not None assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp) opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name") @ensure_active_span("set the traces operation name")
def set_operation_name(operation_name): def set_operation_name(operation_name: str) -> None:
"""Sets the operation name of the active span""" """Sets the operation name of the active span"""
assert opentracing.tracer.active_span is not None assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name) opentracing.tracer.active_span.set_operation_name(operation_name)
@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None:
span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
def is_context_forced_tracing(span_context) -> bool: def is_context_forced_tracing(
span_context: Optional["opentracing.SpanContext"],
) -> bool:
"""Check if sampling has been force for the given span context.""" """Check if sampling has been force for the given span context."""
if span_context is None: if span_context is None:
return False return False
@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None:
@ensure_active_span("get the active span context as a dict", ret={}) @ensure_active_span("get the active span context as a dict", ret={})
def get_active_span_text_map(destination=None): def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
""" """
Gets a span context as a dict. This can be used instead of manually Gets a span context as a dict. This can be used instead of manually
injecting a span into an empty carrier. injecting a span into an empty carrier.
Args: Args:
destination (str): the name of the remote server. destination: the name of the remote server.
Returns: Returns:
dict: the active span's context if opentracing is enabled, otherwise empty. dict: the active span's context if opentracing is enabled, otherwise empty.
@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None):
@ensure_active_span("get the span context as a string.", ret={}) @ensure_active_span("get the span context as a string.", ret={})
def active_span_context_as_string(): def active_span_context_as_string() -> str:
""" """
Returns: Returns:
The active span context encoded as a string. The active span context encoded as a string.
@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
@only_if_tracing @only_if_tracing
def span_context_from_string(carrier): def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]:
""" """
Returns: Returns:
The active span context decoded from a string. The active span context decoded from a string.
""" """
carrier = json_decoder.decode(carrier) payload: Dict[str, str] = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload)
@only_if_tracing @only_if_tracing
def extract_text_map(carrier): def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]:
""" """
Wrapper method for opentracing's tracer.extract for TEXT_MAP. Wrapper method for opentracing's tracer.extract for TEXT_MAP.
Args: Args:
carrier (dict): a dict possibly containing a span context. carrier: a dict possibly containing a span context.
Returns: Returns:
The active span context extracted from carrier. The active span context extracted from carrier.
@ -843,7 +857,7 @@ def trace(func=None, opname=None):
return decorator return decorator
def tag_args(func): def tag_args(func: Callable[P, R]) -> Callable[P, R]:
""" """
Tags all of the args to the active span. Tags all of the args to the active span.
""" """
@ -852,11 +866,11 @@ def tag_args(func):
return func return func
@wraps(func) @wraps(func)
def _tag_args_inner(*args, **kwargs): def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
argspec = inspect.getfullargspec(func) argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]): for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs) set_tag("kwargs", kwargs)
return func(*args, **kwargs) return func(*args, **kwargs)
@ -864,7 +878,9 @@ def tag_args(func):
@contextlib.contextmanager @contextlib.contextmanager
def trace_servlet(request: "SynapseRequest", extract_context: bool = False): def trace_servlet(
request: "SynapseRequest", extract_context: bool = False
) -> Generator[None, None, None]:
"""Returns a context manager which traces a request. It starts a span """Returns a context manager which traces a request. It starts a span
with some servlet specific tags such as the request metrics name and with some servlet specific tags such as the request metrics name and
request information. request information.

View file

@ -14,6 +14,7 @@
import logging import logging
import threading import threading
from contextlib import nullcontext
from functools import wraps from functools import wraps
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
@ -41,11 +42,7 @@ from synapse.logging.context import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
) )
from synapse.logging.opentracing import ( from synapse.logging.opentracing import SynapseTags, start_active_span
SynapseTags,
noop_context_manager,
start_active_span,
)
from synapse.metrics._types import Collector from synapse.metrics._types import Collector
if TYPE_CHECKING: if TYPE_CHECKING:
@ -238,7 +235,7 @@ def run_as_background_process(
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)} f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
) )
else: else:
ctx = noop_context_manager() ctx = nullcontext()
with ctx: with ctx:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except Exception: except Exception: