From faeb369f158a3ca6ba8f48ca1d551b2b53f4c53a Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Wed, 21 Feb 2018 15:19:54 +0000
Subject: [PATCH] Fix missing invalidations for receipt storage

---
 synapse/replication/slave/storage/receipts.py |  2 ++
 synapse/storage/receipts.py                   | 28 +++++++++----------
 2 files changed, 16 insertions(+), 14 deletions(-)

diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index f0e29e983..1647072f6 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -53,6 +53,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "receipts":
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 40530632c..eac8694e0 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -292,20 +292,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "get_all_updated_receipts", get_all_updated_receipts_txn
         )
 
-
-class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-
-        super(ReceiptsStore, self).__init__(db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
     def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
                                                     user_id):
         if receipt_type != "m.read":
@@ -326,6 +312,20 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    def __init__(self, db_conn, hs):
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+
+        super(ReceiptsStore, self).__init__(db_conn, hs)
+
+    def get_max_receipt_stream_id(self):
+        return self._receipts_id_gen.get_current_token()
+
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
         txn.call_after(