Merge branch 'develop' into matrix-org-hotfixes

This commit is contained in:
Brendan Abolivier 2020-09-04 11:02:10 +01:00
commit cc23d81a74
33 changed files with 400 additions and 379 deletions

View file

@ -1 +0,0 @@
Add experimental support for sharding event persister.

1
changelog.d/8233.misc Normal file
View file

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8240.misc Normal file
View file

@ -0,0 +1 @@
Fix type hints for functions decorated with `@cached`.

1
changelog.d/8242.feature Normal file
View file

@ -0,0 +1 @@
Back out experimental support for sharding event persister. **PLEASE REMOVE THIS LINE FROM THE FINAL CHANGELOG**

1
changelog.d/8244.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to pagination, initial sync and events handlers.

1
changelog.d/8245.misc Normal file
View file

@ -0,0 +1 @@
Remove obsolete `order` field from federation send queues.

View file

@ -1,6 +1,6 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent
check_untyped_defs = True
show_error_codes = True
@ -17,10 +17,13 @@ files =
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py,
synapse/handlers/oidc_handler.py,
synapse/handlers/pagination.py,
synapse/handlers/presence.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
@ -51,6 +54,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,

View file

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is a mypy plugin for Synpase to deal with some of the funky typing that
can crop up, e.g the cache descriptors.
"""
from typing import Callable, Optional
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType
class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
):
return cached_function_method_signature
return None
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
"""
# First we mark this as a bound function signature.
signature = bind_self(ctx.default_signature)
# Secondly, we remove any "cache_context" args.
#
# Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world.
context_arg_index = None
for idx, name in enumerate(signature.arg_names):
if name == "cache_context":
context_arg_index = idx
break
if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)
arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)
arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)
signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)
return signature
def plugin(version: str):
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
# However, since we pin the version of mypy Synapse uses in CI, we don't
# really care.
return SynapsePlugin

View file

@ -832,26 +832,11 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true
# If multiple instances are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
return self.get_instance(key) == instance_name
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance
is sending is known only to the sender instance if there is only one.
Therefore `should_handle` should be used where possible.
"""
if not self.instances:
return "master"
if len(self.instances) == 1:
return self.instances[0]
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
@ -861,7 +846,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder]
return self.instances[remainder] == instance_name
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View file

@ -142,4 +142,3 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
def get_instance(self, key: str) -> str: ...

View file

@ -13,24 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
if isinstance(obj, str):
return [obj]
return obj
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
@ -45,13 +33,11 @@ class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instances that write to the event and backfill streams.
typing: The instance that writes to the typing stream.
events: The instance that writes to the event and backfill streams.
events: The instance that writes to the typing stream.
"""
events = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter
)
events = attr.ib(default="master", type=str)
typing = attr.ib(default="master", type=str)
@ -119,18 +105,15 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writers for events and typing also appears in
# Check that the configured writer for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
instance = getattr(self.writers, stream)
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\

View file

@ -108,8 +108,6 @@ class FederationSender(object):
),
)
self._order = 1
self._is_processing = False
self._last_poked_id = -1
@ -290,9 +288,6 @@ class FederationSender(object):
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
order = self._order
self._order += 1
destinations = set(destinations)
destinations.discard(self.server_name)
logger.debug("Sending to: %s", str(destinations))
@ -304,7 +299,7 @@ class FederationSender(object):
sent_pdus_destination_dist_count.inc()
for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu, order)
self._get_per_destination_queue(destination).send_pdu(pdu)
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room

View file

