From d4f72a5bfb95d07d5af3f49c736823840659101a Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Wed, 3 Feb 2016 13:51:25 +0000
Subject: [PATCH] Allowing tagging log contexts

---
 synapse/handlers/sync.py   | 10 ++++++++++
 synapse/http/server.py     | 41 ++++++++++++++++++++++++--------------
 synapse/util/logcontext.py |  7 ++++++-
 3 files changed, 42 insertions(+), 16 deletions(-)

diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index dc686db54..72ccaf1e3 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,6 +18,7 @@ from ._base import BaseHandler
 from synapse.streams.config import PaginationConfig
 from synapse.api.constants import Membership, EventTypes
 from synapse.util import unwrapFirstError
+from synapse.util.logcontext import LoggingContext
 
 from twisted.internet import defer
 
@@ -140,6 +141,15 @@ class SyncHandler(BaseHandler):
             A Deferred SyncResult.
         """
 
+        context = LoggingContext.current_context()
+        if context:
+            if since_token is None:
+                context.tag = "initial_sync"
+            elif full_state:
+                context.tag = "full_state_sync"
+            else:
+                context.tag = "incremental_sync"
+
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 10d1fcd3f..c250a4604 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
 
 incoming_requests_counter = metrics.register_counter(
     "requests",
-    labels=["method", "servlet"],
+    labels=["method", "servlet", "tag"],
 )
 outgoing_responses_counter = metrics.register_counter(
     "responses",
@@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
 
 response_timer = metrics.register_distribution(
     "response_time",
-    labels=["method", "servlet"]
+    labels=["method", "servlet", "tag"]
 )
 
 response_ru_utime = metrics.register_distribution(
-    "response_ru_utime", labels=["method", "servlet"]
+    "response_ru_utime", labels=["method", "servlet", "tag"]
 )
 
 response_ru_stime = metrics.register_distribution(
-    "response_ru_stime", labels=["method", "servlet"]
+    "response_ru_stime", labels=["method", "servlet", "tag"]
 )
 
 response_db_txn_count = metrics.register_distribution(
-    "response_db_txn_count", labels=["method", "servlet"]
+    "response_db_txn_count", labels=["method", "servlet", "tag"]
 )
 
 response_db_txn_duration = metrics.register_distribution(
-    "response_db_txn_duration", labels=["method", "servlet"]
+    "response_db_txn_duration", labels=["method", "servlet", "tag"]
 )
 
 
@@ -226,7 +226,6 @@ class JsonResource(HttpServer, resource.Resource):
                 servlet_classname = servlet_instance.__class__.__name__
             else:
                 servlet_classname = "%r" % callback
-            incoming_requests_counter.inc(request.method, servlet_classname)
 
             args = [
                 urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@@ -237,21 +236,33 @@ class JsonResource(HttpServer, resource.Resource):
                 code, response = callback_return
                 self._send_response(request, code, response)
 
-            response_timer.inc_by(
-                self.clock.time_msec() - start, request.method, servlet_classname
-            )
-
             try:
                 context = LoggingContext.current_context()
+
+                tag = ""
+                if context:
+                    tag = context.tag
+
+                incoming_requests_counter.inc(request.method, servlet_classname, tag)
+
+                response_timer.inc_by(
+                    self.clock.time_msec() - start, request.method,
+                    servlet_classname, tag
+                )
+
                 ru_utime, ru_stime = context.get_resource_usage()
 
-                response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
-                response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
+                response_ru_utime.inc_by(
+                    ru_utime, request.method, servlet_classname, tag
+                )
+                response_ru_stime.inc_by(
+                    ru_stime, request.method, servlet_classname, tag
+                )
                 response_db_txn_count.inc_by(
-                    context.db_txn_count, request.method, servlet_classname
+                    context.db_txn_count, request.method, servlet_classname, tag
                 )
                 response_db_txn_duration.inc_by(
-                    context.db_txn_duration, request.method, servlet_classname
+                    context.db_txn_duration, request.method, servlet_classname, tag
                 )
             except:
                 pass
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 0595c0fa4..e701092cd 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -47,7 +47,8 @@ class LoggingContext(object):
     """
 
     __slots__ = [
-        "parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
+        "parent_context", "name", "usage_start", "usage_end", "main_thread",
+        "__dict__", "tag",
     ]
 
     thread_local = threading.local()
@@ -72,6 +73,9 @@ class LoggingContext(object):
         def add_database_transaction(self, duration_ms):
             pass
 
+        def __nonzero__(self):
+            return False
+
     sentinel = Sentinel()
 
     def __init__(self, name=None):
@@ -83,6 +87,7 @@ class LoggingContext(object):
         self.db_txn_duration = 0.
         self.usage_start = None
         self.main_thread = threading.current_thread()
+        self.tag = ""
 
     def __str__(self):
         return "%s@%x" % (self.name, id(self))