Make token serializing/deserializing async (#8427)

The idea is that in future tokens will encode a mapping of instance to position. However, we don't want to include the full instance name in the string representation, so instead we'll have a mapping between instance name and an immutable integer ID in the DB that we can use instead. We'll then do the lookup when we serialize/deserialize the token (we could alternatively pass around an `Instance` type that includes both the name and ID, but that turns out to be a lot more invasive).
This commit is contained in:
Erik Johnston 2020-09-30 20:29:19 +01:00 committed by GitHub
parent a0a1ba6973
commit 7941372ec8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 115 additions and 59 deletions

1
changelog.d/8427.misc Normal file
View file

@ -0,0 +1 @@
Make stream token serializing/deserializing async.

View file

@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler):
chunk = { chunk = {
"chunk": chunks, "chunk": chunks,
"start": tokens[0].to_string(), "start": await tokens[0].to_string(self.store),
"end": tokens[1].to_string(), "end": await tokens[1].to_string(self.store),
} }
return chunk return chunk

View file

@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler):
messages, time_now=time_now, as_client_event=as_client_event messages, time_now=time_now, as_client_event=as_client_event
) )
), ),
"start": start_token.to_string(), "start": await start_token.to_string(self.store),
"end": end_token.to_string(), "end": await end_token.to_string(self.store),
} }
d["state"] = await self._event_serializer.serialize_events( d["state"] = await self._event_serializer.serialize_events(
@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler):
], ],
"account_data": account_data_events, "account_data": account_data_events,
"receipts": receipt, "receipts": receipt,
"end": now_token.to_string(), "end": await now_token.to_string(self.store),
} }
return ret return ret
@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler):
"chunk": ( "chunk": (
await self._event_serializer.serialize_events(messages, time_now) await self._event_serializer.serialize_events(messages, time_now)
), ),
"start": start_token.to_string(), "start": await start_token.to_string(self.store),
"end": end_token.to_string(), "end": await end_token.to_string(self.store),
}, },
"state": ( "state": (
await self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler):
"chunk": ( "chunk": (
await self._event_serializer.serialize_events(messages, time_now) await self._event_serializer.serialize_events(messages, time_now)
), ),
"start": start_token.to_string(), "start": await start_token.to_string(self.store),
"end": end_token.to_string(), "end": await end_token.to_string(self.store),
}, },
"state": state, "state": state,
"presence": presence, "presence": presence,

View file

@ -413,8 +413,8 @@ class PaginationHandler:
if not events: if not events:
return { return {
"chunk": [], "chunk": [],
"start": from_token.to_string(), "start": await from_token.to_string(self.store),
"end": next_token.to_string(), "end": await next_token.to_string(self.store),
} }
state = None state = None
@ -442,8 +442,8 @@ class PaginationHandler:
events, time_now, as_client_event=as_client_event events, time_now, as_client_event=as_client_event
) )
), ),
"start": from_token.to_string(), "start": await from_token.to_string(self.store),
"end": next_token.to_string(), "end": await next_token.to_string(self.store),
} }
if state: if state:

View file

@ -1077,11 +1077,13 @@ class RoomContextHandler:
# the token, which we replace. # the token, which we replace.
token = StreamToken.START token = StreamToken.START
results["start"] = token.copy_and_replace( results["start"] = await token.copy_and_replace(
"room_key", results["start"] "room_key", results["start"]
).to_string() ).to_string(self.store)
results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() results["end"] = await token.copy_and_replace(
"room_key", results["end"]
).to_string(self.store)
return results return results

View file

@ -362,13 +362,13 @@ class SearchHandler(BaseHandler):
self.storage, user.to_string(), res["events_after"] self.storage, user.to_string(), res["events_after"]
) )
res["start"] = now_token.copy_and_replace( res["start"] = await now_token.copy_and_replace(
"room_key", res["start"] "room_key", res["start"]
).to_string() ).to_string(self.store)
res["end"] = now_token.copy_and_replace( res["end"] = await now_token.copy_and_replace(
"room_key", res["end"] "room_key", res["end"]
).to_string() ).to_string(self.store)
if include_profile: if include_profile:
senders = { senders = {

View file

@ -110,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet):
raise SynapseError(400, "Event is for wrong room.") raise SynapseError(400, "Event is for wrong room.")
room_token = await self.store.get_topological_token_for_event(event_id) room_token = await self.store.get_topological_token_for_event(event_id)
token = str(room_token) token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body: elif "purge_up_to_ts" in body:

View file

@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet):
super().__init__() super().__init__()
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
if b"room_id" in request.args: if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii") room_id = request.args[b"room_id"][0].decode("ascii")
pagin_config = PaginationConfig.from_request(request) pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args: if b"timeout" in request.args:
try: try:

View file

@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet):
super().__init__() super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False) include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms( content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),

View file

@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet):
super().__init__() super().__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id): async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
if at_token_string is None: if at_token_string is None:
at_token = None at_token = None
else: else:
at_token = StreamToken.from_string(at_token_string) at_token = await StreamToken.from_string(self.store, at_token_string)
# let you filter down on particular memberships. # let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter # XXX: this may not be the best shape for this API - we could pass in a filter
@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet):
super().__init__() super().__init__()
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id): async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10) pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
filter_str = parse_string(request, b"filter", encoding="utf-8") filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str: if filter_str:
@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet):
super().__init__() super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id): async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request) pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync( content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config room_id=room_id, requester=requester, pagin_config=pagination_config
) )

View file

@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet):
# changes after the "to" as well as before. # changes after the "to" as well as before.
set_tag("to", parse_string(request, "to")) set_tag("to", parse_string(request, "to"))
from_token = StreamToken.from_string(from_token_string) from_token = await StreamToken.from_string(self.store, from_token_string)
user_id = requester.user.to_string() user_id = requester.user.to_string()

View file

@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.sync_handler = hs.get_sync_handler() self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet):
device_id=device_id, device_id=device_id,
) )
since_token = None
if since is not None: if since is not None:
since_token = StreamToken.from_string(since) since_token = await StreamToken.from_string(self.store, since)
else:
since_token = None
# send any outstanding server notices to the user. # send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(user.to_string()) await self._server_notices_sender.on_user_syncing(user.to_string())
@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave, "leave": sync_result.groups.leave,
}, },
"device_one_time_keys_count": sync_result.device_one_time_keys_count, "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(), "next_batch": await sync_result.next_batch.to_string(self.store),
} }
@staticmethod @staticmethod
@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet):
result = { result = {
"timeline": { "timeline": {
"events": serialized_timeline, "events": serialized_timeline,
"prev_batch": room.timeline.prev_batch.to_string(), "prev_batch": await room.timeline.prev_batch.to_string(self.store),
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": serialized_state}, "state": {"events": serialized_state},

View file

@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
The set of state groups that are referenced by deleted events. The set of state groups that are referenced by deleted events.
""" """
parsed_token = await RoomStreamToken.parse(self, token)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"purge_history", "purge_history",
self._purge_history_txn, self._purge_history_txn,
room_id, room_id,
token, parsed_token,
delete_local_events, delete_local_events,
) )
def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): def _purge_history_txn(self, txn, room_id, token, delete_local_events):
token = RoomStreamToken.parse(token_str)
# Tables that should be pruned: # Tables that should be pruned:
# event_auth # event_auth
# event_backward_extremities # event_backward_extremities

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import Optional
@ -21,6 +20,7 @@ import attr
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken from synapse.types import StreamToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,8 +39,9 @@ class PaginationConfig:
limit = attr.ib(type=Optional[int]) limit = attr.ib(type=Optional[int])
@classmethod @classmethod
def from_request( async def from_request(
cls, cls,
store: "DataStore",
request: SynapseRequest, request: SynapseRequest,
raise_invalid_params: bool = True, raise_invalid_params: bool = True,
default_limit: Optional[int] = None, default_limit: Optional[int] = None,
@ -54,13 +55,13 @@ class PaginationConfig:
if from_tok == "END": if from_tok == "END":
from_tok = None # For backwards compat. from_tok = None # For backwards compat.
elif from_tok: elif from_tok:
from_tok = StreamToken.from_string(from_tok) from_tok = await StreamToken.from_string(store, from_tok)
except Exception: except Exception:
raise SynapseError(400, "'from' parameter is invalid") raise SynapseError(400, "'from' parameter is invalid")
try: try:
if to_tok: if to_tok:
to_tok = StreamToken.from_string(to_tok) to_tok = await StreamToken.from_string(store, to_tok)
except Exception: except Exception:
raise SynapseError(400, "'to' parameter is invalid") raise SynapseError(400, "'to' parameter is invalid")

View file

@ -18,7 +18,17 @@ import re
import string import string
import sys import sys
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar from typing import (
TYPE_CHECKING,
Any,
Dict,
Mapping,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
)
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5 # define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0): if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection from typing import Collection
@ -393,7 +406,7 @@ class RoomStreamToken:
stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
@classmethod @classmethod
def parse(cls, string: str) -> "RoomStreamToken": async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
try: try:
if string[0] == "s": if string[0] == "s":
return cls(topological=None, stream=int(string[1:])) return cls(topological=None, stream=int(string[1:]))
@ -428,7 +441,7 @@ class RoomStreamToken:
def as_tuple(self) -> Tuple[Optional[int], int]: def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream) return (self.topological, self.stream)
def __str__(self) -> str: async def to_string(self, store: "DataStore") -> str:
if self.topological is not None: if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
@ -453,18 +466,32 @@ class StreamToken:
START = None # type: StreamToken START = None # type: StreamToken
@classmethod @classmethod
def from_string(cls, string): async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
while len(keys) < len(attr.fields(cls)): while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")
return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:])) return cls(
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
)
except Exception: except Exception:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
def to_string(self): async def to_string(self, store: "DataStore") -> str:
return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)]) return self._SEPARATOR.join(
[
await self.room_key.to_string(store),
str(self.presence_key),
str(self.typing_key),
str(self.receipt_key),
str(self.account_data_key),
str(self.push_rules_key),
str(self.to_device_key),
str(self.device_list_key),
str(self.groups_key),
]
)
@property @property
def room_stream_id(self): def room_stream_id(self):
@ -493,7 +520,7 @@ class StreamToken:
return attr.evolve(self, **{key: new_value}) return attr.evolve(self, **{key: new_value})
StreamToken.START = StreamToken.from_string("s0_0") StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)