@ -95,8 +95,8 @@ class PerDestinationQueue(object):
self._destination = destination
self.transmission_loop_running = False
# a list of tuples of (pending pdu, order)
self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
# a list of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see
# https://github.com/matrix-org/synapse/issues/7549
@ -135,14 +135,13 @@ class PerDestinationQueue(object):
+ len(self._pending_edus_keyed)
)
def send_pdu(self, pdu: EventBase, order: int) -> None:
def send_pdu(self, pdu: EventBase) -> None:
"""Add a PDU to the queue, and start the transmission loop if necessary
Args:
pdu: pdu to send
order
"""
self._pending_pdus.append((pdu, order))
self._pending_pdus.append(pdu)
self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@ -188,7 +187,7 @@ class PerDestinationQueue(object):
returns immediately. Otherwise kicks off the process of sending a
transaction in the background.
"""
# list of (pending_pdu, deferred, order)
if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing.
@ -213,7 +212,7 @@ class PerDestinationQueue(object):
)
async def _transaction_transmission_loop(self) -> None:
pending_pdus = [] # type: List[Tuple[EventBase, int]]
pending_pdus = [] # type: List[EventBase]
try:
self.transmission_loop_running = True
@ -388,13 +387,13 @@ class PerDestinationQueue(object):
"TX [%s] Failed to send transaction: %s", self._destination, e
)
for p, _ in pending_pdus:
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p, _ in pending_pdus:
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
@ -57,11 +57,17 @@ class TransactionManager(object):
@measure_func("_send_new_transaction")
async def send_new_transaction(
self,
destination: str,
pending_pdus: List[Tuple[EventBase, int]],
pending_edus: List[Edu],
):
self, destination: str, pdus: List[EventBase], edus: List[Edu],
) -> bool:
"""
Args:
destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
Returns:
True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is
@ -71,7 +77,7 @@ class TransactionManager(object):
span_contexts = []
keep_destination = whitelisted_homeserver(destination)
for edu in pending_edus:
for edu in edus:
context = edu.get_context()
if context:
span_contexts.append(extract_text_map(json_decoder.decode(context)))
@ -79,12 +85,6 @@ class TransactionManager(object):
edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
success = True
logger.debug("TX [%s] _attempt_new_transaction", destination)

View file

