forked from MirrorHub/synapse
Port receipt and read markers to async/wait
This commit is contained in:
parent
09a135b039
commit
2c35ffead2
6 changed files with 32 additions and 53 deletions
|
@ -36,6 +36,8 @@ from six import iteritems
|
||||||
|
|
||||||
from sortedcontainers import SortedDict
|
from sortedcontainers import SortedDict
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object):
|
||||||
receipt (synapse.types.ReadReceipt):
|
receipt (synapse.types.ReadReceipt):
|
||||||
"""
|
"""
|
||||||
# nothing to do here: the replication listener will handle it.
|
# nothing to do here: the replication listener will handle it.
|
||||||
pass
|
return defer.succeed(None)
|
||||||
|
|
||||||
def send_presence(self, states):
|
def send_presence(self, states):
|
||||||
"""As per FederationSender
|
"""As per FederationSender
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler):
|
||||||
self.read_marker_linearizer = Linearizer(name="read_marker")
|
self.read_marker_linearizer = Linearizer(name="read_marker")
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def received_client_read_marker(self, room_id, user_id, event_id):
|
||||||
def received_client_read_marker(self, room_id, user_id, event_id):
|
|
||||||
"""Updates the read marker for a given user in a given room if the event ID given
|
"""Updates the read marker for a given user in a given room if the event ID given
|
||||||
is ahead in the stream relative to the current read marker.
|
is ahead in the stream relative to the current read marker.
|
||||||
|
|
||||||
|
@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler):
|
||||||
the read marker has changed.
|
the read marker has changed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
|
with await self.read_marker_linearizer.queue((room_id, user_id)):
|
||||||
existing_read_marker = yield self.store.get_account_data_for_room_and_type(
|
existing_read_marker = await self.store.get_account_data_for_room_and_type(
|
||||||
user_id, room_id, "m.fully_read"
|
user_id, room_id, "m.fully_read"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler):
|
||||||
|
|
||||||
if existing_read_marker:
|
if existing_read_marker:
|
||||||
# Only update if the new marker is ahead in the stream
|
# Only update if the new marker is ahead in the stream
|
||||||
should_update = yield self.store.is_event_after(
|
should_update = await self.store.is_event_after(
|
||||||
event_id, existing_read_marker["event_id"]
|
event_id, existing_read_marker["event_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_update:
|
if should_update:
|
||||||
content = {"event_id": event_id}
|
content = {"event_id": event_id}
|
||||||
max_id = yield self.store.add_account_data_to_room(
|
max_id = await self.store.add_account_data_to_room(
|
||||||
user_id, room_id, "m.fully_read", content
|
user_id, room_id, "m.fully_read", content
|
||||||
)
|
)
|
||||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.types import ReadReceipt, get_domain_from_id
|
from synapse.types import ReadReceipt, get_domain_from_id
|
||||||
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler):
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _received_remote_receipt(self, origin, content):
|
||||||
def _received_remote_receipt(self, origin, content):
|
|
||||||
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
||||||
"""
|
"""
|
||||||
receipts = []
|
receipts = []
|
||||||
|
@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._handle_new_receipts(receipts)
|
await self._handle_new_receipts(receipts)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _handle_new_receipts(self, receipts):
|
||||||
def _handle_new_receipts(self, receipts):
|
|
||||||
"""Takes a list of receipts, stores them and informs the notifier.
|
"""Takes a list of receipts, stores them and informs the notifier.
|
||||||
"""
|
"""
|
||||||
min_batch_id = None
|
min_batch_id = None
|
||||||
max_batch_id = None
|
max_batch_id = None
|
||||||
|
|
||||||
for receipt in receipts:
|
for receipt in receipts:
|
||||||
res = yield self.store.insert_receipt(
|
res = await self.store.insert_receipt(
|
||||||
receipt.room_id,
|
receipt.room_id,
|
||||||
receipt.receipt_type,
|
receipt.receipt_type,
|
||||||
receipt.user_id,
|
receipt.user_id,
|
||||||
|
@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler):
|
||||||
|
|
||||||
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
||||||
# Note that the min here shouldn't be relied upon to be accurate.
|
# Note that the min here shouldn't be relied upon to be accurate.
|
||||||
yield self.hs.get_pusherpool().on_new_receipts(
|
await maybe_awaitable(
|
||||||
min_batch_id, max_batch_id, affected_room_ids
|
self.hs.get_pusherpool().on_new_receipts(
|
||||||
|
min_batch_id, max_batch_id, affected_room_ids
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
|
||||||
def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
|
|
||||||
"""Called when a client tells us a local user has read up to the given
|
"""Called when a client tells us a local user has read up to the given
|
||||||
event_id in the room.
|
event_id in the room.
|
||||||
"""
|
"""
|
||||||
|
@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler):
|
||||||
data={"ts": int(self.clock.time_msec())},
|
data={"ts": int(self.clock.time_msec())},
|
||||||
)
|
)
|
||||||
|
|
||||||
is_new = yield self._handle_new_receipts([receipt])
|
is_new = await self._handle_new_receipts([receipt])
|
||||||
if not is_new:
|
if not is_new:
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self.federation.send_read_receipt(receipt)
|
await self.federation.send_read_receipt(receipt)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_receipts_for_room(self, room_id, to_key):
|
|
||||||
"""Gets all receipts for a room, upto the given key.
|
|
||||||
"""
|
|
||||||
result = yield self.store.get_linearized_receipts_for_room(
|
|
||||||
room_id, to_key=to_key
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class ReceiptEventSource(object):
|
class ReceiptEventSource(object):
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet):
|
||||||
self.read_marker_handler = hs.get_read_marker_handler()
|
self.read_marker_handler = hs.get_read_marker_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request, room_id):
|
||||||
def on_POST(self, request, room_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
yield self.presence_handler.bump_presence_active_time(requester.user)
|
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
read_event_id = body.get("m.read", None)
|
read_event_id = body.get("m.read", None)
|
||||||
if read_event_id:
|
if read_event_id:
|
||||||
yield self.receipts_handler.received_client_receipt(
|
await self.receipts_handler.received_client_receipt(
|
||||||
room_id,
|
room_id,
|
||||||
"m.read",
|
"m.read",
|
||||||
user_id=requester.user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
|
@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet):
|
||||||
|
|
||||||
read_marker_event_id = body.get("m.fully_read", None)
|
read_marker_event_id = body.get("m.fully_read", None)
|
||||||
if read_marker_event_id:
|
if read_marker_event_id:
|
||||||
yield self.read_marker_handler.received_client_read_marker(
|
await self.read_marker_handler.received_client_read_marker(
|
||||||
room_id,
|
room_id,
|
||||||
user_id=requester.user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
event_id=read_marker_event_id,
|
event_id=read_marker_event_id,
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
|
@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet):
|
||||||
self.receipts_handler = hs.get_receipts_handler()
|
self.receipts_handler = hs.get_receipts_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request, room_id, receipt_type, event_id):
|
||||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
if receipt_type != "m.read":
|
if receipt_type != "m.read":
|
||||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||||
|
|
||||||
yield self.presence_handler.bump_presence_active_time(requester.user)
|
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||||
|
|
||||||
yield self.receipts_handler.received_client_receipt(
|
await self.receipts_handler.received_client_receipt(
|
||||||
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
|
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -2439,12 +2439,11 @@ class EventsStore(
|
||||||
|
|
||||||
logger.info("[purge] done")
|
logger.info("[purge] done")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def is_event_after(self, event_id1, event_id2):
|
||||||
def is_event_after(self, event_id1, event_id2):
|
|
||||||
"""Returns True if event_id1 is after event_id2 in the stream
|
"""Returns True if event_id1 is after event_id2 in the stream
|
||||||
"""
|
"""
|
||||||
to_1, so_1 = yield self._get_event_ordering(event_id1)
|
to_1, so_1 = await self._get_event_ordering(event_id1)
|
||||||
to_2, so_2 = yield self._get_event_ordering(event_id2)
|
to_2, so_2 = await self._get_event_ordering(event_id2)
|
||||||
return (to_1, so_1) > (to_2, so_2)
|
return (to_1, so_1) > (to_2, so_2)
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=5000)
|
@cachedInlineCallbacks(max_entries=5000)
|
||||||
|
|
Loading…
Reference in a new issue