mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 00:53:52 +01:00
Add types to StreamToken and RoomStreamToken (#8279)
The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too.
This commit is contained in:
parent
094896a69d
commit
63c0e9e195
5 changed files with 95 additions and 91 deletions
1
changelog.d/8279.misc
Normal file
1
changelog.d/8279.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to `StreamToken` and `RoomStreamToken` classes.
|
|
@ -1310,12 +1310,11 @@ class SyncHandler:
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources["presence"]
|
||||||
|
|
||||||
since_token = sync_result_builder.since_token
|
since_token = sync_result_builder.since_token
|
||||||
|
presence_key = None
|
||||||
|
include_offline = False
|
||||||
if since_token and not sync_result_builder.full_state:
|
if since_token and not sync_result_builder.full_state:
|
||||||
presence_key = since_token.presence_key
|
presence_key = since_token.presence_key
|
||||||
include_offline = True
|
include_offline = True
|
||||||
else:
|
|
||||||
presence_key = None
|
|
||||||
include_offline = False
|
|
||||||
|
|
||||||
presence, presence_key = await presence_source.get_new_events(
|
presence, presence_key = await presence_source.get_new_events(
|
||||||
user=user,
|
user=user,
|
||||||
|
|
|
@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_users_whose_devices_changed(
|
async def get_users_whose_devices_changed(
|
||||||
self, from_key: str, user_ids: Iterable[str]
|
self, from_key: int, user_ids: Iterable[str]
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Get set of users whose devices have changed since `from_key` that
|
"""Get set of users whose devices have changed since `from_key` that
|
||||||
are in the given list of user_ids.
|
are in the given list of user_ids.
|
||||||
|
@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
The set of user_ids whose devices have changed since `from_key`
|
The set of user_ids whose devices have changed since `from_key`
|
||||||
"""
|
"""
|
||||||
from_key = int(from_key)
|
|
||||||
|
|
||||||
# Get set of users who *may* have changed. Users not in the returned
|
# Get set of users who *may* have changed. Users not in the returned
|
||||||
# list have definitely not changed.
|
# list have definitely not changed.
|
||||||
|
@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_users_whose_signatures_changed(
|
async def get_users_whose_signatures_changed(
|
||||||
self, user_id: str, from_key: str
|
self, user_id: str, from_key: int
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Get the users who have new cross-signing signatures made by `user_id` since
|
"""Get the users who have new cross-signing signatures made by `user_id` since
|
||||||
`from_key`.
|
`from_key`.
|
||||||
|
@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
A set of user IDs with updated signatures.
|
A set of user IDs with updated signatures.
|
||||||
"""
|
"""
|
||||||
from_key = int(from_key)
|
|
||||||
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
||||||
sql = """
|
sql = """
|
||||||
SELECT DISTINCT user_ids FROM user_signature_stream
|
SELECT DISTINCT user_ids FROM user_signature_stream
|
||||||
|
|
|
@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
|
||||||
def generate_pagination_where_clause(
|
def generate_pagination_where_clause(
|
||||||
direction: str,
|
direction: str,
|
||||||
column_names: Tuple[str, str],
|
column_names: Tuple[str, str],
|
||||||
from_token: Optional[Tuple[int, int]],
|
from_token: Optional[Tuple[Optional[int], int]],
|
||||||
to_token: Optional[Tuple[int, int]],
|
to_token: Optional[Tuple[Optional[int], int]],
|
||||||
engine: BaseDatabaseEngine,
|
engine: BaseDatabaseEngine,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Creates an SQL expression to bound the columns by the pagination
|
"""Creates an SQL expression to bound the columns by the pagination
|
||||||
|
@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
if limit == 0:
|
if limit == 0:
|
||||||
return [], end_token
|
return [], end_token
|
||||||
|
|
||||||
end_token = RoomStreamToken.parse(end_token)
|
parsed_end_token = RoomStreamToken.parse(end_token)
|
||||||
|
|
||||||
rows, token = await self.db_pool.runInteraction(
|
rows, token = await self.db_pool.runInteraction(
|
||||||
"get_recent_event_ids_for_room",
|
"get_recent_event_ids_for_room",
|
||||||
self._paginate_room_events_txn,
|
self._paginate_room_events_txn,
|
||||||
room_id,
|
room_id,
|
||||||
from_token=end_token,
|
from_token=parsed_end_token,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
bounds = generate_pagination_where_clause(
|
bounds = generate_pagination_where_clause(
|
||||||
direction=direction,
|
direction=direction,
|
||||||
column_names=("topological_ordering", "stream_ordering"),
|
column_names=("topological_ordering", "stream_ordering"),
|
||||||
from_token=from_token,
|
from_token=from_token.as_tuple(),
|
||||||
to_token=to_token,
|
to_token=to_token.as_tuple() if to_token else None,
|
||||||
engine=self.database_engine,
|
engine=self.database_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
and `to_key`).
|
and `to_key`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from_key = RoomStreamToken.parse(from_key)
|
parsed_from_key = RoomStreamToken.parse(from_key)
|
||||||
|
parsed_to_key = None
|
||||||
if to_key:
|
if to_key:
|
||||||
to_key = RoomStreamToken.parse(to_key)
|
parsed_to_key = RoomStreamToken.parse(to_key)
|
||||||
|
|
||||||
rows, token = await self.db_pool.runInteraction(
|
rows, token = await self.db_pool.runInteraction(
|
||||||
"paginate_room_events",
|
"paginate_room_events",
|
||||||
self._paginate_room_events_txn,
|
self._paginate_room_events_txn,
|
||||||
room_id,
|
room_id,
|
||||||
from_key,
|
parsed_from_key,
|
||||||
to_key,
|
parsed_to_key,
|
||||||
direction,
|
direction,
|
||||||
limit,
|
limit,
|
||||||
event_filter,
|
event_filter,
|
||||||
|
|
152
synapse/types.py
152
synapse/types.py
|
@ -18,7 +18,7 @@ import re
|
||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
|
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
@ -362,22 +362,79 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
|
||||||
return username.decode("ascii")
|
return username.decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
class StreamToken(
|
@attr.s(frozen=True, slots=True)
|
||||||
namedtuple(
|
class RoomStreamToken:
|
||||||
"Token",
|
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||||
(
|
|
||||||
"room_key",
|
s0 s1
|
||||||
"presence_key",
|
| |
|
||||||
"typing_key",
|
[0] V [1] V [2]
|
||||||
"receipt_key",
|
|
||||||
"account_data_key",
|
Tokens can either be a point in the live event stream or a cursor going
|
||||||
"push_rules_key",
|
through historic events.
|
||||||
"to_device_key",
|
|
||||||
"device_list_key",
|
When traversing the live event stream events are ordered by when they
|
||||||
"groups_key",
|
arrived at the homeserver.
|
||||||
),
|
|
||||||
|
When traversing historic events the events are ordered by their depth in
|
||||||
|
the event graph "topological_ordering" and then by when they arrived at the
|
||||||
|
homeserver "stream_ordering".
|
||||||
|
|
||||||
|
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||||
|
event it comes after. Historic tokens start with a "t" followed by the
|
||||||
|
"topological_ordering" id of the event it comes after, followed by "-",
|
||||||
|
followed by the "stream_ordering" id of the event it comes after.
|
||||||
|
"""
|
||||||
|
|
||||||
|
topological = attr.ib(
|
||||||
|
type=Optional[int],
|
||||||
|
validator=attr.validators.optional(attr.validators.instance_of(int)),
|
||||||
)
|
)
|
||||||
):
|
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, string: str) -> "RoomStreamToken":
|
||||||
|
try:
|
||||||
|
if string[0] == "s":
|
||||||
|
return cls(topological=None, stream=int(string[1:]))
|
||||||
|
if string[0] == "t":
|
||||||
|
parts = string[1:].split("-", 1)
|
||||||
|
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
|
||||||
|
try:
|
||||||
|
if string[0] == "s":
|
||||||
|
return cls(topological=None, stream=int(string[1:]))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||||
|
|
||||||
|
def as_tuple(self) -> Tuple[Optional[int], int]:
|
||||||
|
return (self.topological, self.stream)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.topological is not None:
|
||||||
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
|
else:
|
||||||
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class StreamToken:
|
||||||
|
room_key = attr.ib(type=str)
|
||||||
|
presence_key = attr.ib(type=int)
|
||||||
|
typing_key = attr.ib(type=int)
|
||||||
|
receipt_key = attr.ib(type=int)
|
||||||
|
account_data_key = attr.ib(type=int)
|
||||||
|
push_rules_key = attr.ib(type=int)
|
||||||
|
to_device_key = attr.ib(type=int)
|
||||||
|
device_list_key = attr.ib(type=int)
|
||||||
|
groups_key = attr.ib(type=int)
|
||||||
|
|
||||||
_SEPARATOR = "_"
|
_SEPARATOR = "_"
|
||||||
START = None # type: StreamToken
|
START = None # type: StreamToken
|
||||||
|
|
||||||
|
@ -385,15 +442,15 @@ class StreamToken(
|
||||||
def from_string(cls, string):
|
def from_string(cls, string):
|
||||||
try:
|
try:
|
||||||
keys = string.split(cls._SEPARATOR)
|
keys = string.split(cls._SEPARATOR)
|
||||||
while len(keys) < len(cls._fields):
|
while len(keys) < len(attr.fields(cls)):
|
||||||
# i.e. old token from before receipt_key
|
# i.e. old token from before receipt_key
|
||||||
keys.append("0")
|
keys.append("0")
|
||||||
return cls(*keys)
|
return cls(keys[0], *(int(k) for k in keys[1:]))
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "Invalid Token")
|
raise SynapseError(400, "Invalid Token")
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self):
|
||||||
return self._SEPARATOR.join([str(k) for k in self])
|
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def room_stream_id(self):
|
def room_stream_id(self):
|
||||||
|
@ -435,63 +492,10 @@ class StreamToken(
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def copy_and_replace(self, key, new_value):
|
def copy_and_replace(self, key, new_value):
|
||||||
return self._replace(**{key: new_value})
|
return attr.evolve(self, **{key: new_value})
|
||||||
|
|
||||||
|
|
||||||
StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
|
StreamToken.START = StreamToken.from_string("s0_0")
|
||||||
|
|
||||||
|
|
||||||
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|
||||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
|
||||||
|
|
||||||
s0 s1
|
|
||||||
| |
|
|
||||||
[0] V [1] V [2]
|
|
||||||
|
|
||||||
Tokens can either be a point in the live event stream or a cursor going
|
|
||||||
through historic events.
|
|
||||||
|
|
||||||
When traversing the live event stream events are ordered by when they
|
|
||||||
arrived at the homeserver.
|
|
||||||
|
|
||||||
When traversing historic events the events are ordered by their depth in
|
|
||||||
the event graph "topological_ordering" and then by when they arrived at the
|
|
||||||
homeserver "stream_ordering".
|
|
||||||
|
|
||||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
|
||||||
event it comes after. Historic tokens start with a "t" followed by the
|
|
||||||
"topological_ordering" id of the event it comes after, followed by "-",
|
|
||||||
followed by the "stream_ordering" id of the event it comes after.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = [] # type: list
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse(cls, string):
|
|
||||||
try:
|
|
||||||
if string[0] == "s":
|
|
||||||
return cls(topological=None, stream=int(string[1:]))
|
|
||||||
if string[0] == "t":
|
|
||||||
parts = string[1:].split("-", 1)
|
|
||||||
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse_stream_token(cls, string):
|
|
||||||
try:
|
|
||||||
if string[0] == "s":
|
|
||||||
return cls(topological=None, stream=int(string[1:]))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.topological is not None:
|
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
|
||||||
else:
|
|
||||||
return "s%d" % (self.stream,)
|
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyInstanceID(
|
class ThirdPartyInstanceID(
|
||||||
|
|
Loading…
Reference in a new issue