@ -15,29 +15,30 @@
import logging
import random
from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.utils import log_function
from synapse.types import UserID
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super(EventStreamHandler, self).__init__(hs)
# Count of active streams per user
self._streams_per_user = {}
# Grace timers per user to delay the "stopped" signal
self._stop_timer_per_user = {}
self.distributor = hs.get_distributor()
self.distributor.declare("started_user_eventstream")
self.distributor.declare("stopped_user_eventstream")
@ -52,14 +53,14 @@ class EventStreamHandler(BaseHandler):
@log_function
async def get_stream(
self,
auth_user_id,
pagin_config,
timeout=0,
as_client_event=True,
affect_presence=True,
room_id=None,
is_guest=False,
):
auth_user_id: str,
pagin_config: PaginationConfig,
timeout: int = 0,
as_client_event: bool = True,
affect_presence: bool = True,
room_id: Optional[str] = None,
is_guest: bool = False,
) -> JsonDict:
"""Fetches the events stream for a given user.
"""
@ -98,7 +99,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
to_add = [] # type: List[JsonDict]
for event in events:
if not isinstance(event, EventBase):
continue
@ -110,7 +111,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence for everyone in the room.
users = await self.state.get_current_users_in_room(
event.room_id
)
) # type: Iterable[str]
else:
users = [event.state_key]
@ -144,20 +145,22 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
async def get_event(self, user, room_id, event_id):
async def get_event(
self, user: UserID, room_id: Optional[str], event_id: str
) -> Optional[EventBase]:
"""Retrieve a single specified event.
Args:
user (synapse.types.UserID): The user requesting the event
room_id (str|None): The expected room id. We'll return None if the
user: The user requesting the event
room_id: The expected room id. We'll return None if the
event's room does not match.
event_id (str): The event ID to obtain.
event_id: The event ID to obtain.
Returns:
dict: An event, or None if there is no event matching this ID.
An event, or None if there is no event matching this ID.
Raises:
SynapseError if there was a problem retrieving this event, or
AuthError if the user does not have the rights to inspect this

View file

@ -440,11 +440,11 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
return
latest = await self.store.get_latest_event_ids_in_room(room_id)
latest_list = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest = set(latest_list)
latest |= seen
logger.info(
@ -781,7 +781,7 @@ class FederationHandler(BaseHandler):
# keys across all devices.
current_keys = [
key
for device in cached_devices
for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]
@ -923,8 +923,7 @@ class FederationHandler(BaseHandler):
)
)
if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@ -1217,7 +1216,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
destination, room_id, event_infos,
destination, event_infos,
)
def _sanity_check_event(self, ev):
@ -1364,15 +1363,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
origin, room_id, auth_chain, state, event, room_version_obj
origin, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.config.worker.events_shard_config.get_instance(room_id),
"events",
max_stream_id,
self.config.worker.writers.events, "events", max_stream_id
)
# Check whether this room is the result of an upgrade of a room we already know
@ -1626,7 +1625,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify(event.room_id, [(event, context)])
await self.persist_events_and_notify([(event, context)])
return event
@ -1653,9 +1652,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify(
event.room_id, [(event, context)]
)
stream_id = await self.persist_events_and_notify([(event, context)])
return event, stream_id
@ -1903,7 +1900,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
event.room_id, [(event, context)], backfilled=backfilled
[(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@ -1916,7 +1913,6 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@ -1948,7 +1944,6 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@ -1959,7 +1954,6 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@ -1974,7 +1968,6 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
room_id,
auth_events
state
event
@ -2049,20 +2042,17 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
]
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
return await self.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
return await self.persist_events_and_notify([(event, new_event_context)])
async def _prep_event(
self,
@ -2119,8 +2109,8 @@ class FederationHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier():
return
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:
@ -2913,7 +2903,6 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@ -2921,19 +2910,14 @@ class FederationHandler(BaseHandler):
necessary.
Args:
room_id: The room ID of events being persisted.
event_and_contexts: Sequence of events with their associated
context that should be persisted. All events must belong to
the same room.
event_and_contexts:
backfilled: Whether these events are a result of
backfilling or not
"""
instance = self.config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
if self.config.worker.writers.events != self._instance_name:
result = await self._send_events(
instance_name=instance,
instance_name=self.config.worker.writers.events,
store=self.store,
room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.internet import defer
@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super(InitialSyncHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
def snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
as_client_event=True,
include_archived=False,
):
user_id: str,
pagin_config: PaginationConfig,
as_client_event: bool = True,
include_archived: bool = False,
) -> JsonDict:
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
user_id: The ID of the user making the request.
pagin_config: The pagination config used to determine how many
messages *PER ROOM* to return.
as_client_event: True to get events in client-server format.
include_archived: True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
A JsonDict with the same format as the response to `/intialSync`
API
"""
key = (
user_id,
@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
async def _snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
as_client_event=True,
include_archived=False,
):
user_id: str,
pagin_config: PaginationConfig,
as_client_event: bool = True,
include_archived: bool = False,
) -> JsonDict:
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
async def handle_room(event):
async def handle_room(event: RoomsForUser):
d = {
"room_id": event.room_id,
"membership": event.membership,
@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler):
return ret
async def room_initial_sync(self, requester, room_id, pagin_config=None):
async def room_initial_sync(
self, requester: Requester, room_id: str, pagin_config: PaginationConfig
) -> JsonDict:
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
requester: The user to get a snapshot for.
room_id: The room to get a snapshot of.
pagin_config: The pagination config used to determine how many
messages to return.
Raises:
AuthError if the user wasn't in the room.
Returns:
@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler):
return result
async def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
self,
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler):
}
async def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking
):
self,
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
membership: Membership,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently

View file

@ -376,8 +376,9 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._is_event_writer = (
self.config.worker.writers.events == hs.get_instance_name()
)
self.room_invite_state_types = self.hs.config.room_invite_state_types
@ -905,10 +906,9 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
if writer_instance != self._instance_name:
if not self._is_event_writer:
result = await self.send_event(
instance_name=writer_instance,
instance_name=self.config.worker.writers.events,
event_id=event.event_id,
store=self.store,
requester=requester,
@ -976,9 +976,7 @@ class EventCreationHandler(object):
This should only be run on the instance in charge of persisting events.
"""
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
assert self._is_event_writer
if ratelimit:
# We check if this is a room admin redacting an event so that we

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
from twisted.python.failure import Failure
@ -30,6 +30,10 @@ from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -68,7 +72,7 @@ class PaginationHandler(object):
paginating during a purge.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@ -78,9 +82,9 @@ class PaginationHandler(object):
self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
self._purges_in_progress_by_room = set() # type: Set[str]
# map from purge id to PurgeStatus
self._purges_by_id = {}
self._purges_by_id = {} # type: Dict[str, PurgeStatus]
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
@ -102,7 +106,9 @@ class PaginationHandler(object):
job["longest_max_lifetime"],
)
async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int]
):
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@ -110,10 +116,10 @@ class PaginationHandler(object):
retention policy.
Args:
min_ms (int|None): Duration in milliseconds that define the lower limit of
min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, it means that the range has no
lower limit.
max_ms (int|None): Duration in milliseconds that define the upper limit of
max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, it means that the range has no
upper limit.
"""
@ -220,18 +226,19 @@ class PaginationHandler(object):
"_purge_history", self._purge_history, purge_id, room_id, token, True,
)
def start_purge_history(self, room_id, token, delete_local_events=False):
def start_purge_history(
self, room_id: str, token: str, delete_local_events: bool = False
) -> str:
"""Start off a history purge on a room.
Args:
room_id (str): The room to purge from
token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
room_id: The room to purge from
token: topological token to delete events before
delete_local_events: True to delete local events as well as
remote ones
Returns:
str: unique ID for this purge transaction.
unique ID for this purge transaction.
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
@ -284,14 +291,11 @@ class PaginationHandler(object):
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
"""Get the current status of an active purge
Args:
purge_id (str): purge_id returned by start_purge_history
Returns:
PurgeStatus|None
purge_id: purge_id returned by start_purge_history
"""
return self._purges_by_id.get(purge_id)
@ -312,8 +316,8 @@ class PaginationHandler(object):
async def get_messages(
self,
requester: Requester,
room_id: Optional[str] = None,
pagin_config: Optional[PaginationConfig] = None,
room_id: str,
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
) -> Dict[str, Any]:
@ -368,11 +372,15 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
# This is only None if the room is world_readable, in which
# case "JOIN" would have been returned.
assert member_event_id
leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
if RoomStreamToken.parse(leave_token).topological < max_topo:
source_config.from_key = str(leave_token)
await self.hs.get_handlers().federation_handler.maybe_backfill(
@ -419,8 +427,8 @@ class PaginationHandler(object):
)
if state_ids:
state = await self.store.get_events(list(state_ids.values()))
state = state.values()
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
time_now = self.clock.time_msec()

