Port receipt and read markers to async/wait

This commit is contained in:
Erik Johnston 2019-10-29 15:08:22 +00:00
parent 09a135b039
commit 2c35ffead2
6 changed files with 32 additions and 53 deletions

View file

@ -36,6 +36,8 @@ from six import iteritems
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from twisted.internet import defer
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object):
receipt (synapse.types.ReadReceipt): receipt (synapse.types.ReadReceipt):
""" """
# nothing to do here: the replication listener will handle it. # nothing to do here: the replication listener will handle it.
pass return defer.succeed(None)
def send_presence(self, states): def send_presence(self, states):
"""As per FederationSender """As per FederationSender

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler from ._base import BaseHandler
@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler):
self.read_marker_linearizer = Linearizer(name="read_marker") self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def received_client_read_marker(self, room_id, user_id, event_id):
def received_client_read_marker(self, room_id, user_id, event_id):
"""Updates the read marker for a given user in a given room if the event ID given """Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker. is ahead in the stream relative to the current read marker.
@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler):
the read marker has changed. the read marker has changed.
""" """
with (yield self.read_marker_linearizer.queue((room_id, user_id))): with await self.read_marker_linearizer.queue((room_id, user_id)):
existing_read_marker = yield self.store.get_account_data_for_room_and_type( existing_read_marker = await self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read" user_id, room_id, "m.fully_read"
) )
@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler):
if existing_read_marker: if existing_read_marker:
# Only update if the new marker is ahead in the stream # Only update if the new marker is ahead in the stream
should_update = yield self.store.is_event_after( should_update = await self.store.is_event_after(
event_id, existing_read_marker["event_id"] event_id, existing_read_marker["event_id"]
) )
if should_update: if should_update:
content = {"event_id": event_id} content = {"event_id": event_id}
max_id = yield self.store.add_account_data_to_room( max_id = await self.store.add_account_data_to_room(
user_id, room_id, "m.fully_read", content user_id, room_id, "m.fully_read", content
) )
self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@defer.inlineCallbacks async def _received_remote_receipt(self, origin, content):
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS. """Called when we receive an EDU of type m.receipt from a remote HS.
""" """
receipts = [] receipts = []
@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler):
) )
) )
yield self._handle_new_receipts(receipts) await self._handle_new_receipts(receipts)
@defer.inlineCallbacks async def _handle_new_receipts(self, receipts):
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier. """Takes a list of receipts, stores them and informs the notifier.
""" """
min_batch_id = None min_batch_id = None
max_batch_id = None max_batch_id = None
for receipt in receipts: for receipt in receipts:
res = yield self.store.insert_receipt( res = await self.store.insert_receipt(
receipt.room_id, receipt.room_id,
receipt.receipt_type, receipt.receipt_type,
receipt.user_id, receipt.user_id,
@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
yield self.hs.get_pusherpool().on_new_receipts( await maybe_awaitable(
self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids min_batch_id, max_batch_id, affected_room_ids
) )
)
return True return True
@defer.inlineCallbacks async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given """Called when a client tells us a local user has read up to the given
event_id in the room. event_id in the room.
""" """
@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler):
data={"ts": int(self.clock.time_msec())}, data={"ts": int(self.clock.time_msec())},
) )
is_new = yield self._handle_new_receipts([receipt]) is_new = await self._handle_new_receipts([receipt])
if not is_new: if not is_new:
return return
yield self.federation.send_read_receipt(receipt) await self.federation.send_read_receipt(receipt)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
room_id, to_key=to_key
)
if not result:
return []
return result
class ReceiptEventSource(object): class ReceiptEventSource(object):

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns from ._base import client_patterns
@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks async def on_POST(self, request, room_id):
def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
yield self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None) read_event_id = body.get("m.read", None)
if read_event_id: if read_event_id:
yield self.receipts_handler.received_client_receipt( await self.receipts_handler.received_client_receipt(
room_id, room_id,
"m.read", "m.read",
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet):
read_marker_event_id = body.get("m.fully_read", None) read_marker_event_id = body.get("m.fully_read", None)
if read_marker_event_id: if read_marker_event_id:
yield self.read_marker_handler.received_client_read_marker( await self.read_marker_handler.received_client_read_marker(
room_id, room_id,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
event_id=read_marker_event_id, event_id=read_marker_event_id,

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler() self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks async def on_POST(self, request, room_id, receipt_type, event_id):
def on_POST(self, request, room_id, receipt_type, event_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt( await self.receipts_handler.received_client_receipt(
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
) )

View file

@ -2439,12 +2439,11 @@ class EventsStore(
logger.info("[purge] done") logger.info("[purge] done")
@defer.inlineCallbacks async def is_event_after(self, event_id1, event_id2):
def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream """Returns True if event_id1 is after event_id2 in the stream
""" """
to_1, so_1 = yield self._get_event_ordering(event_id1) to_1, so_1 = await self._get_event_ordering(event_id1)
to_2, so_2 = yield self._get_event_ordering(event_id2) to_2, so_2 = await self._get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2) return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)