View file

@ -902,16 +902,18 @@ class RoomMessageListTestCase(RoomBase):
# Send a first message in the room, which will be removed by the purge. # Send a first message in the room, which will be removed by the purge.
first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
first_token = str( first_token = self.get_success(
self.get_success(store.get_topological_token_for_event(first_event_id)) store.get_topological_token_for_event(first_event_id)
) )
first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll # Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before. # use as the marker to purge events before.
second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
second_token = str( second_token = self.get_success(
self.get_success(store.get_topological_token_for_event(second_event_id)) store.get_topological_token_for_event(second_event_id)
) )
second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case # Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room. # due to our marker being the latest forward extremity in the room.
@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request( request, channel = self.make_request(
"GET", "GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), % (
self.room_id,
second_token_str,
json.dumps({"types": [EventTypes.Message]}),
),
) )
self.render(request) self.render(request)
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history( pagination_handler._purge_history(
purge_id=purge_id, purge_id=purge_id,
room_id=self.room_id, room_id=self.room_id,
token=second_token, token=second_token_str,
delete_local_events=True, delete_local_events=True,
) )
) )
@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request( request, channel = self.make_request(
"GET", "GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), % (
self.room_id,
second_token_str,
json.dumps({"types": [EventTypes.Message]}),
),
) )
self.render(request) self.render(request)
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request( request, channel = self.make_request(
"GET", "GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), % (
self.room_id,
first_token_str,
json.dumps({"types": [EventTypes.Message]}),
),
) )
self.render(request) self.render(request)
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)

View file

@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage() storage = self.hs.get_storage()
# Get the topological token # Get the topological token
event = str( token = self.get_success(
self.get_success(store.get_topological_token_for_event(last["event_id"])) store.get_topological_token_for_event(last["event_id"])
) )
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token # Purge everything before this topological token
self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) self.get_success(
storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not. # and last is not.