View file

@ -804,9 +804,7 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
last_stream_id,
self.hs.config.worker.writers.events, "events", last_stream_id
)
return result, last_stream_id
@ -1262,10 +1260,10 @@ class RoomShutdownHandler(object):
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
"events",
stream_id,
self.hs.config.worker.writers.events, "events", stream_id
)
else:
new_room_id = None
@ -1295,9 +1293,7 @@ class RoomShutdownHandler(object):
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
stream_id,
self.hs.config.worker.writers.events, "events", stream_id
)
await self.room_member_handler.forget(target_requester.user, room_id)

View file

@ -83,6 +83,13 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,

View file

@ -65,11 +65,10 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
async def _serialize_payload(store, event_and_contexts, backfilled):
"""
Args:
store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
@ -89,11 +88,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
}
)
payload = {
"events": event_payloads,
"backfilled": backfilled,
"room_id": room_id,
}
payload = {"events": event_payloads, "backfilled": backfilled}
return payload
@ -101,7 +96,6 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
room_id = content["room_id"]
backfilled = content["backfilled"]
event_payloads = content["events"]
@ -126,7 +120,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled
event_and_contexts, backfilled
)
return 200, {"max_stream_id": max_stream_id}

View file

@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if hs.get_instance_name() in hs.config.worker.writers.events:
if hs.config.worker.writers.events == hs.get_instance_name():
self._streams_to_replicate.append(stream)
continue

View file

@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self._store._stream_id_gen.get_current_token_for_writer,
current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)

View file

@ -68,7 +68,7 @@ class Databases(object):
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events:
if hs.config.worker.writers.events == hs.get_instance_name():
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:

View file

@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])
# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
# cross-signing sigs on this device.
# dict from (signing user_id)->(signing device_id)->sig
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
class EndToEndKeyWorkerStore(SQLBaseStore):
@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
included.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
# get the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
(user_id, device_id)
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None
)
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
target_device_signatures = target_device_result.signatures
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature
log_kv(result)
return result
def _get_e2e_device_keys_and_signatures_txn(
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
signature_query_clauses = []
signature_query_params = []
if include_all_devices is False:
include_deleted_devices = False
@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)
if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)
signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)
query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)
sql = (
"SELECT user_id, device_id, "
@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
# get signatures on the device
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
Returns signatures made by the owners of the devices.
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
"""
signature_query_clauses = []
signature_query_params = []
for (user_id, device_id) in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signature_query_params.extend([user_id, device_id, user_id])
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)
txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]
target_user_result = result.get(target_user_id)
if not target_user_result:
continue
target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
return result
return txn.fetchall()
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]

View file

@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
raise StoreError(400, "stream_ordering too old")
sql = """
SELECT event_id FROM stream_ordering_to_exterm

View file

@ -97,7 +97,6 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@ -109,7 +108,7 @@ class PersistEventsStore:
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@ -801,7 +800,6 @@ class PersistEventsStore:
table="events",
values=[
{
"instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,

View file

@ -42,8 +42,7 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@ -79,54 +78,27 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering"
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
# Another process is in charge of persisting events and generating
# stream IDs: rely on the replication streams to let us know which
# IDs we can process.
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache(
"*getEvent*",

View file

@ -1,16 +0,0 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE events ADD COLUMN instance_name TEXT;

View file

@ -1,26 +0,0 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
SELECT setval('events_stream_seq', (
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
));
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
SELECT setval('events_backfill_stream_seq', (
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
));

View file

@ -231,12 +231,8 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
#
# We start at 1 here as a) the first generated stream ID will be 2, and
# b) other parts of the code assume that stream IDs are strictly greater
# than 0.
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 1
min(self._current_positions.values()) if self._current_positions else 0
)
self._known_persisted_positions = [] # type: List[int]
@ -366,7 +362,9 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
return self.get_persisted_upto_position()
# Currently we don't support this operation, as it's not obvious how to
# condense the stream positions of multiple writers into a single int.
raise NotImplementedError()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

View file

@ -18,11 +18,10 @@ import functools
import inspect
import logging
import threading
from typing import Any, Tuple, Union, cast
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary
from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer
@ -38,8 +37,10 @@ logger = logging.getLogger(__name__)
CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Protocol):
class _CachedFunction(Generic[F]):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
@ -47,8 +48,11 @@ class _CachedFunction(Protocol):
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
__name__ = None # type: str
# Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
__call__ = None # type: F
cache_pending_metric = Gauge(
@ -123,7 +127,7 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
@ -662,9 +666,13 @@ class _CacheContext:
def cached(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
max_entries: int = 1000,
num_args: Optional[int] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@ -673,8 +681,12 @@ def cached(
iterable=iterable,
)
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(cached_method_name, list_name, num_args=None):
def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
cache.
Args:
cached_method_name (str): The name of the single-item lookup method.
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
list_name: The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
Example:
@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
def batch_do_something(self, first_arg, second_args):
...
"""
return lambda orig: CacheListDescriptor(
func = lambda orig: CacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
)
return cast(Callable[[F], _CachedFunction[F]], func)