Merge branch 'develop' into matrix-org-hotfixes
This commit is contained in:
commit
cc23d81a74
|
@ -1 +0,0 @@
|
|||
Add experimental support for sharding event persister.
|
1
changelog.d/8233.misc
Normal file
1
changelog.d/8233.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
1
changelog.d/8240.misc
Normal file
1
changelog.d/8240.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix type hints for functions decorated with `@cached`.
|
1
changelog.d/8242.feature
Normal file
1
changelog.d/8242.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Back out experimental support for sharding event persister. **PLEASE REMOVE THIS LINE FROM THE FINAL CHANGELOG**
|
1
changelog.d/8244.misc
Normal file
1
changelog.d/8244.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to pagination, initial sync and events handlers.
|
1
changelog.d/8245.misc
Normal file
1
changelog.d/8245.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove obsolete `order` field from federation send queues.
|
6
mypy.ini
6
mypy.ini
|
@ -1,6 +1,6 @@
|
|||
[mypy]
|
||||
namespace_packages = True
|
||||
plugins = mypy_zope:plugin
|
||||
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
|
||||
follow_imports = silent
|
||||
check_untyped_defs = True
|
||||
show_error_codes = True
|
||||
|
@ -17,10 +17,13 @@ files =
|
|||
synapse/handlers/auth.py,
|
||||
synapse/handlers/cas_handler.py,
|
||||
synapse/handlers/directory.py,
|
||||
synapse/handlers/events.py,
|
||||
synapse/handlers/federation.py,
|
||||
synapse/handlers/identity.py,
|
||||
synapse/handlers/initial_sync.py,
|
||||
synapse/handlers/message.py,
|
||||
synapse/handlers/oidc_handler.py,
|
||||
synapse/handlers/pagination.py,
|
||||
synapse/handlers/presence.py,
|
||||
synapse/handlers/room.py,
|
||||
synapse/handlers/room_member.py,
|
||||
|
@ -51,6 +54,7 @@ files =
|
|||
synapse/storage/util,
|
||||
synapse/streams,
|
||||
synapse/types.py,
|
||||
synapse/util/caches/descriptors.py,
|
||||
synapse/util/caches/stream_change_cache.py,
|
||||
synapse/util/metrics.py,
|
||||
tests/replication,
|
||||
|
|
85
scripts-dev/mypy_synapse_plugin.py
Normal file
85
scripts-dev/mypy_synapse_plugin.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This is a mypy plugin for Synpase to deal with some of the funky typing that
|
||||
can crop up, e.g the cache descriptors.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from mypy.plugin import MethodSigContext, Plugin
|
||||
from mypy.typeops import bind_self
|
||||
from mypy.types import CallableType
|
||||
|
||||
|
||||
class SynapsePlugin(Plugin):
|
||||
def get_method_signature_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||
if fullname.startswith(
|
||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||
):
|
||||
return cached_function_method_signature
|
||||
return None
|
||||
|
||||
|
||||
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
"""Fixes the `_CachedFunction.__call__` signature to be correct.
|
||||
|
||||
It already has *almost* the correct signature, except:
|
||||
|
||||
1. the `self` argument needs to be marked as "bound"; and
|
||||
2. any `cache_context` argument should be removed.
|
||||
"""
|
||||
|
||||
# First we mark this as a bound function signature.
|
||||
signature = bind_self(ctx.default_signature)
|
||||
|
||||
# Secondly, we remove any "cache_context" args.
|
||||
#
|
||||
# Note: We should be only doing this if `cache_context=True` is set, but if
|
||||
# it isn't then the code will raise an exception when its called anyway, so
|
||||
# its not the end of the world.
|
||||
context_arg_index = None
|
||||
for idx, name in enumerate(signature.arg_names):
|
||||
if name == "cache_context":
|
||||
context_arg_index = idx
|
||||
break
|
||||
|
||||
if context_arg_index:
|
||||
arg_types = list(signature.arg_types)
|
||||
arg_types.pop(context_arg_index)
|
||||
|
||||
arg_names = list(signature.arg_names)
|
||||
arg_names.pop(context_arg_index)
|
||||
|
||||
arg_kinds = list(signature.arg_kinds)
|
||||
arg_kinds.pop(context_arg_index)
|
||||
|
||||
signature = signature.copy_modified(
|
||||
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
|
||||
)
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def plugin(version: str):
|
||||
# This is the entry point of the plugin, and let's us deal with the fact
|
||||
# that the mypy plugin interface is *not* stable by looking at the version
|
||||
# string.
|
||||
#
|
||||
# However, since we pin the version of mypy Synapse uses in CI, we don't
|
||||
# really care.
|
||||
return SynapsePlugin
|
|
@ -832,26 +832,11 @@ class ShardedWorkerHandlingConfig:
|
|||
def should_handle(self, instance_name: str, key: str) -> bool:
|
||||
"""Whether this instance is responsible for handling the given key.
|
||||
"""
|
||||
# If multiple instances are not defined we always return true
|
||||
|
||||
# If multiple instances are not defined we always return true.
|
||||
if not self.instances or len(self.instances) == 1:
|
||||
return True
|
||||
|
||||
return self.get_instance(key) == instance_name
|
||||
|
||||
def get_instance(self, key: str) -> str:
|
||||
"""Get the instance responsible for handling the given key.
|
||||
|
||||
Note: For things like federation sending the config for which instance
|
||||
is sending is known only to the sender instance if there is only one.
|
||||
Therefore `should_handle` should be used where possible.
|
||||
"""
|
||||
|
||||
if not self.instances:
|
||||
return "master"
|
||||
|
||||
if len(self.instances) == 1:
|
||||
return self.instances[0]
|
||||
|
||||
# We shard by taking the hash, modulo it by the number of instances and
|
||||
# then checking whether this instance matches the instance at that
|
||||
# index.
|
||||
|
@ -861,7 +846,7 @@ class ShardedWorkerHandlingConfig:
|
|||
dest_hash = sha256(key.encode("utf8")).digest()
|
||||
dest_int = int.from_bytes(dest_hash, byteorder="little")
|
||||
remainder = dest_int % (len(self.instances))
|
||||
return self.instances[remainder]
|
||||
return self.instances[remainder] == instance_name
|
||||
|
||||
|
||||
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
||||
|
|
|
@ -142,4 +142,3 @@ class ShardedWorkerHandlingConfig:
|
|||
instances: List[str]
|
||||
def __init__(self, instances: List[str]) -> None: ...
|
||||
def should_handle(self, instance_name: str, key: str) -> bool: ...
|
||||
def get_instance(self, key: str) -> str: ...
|
||||
|
|
|
@ -13,24 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
|
||||
from .server import ListenerConfig, parse_listener_def
|
||||
|
||||
|
||||
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
||||
"""Helper for allowing parsing a string or list of strings to a config
|
||||
option expecting a list of strings.
|
||||
"""
|
||||
|
||||
if isinstance(obj, str):
|
||||
return [obj]
|
||||
return obj
|
||||
|
||||
|
||||
@attr.s
|
||||
class InstanceLocationConfig:
|
||||
"""The host and port to talk to an instance via HTTP replication.
|
||||
|
@ -45,13 +33,11 @@ class WriterLocations:
|
|||
"""Specifies the instances that write various streams.
|
||||
|
||||
Attributes:
|
||||
events: The instances that write to the event and backfill streams.
|
||||
typing: The instance that writes to the typing stream.
|
||||
events: The instance that writes to the event and backfill streams.
|
||||
events: The instance that writes to the typing stream.
|
||||
"""
|
||||
|
||||
events = attr.ib(
|
||||
default=["master"], type=List[str], converter=_instance_to_list_converter
|
||||
)
|
||||
events = attr.ib(default="master", type=str)
|
||||
typing = attr.ib(default="master", type=str)
|
||||
|
||||
|
||||
|
@ -119,18 +105,15 @@ class WorkerConfig(Config):
|
|||
writers = config.get("stream_writers") or {}
|
||||
self.writers = WriterLocations(**writers)
|
||||
|
||||
# Check that the configured writers for events and typing also appears in
|
||||
# Check that the configured writer for events and typing also appears in
|
||||
# `instance_map`.
|
||||
for stream in ("events", "typing"):
|
||||
instances = _instance_to_list_converter(getattr(self.writers, stream))
|
||||
for instance in instances:
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
)
|
||||
|
||||
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
|
||||
instance = getattr(self.writers, stream)
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
|
|
|
@ -108,8 +108,6 @@ class FederationSender(object):
|
|||
),
|
||||
)
|
||||
|
||||
self._order = 1
|
||||
|
||||
self._is_processing = False
|
||||
self._last_poked_id = -1
|
||||
|
||||
|
@ -290,9 +288,6 @@ class FederationSender(object):
|
|||
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||
# table and we'll get back to it later.
|
||||
|
||||
order = self._order
|
||||
self._order += 1
|
||||
|
||||
destinations = set(destinations)
|
||||
destinations.discard(self.server_name)
|
||||
logger.debug("Sending to: %s", str(destinations))
|
||||
|
@ -304,7 +299,7 @@ class FederationSender(object):
|
|||
sent_pdus_destination_dist_count.inc()
|
||||
|
||||
for destination in destinations:
|
||||
self._get_per_destination_queue(destination).send_pdu(pdu, order)
|
||||
self._get_per_destination_queue(destination).send_pdu(pdu)
|
||||
|
||||
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
|
||||
"""Send a RR to any other servers in the room
|
||||
|
|
|
@ -95,8 +95,8 @@ class PerDestinationQueue(object):
|
|||
self._destination = destination
|
||||
self.transmission_loop_running = False
|
||||
|
||||
# a list of tuples of (pending pdu, order)
|
||||
self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
|
||||
# a list of pending PDUs
|
||||
self._pending_pdus = [] # type: List[EventBase]
|
||||
|
||||
# XXX this is never actually used: see
|
||||
# https://github.com/matrix-org/synapse/issues/7549
|
||||
|
@ -135,14 +135,13 @@ class PerDestinationQueue(object):
|
|||
+ len(self._pending_edus_keyed)
|
||||
)
|
||||
|
||||
def send_pdu(self, pdu: EventBase, order: int) -> None:
|
||||
def send_pdu(self, pdu: EventBase) -> None:
|
||||
"""Add a PDU to the queue, and start the transmission loop if necessary
|
||||
|
||||
Args:
|
||||
pdu: pdu to send
|
||||
order
|
||||
"""
|
||||
self._pending_pdus.append((pdu, order))
|
||||
self._pending_pdus.append(pdu)
|
||||
self.attempt_new_transaction()
|
||||
|
||||
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
|
||||
|
@ -188,7 +187,7 @@ class PerDestinationQueue(object):
|
|||
returns immediately. Otherwise kicks off the process of sending a
|
||||
transaction in the background.
|
||||
"""
|
||||
# list of (pending_pdu, deferred, order)
|
||||
|
||||
if self.transmission_loop_running:
|
||||
# XXX: this can get stuck on by a never-ending
|
||||
# request at which point pending_pdus just keeps growing.
|
||||
|
@ -213,7 +212,7 @@ class PerDestinationQueue(object):
|
|||
)
|
||||
|
||||
async def _transaction_transmission_loop(self) -> None:
|
||||
pending_pdus = [] # type: List[Tuple[EventBase, int]]
|
||||
pending_pdus = [] # type: List[EventBase]
|
||||
try:
|
||||
self.transmission_loop_running = True
|
||||
|
||||
|
@ -388,13 +387,13 @@ class PerDestinationQueue(object):
|
|||
"TX [%s] Failed to send transaction: %s", self._destination, e
|
||||
)
|
||||
|
||||
for p, _ in pending_pdus:
|
||||
for p in pending_pdus:
|
||||
logger.info(
|
||||
"Failed to send event %s to %s", p.event_id, self._destination
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("TX [%s] Failed to send transaction", self._destination)
|
||||
for p, _ in pending_pdus:
|
||||
for p in pending_pdus:
|
||||
logger.info(
|
||||
"Failed to send event %s to %s", p.event_id, self._destination
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.events import EventBase
|
||||
|
@ -57,11 +57,17 @@ class TransactionManager(object):
|
|||
|
||||
@measure_func("_send_new_transaction")
|
||||
async def send_new_transaction(
|
||||
self,
|
||||
destination: str,
|
||||
pending_pdus: List[Tuple[EventBase, int]],
|
||||
pending_edus: List[Edu],
|
||||
):
|
||||
self, destination: str, pdus: List[EventBase], edus: List[Edu],
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
destination: The destination to send to (e.g. 'example.org')
|
||||
pdus: In-order list of PDUs to send
|
||||
edus: List of EDUs to send
|
||||
|
||||
Returns:
|
||||
True iff the transaction was successful
|
||||
"""
|
||||
|
||||
# Make a transaction-sending opentracing span. This span follows on from
|
||||
# all the edus in that transaction. This needs to be done since there is
|
||||
|
@ -71,7 +77,7 @@ class TransactionManager(object):
|
|||
span_contexts = []
|
||||
keep_destination = whitelisted_homeserver(destination)
|
||||
|
||||
for edu in pending_edus:
|
||||
for edu in edus:
|
||||
context = edu.get_context()
|
||||
if context:
|
||||
span_contexts.append(extract_text_map(json_decoder.decode(context)))
|
||||
|
@ -79,12 +85,6 @@ class TransactionManager(object):
|
|||
edu.strip_context()
|
||||
|
||||
with start_active_span_follows_from("send_transaction", span_contexts):
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[1])
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = pending_edus
|
||||
|
||||
success = True
|
||||
|
||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||
|
|
|
@ -15,29 +15,30 @@
|
|||
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.types import UserID
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventStreamHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(EventStreamHandler, self).__init__(hs)
|
||||
|
||||
# Count of active streams per user
|
||||
self._streams_per_user = {}
|
||||
# Grace timers per user to delay the "stopped" signal
|
||||
self._stop_timer_per_user = {}
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("started_user_eventstream")
|
||||
self.distributor.declare("stopped_user_eventstream")
|
||||
|
@ -52,14 +53,14 @@ class EventStreamHandler(BaseHandler):
|
|||
@log_function
|
||||
async def get_stream(
|
||||
self,
|
||||
auth_user_id,
|
||||
pagin_config,
|
||||
timeout=0,
|
||||
as_client_event=True,
|
||||
affect_presence=True,
|
||||
room_id=None,
|
||||
is_guest=False,
|
||||
):
|
||||
auth_user_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
timeout: int = 0,
|
||||
as_client_event: bool = True,
|
||||
affect_presence: bool = True,
|
||||
room_id: Optional[str] = None,
|
||||
is_guest: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Fetches the events stream for a given user.
|
||||
"""
|
||||
|
||||
|
@ -98,7 +99,7 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
# When the user joins a new room, or another user joins a currently
|
||||
# joined room, we need to send down presence for those users.
|
||||
to_add = []
|
||||
to_add = [] # type: List[JsonDict]
|
||||
for event in events:
|
||||
if not isinstance(event, EventBase):
|
||||
continue
|
||||
|
@ -110,7 +111,7 @@ class EventStreamHandler(BaseHandler):
|
|||
# Send down presence for everyone in the room.
|
||||
users = await self.state.get_current_users_in_room(
|
||||
event.room_id
|
||||
)
|
||||
) # type: Iterable[str]
|
||||
else:
|
||||
users = [event.state_key]
|
||||
|
||||
|
@ -144,20 +145,22 @@ class EventStreamHandler(BaseHandler):
|
|||
|
||||
|
||||
class EventHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(EventHandler, self).__init__(hs)
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
async def get_event(self, user, room_id, event_id):
|
||||
async def get_event(
|
||||
self, user: UserID, room_id: Optional[str], event_id: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Retrieve a single specified event.
|
||||
|
||||
Args:
|
||||
user (synapse.types.UserID): The user requesting the event
|
||||
room_id (str|None): The expected room id. We'll return None if the
|
||||
user: The user requesting the event
|
||||
room_id: The expected room id. We'll return None if the
|
||||
event's room does not match.
|
||||
event_id (str): The event ID to obtain.
|
||||
event_id: The event ID to obtain.
|
||||
Returns:
|
||||
dict: An event, or None if there is no event matching this ID.
|
||||
An event, or None if there is no event matching this ID.
|
||||
Raises:
|
||||
SynapseError if there was a problem retrieving this event, or
|
||||
AuthError if the user does not have the rights to inspect this
|
||||
|
|
|
@ -440,11 +440,11 @@ class FederationHandler(BaseHandler):
|
|||
if not prevs - seen:
|
||||
return
|
||||
|
||||
latest = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest_list = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest = set(latest_list)
|
||||
latest |= seen
|
||||
|
||||
logger.info(
|
||||
|
@ -781,7 +781,7 @@ class FederationHandler(BaseHandler):
|
|||
# keys across all devices.
|
||||
current_keys = [
|
||||
key
|
||||
for device in cached_devices
|
||||
for device in cached_devices.values()
|
||||
for key in device.get("keys", {}).get("keys", {}).values()
|
||||
]
|
||||
|
||||
|
@ -923,8 +923,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
)
|
||||
|
||||
if ev_infos:
|
||||
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
|
||||
# Step 2: Persist the rest of the events in the chunk one by one
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
@ -1217,7 +1216,7 @@ class FederationHandler(BaseHandler):
|
|||
event_infos.append(_NewEventInfo(event, None, auth))
|
||||
|
||||
await self._handle_new_events(
|
||||
destination, room_id, event_infos,
|
||||
destination, event_infos,
|
||||
)
|
||||
|
||||
def _sanity_check_event(self, ev):
|
||||
|
@ -1364,15 +1363,15 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
max_stream_id = await self._persist_auth_tree(
|
||||
origin, room_id, auth_chain, state, event, room_version_obj
|
||||
origin, auth_chain, state, event, room_version_obj
|
||||
)
|
||||
|
||||
# We wait here until this instance has seen the events come down
|
||||
# replication (if we're using replication) as the below uses caches.
|
||||
#
|
||||
# TODO: Currently the events stream is written to from master
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
max_stream_id,
|
||||
self.config.worker.writers.events, "events", max_stream_id
|
||||
)
|
||||
|
||||
# Check whether this room is the result of an upgrade of a room we already know
|
||||
|
@ -1626,7 +1625,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
await self.persist_events_and_notify(event.room_id, [(event, context)])
|
||||
await self.persist_events_and_notify([(event, context)])
|
||||
|
||||
return event
|
||||
|
||||
|
@ -1653,9 +1652,7 @@ class FederationHandler(BaseHandler):
|
|||
await self.federation_client.send_leave(host_list, event)
|
||||
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
stream_id = await self.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
stream_id = await self.persist_events_and_notify([(event, context)])
|
||||
|
||||
return event, stream_id
|
||||
|
||||
|
@ -1903,7 +1900,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
[(event, context)], backfilled=backfilled
|
||||
)
|
||||
except Exception:
|
||||
run_in_background(
|
||||
|
@ -1916,7 +1913,6 @@ class FederationHandler(BaseHandler):
|
|||
async def _handle_new_events(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
event_infos: Iterable[_NewEventInfo],
|
||||
backfilled: bool = False,
|
||||
) -> None:
|
||||
|
@ -1948,7 +1944,6 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
room_id,
|
||||
[
|
||||
(ev_info.event, context)
|
||||
for ev_info, context in zip(event_infos, contexts)
|
||||
|
@ -1959,7 +1954,6 @@ class FederationHandler(BaseHandler):
|
|||
async def _persist_auth_tree(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
auth_events: List[EventBase],
|
||||
state: List[EventBase],
|
||||
event: EventBase,
|
||||
|
@ -1974,7 +1968,6 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
Args:
|
||||
origin: Where the events came from
|
||||
room_id,
|
||||
auth_events
|
||||
state
|
||||
event
|
||||
|
@ -2049,20 +2042,17 @@ class FederationHandler(BaseHandler):
|
|||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
room_id,
|
||||
[
|
||||
(e, events_to_context[e.event_id])
|
||||
for e in itertools.chain(auth_events, state)
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
new_event_context = await self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
return await self.persist_events_and_notify(
|
||||
room_id, [(event, new_event_context)]
|
||||
)
|
||||
return await self.persist_events_and_notify([(event, new_event_context)])
|
||||
|
||||
async def _prep_event(
|
||||
self,
|
||||
|
@ -2119,8 +2109,8 @@ class FederationHandler(BaseHandler):
|
|||
if backfilled or event.internal_metadata.is_outlier():
|
||||
return
|
||||
|
||||
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||
extrem_ids = set(extrem_ids)
|
||||
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||
extrem_ids = set(extrem_ids_list)
|
||||
prev_event_ids = set(event.prev_event_ids())
|
||||
|
||||
if extrem_ids == prev_event_ids:
|
||||
|
@ -2913,7 +2903,6 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
async def persist_events_and_notify(
|
||||
self,
|
||||
room_id: str,
|
||||
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
) -> int:
|
||||
|
@ -2921,19 +2910,14 @@ class FederationHandler(BaseHandler):
|
|||
necessary.
|
||||
|
||||
Args:
|
||||
room_id: The room ID of events being persisted.
|
||||
event_and_contexts: Sequence of events with their associated
|
||||
context that should be persisted. All events must belong to
|
||||
the same room.
|
||||
event_and_contexts:
|
||||
backfilled: Whether these events are a result of
|
||||
backfilling or not
|
||||
"""
|
||||
instance = self.config.worker.events_shard_config.get_instance(room_id)
|
||||
if instance != self._instance_name:
|
||||
if self.config.worker.writers.events != self._instance_name:
|
||||
result = await self._send_events(
|
||||
instance_name=instance,
|
||||
instance_name=self.config.worker.writers.events,
|
||||
store=self.store,
|
||||
room_id=room_id,
|
||||
event_and_contexts=event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import StreamToken, UserID
|
||||
from synapse.types import JsonDict, Requester, StreamToken, UserID
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
|
|||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InitialSyncHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(InitialSyncHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
|
@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
def snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
include_archived=False,
|
||||
):
|
||||
user_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
include_archived: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Retrieve a snapshot of all rooms the user is invited or has joined.
|
||||
|
||||
This snapshot may include messages for all rooms where the user is
|
||||
joined, depending on the pagination config.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user making the request.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config used to determine how many messages *PER ROOM* to return.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
include_archived (bool): True to get rooms that the user has left
|
||||
user_id: The ID of the user making the request.
|
||||
pagin_config: The pagination config used to determine how many
|
||||
messages *PER ROOM* to return.
|
||||
as_client_event: True to get events in client-server format.
|
||||
include_archived: True to get rooms that the user has left
|
||||
Returns:
|
||||
A list of dicts with "room_id" and "membership" keys for all rooms
|
||||
the user is currently invited or joined in on. Rooms where the user
|
||||
is joined on, may return a "messages" key with messages, depending
|
||||
on the specified PaginationConfig.
|
||||
A JsonDict with the same format as the response to `/intialSync`
|
||||
API
|
||||
"""
|
||||
key = (
|
||||
user_id,
|
||||
|
@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
async def _snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
include_archived=False,
|
||||
):
|
||||
user_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
include_archived: bool = False,
|
||||
) -> JsonDict:
|
||||
|
||||
memberships = [Membership.INVITE, Membership.JOIN]
|
||||
if include_archived:
|
||||
|
@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
async def handle_room(event):
|
||||
async def handle_room(event: RoomsForUser):
|
||||
d = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
|
@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
return ret
|
||||
|
||||
async def room_initial_sync(self, requester, room_id, pagin_config=None):
|
||||
async def room_initial_sync(
|
||||
self, requester: Requester, room_id: str, pagin_config: PaginationConfig
|
||||
) -> JsonDict:
|
||||
"""Capture the a snapshot of a room. If user is currently a member of
|
||||
the room this will be what is currently in the room. If the user left
|
||||
the room this will be what was in the room when they left.
|
||||
|
||||
Args:
|
||||
requester(Requester): The user to get a snapshot for.
|
||||
room_id(str): The room to get a snapshot of.
|
||||
pagin_config(synapse.streams.config.PaginationConfig):
|
||||
The pagination config used to determine how many messages to
|
||||
return.
|
||||
requester: The user to get a snapshot for.
|
||||
room_id: The room to get a snapshot of.
|
||||
pagin_config: The pagination config used to determine how many
|
||||
messages to return.
|
||||
Raises:
|
||||
AuthError if the user wasn't in the room.
|
||||
Returns:
|
||||
|
@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler):
|
|||
return result
|
||||
|
||||
async def _room_initial_sync_parted(
|
||||
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
membership: Membership,
|
||||
member_event_id: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
room_state = await self.state_store.get_state_for_events([member_event_id])
|
||||
|
||||
room_state = room_state[member_event_id]
|
||||
|
@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler):
|
|||
}
|
||||
|
||||
async def _room_initial_sync_joined(
|
||||
self, user_id, room_id, pagin_config, membership, is_peeking
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
membership: Membership,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
current_state = await self.state.get_current_state(room_id=room_id)
|
||||
|
||||
# TODO: These concurrently
|
||||
|
|
|
@ -376,8 +376,9 @@ class EventCreationHandler(object):
|
|||
self.notifier = hs.get_notifier()
|
||||
self.config = hs.config
|
||||
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
|
||||
self._events_shard_config = self.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._is_event_writer = (
|
||||
self.config.worker.writers.events == hs.get_instance_name()
|
||||
)
|
||||
|
||||
self.room_invite_state_types = self.hs.config.room_invite_state_types
|
||||
|
||||
|
@ -905,10 +906,9 @@ class EventCreationHandler(object):
|
|||
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
||||
if writer_instance != self._instance_name:
|
||||
if not self._is_event_writer:
|
||||
result = await self.send_event(
|
||||
instance_name=writer_instance,
|
||||
instance_name=self.config.worker.writers.events,
|
||||
event_id=event.event_id,
|
||||
store=self.store,
|
||||
requester=requester,
|
||||
|
@ -976,9 +976,7 @@ class EventCreationHandler(object):
|
|||
|
||||
This should only be run on the instance in charge of persisting events.
|
||||
"""
|
||||
assert self._events_shard_config.should_handle(
|
||||
self._instance_name, event.room_id
|
||||
)
|
||||
assert self._is_event_writer
|
||||
|
||||
if ratelimit:
|
||||
# We check if this is a room admin redacting an event so that we
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
@ -30,6 +30,10 @@ from synapse.util.async_helpers import ReadWriteLock
|
|||
from synapse.util.stringutils import random_string
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -68,7 +72,7 @@ class PaginationHandler(object):
|
|||
paginating during a purge.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -78,9 +82,9 @@ class PaginationHandler(object):
|
|||
self._server_name = hs.hostname
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
self._purges_in_progress_by_room = set()
|
||||
self._purges_in_progress_by_room = set() # type: Set[str]
|
||||
# map from purge id to PurgeStatus
|
||||
self._purges_by_id = {}
|
||||
self._purges_by_id = {} # type: Dict[str, PurgeStatus]
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
||||
|
@ -102,7 +106,9 @@ class PaginationHandler(object):
|
|||
job["longest_max_lifetime"],
|
||||
)
|
||||
|
||||
async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
|
||||
async def purge_history_for_rooms_in_range(
|
||||
self, min_ms: Optional[int], max_ms: Optional[int]
|
||||
):
|
||||
"""Purge outdated events from rooms within the given retention range.
|
||||
|
||||
If a default retention policy is defined in the server's configuration and its
|
||||
|
@ -110,10 +116,10 @@ class PaginationHandler(object):
|
|||
retention policy.
|
||||
|
||||
Args:
|
||||
min_ms (int|None): Duration in milliseconds that define the lower limit of
|
||||
min_ms: Duration in milliseconds that define the lower limit of
|
||||
the range to handle (exclusive). If None, it means that the range has no
|
||||
lower limit.
|
||||
max_ms (int|None): Duration in milliseconds that define the upper limit of
|
||||
max_ms: Duration in milliseconds that define the upper limit of
|
||||
the range to handle (inclusive). If None, it means that the range has no
|
||||
upper limit.
|
||||
"""
|
||||
|
@ -220,18 +226,19 @@ class PaginationHandler(object):
|
|||
"_purge_history", self._purge_history, purge_id, room_id, token, True,
|
||||
)
|
||||
|
||||
def start_purge_history(self, room_id, token, delete_local_events=False):
|
||||
def start_purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool = False
|
||||
) -> str:
|
||||
"""Start off a history purge on a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The room to purge from
|
||||
|
||||
token (str): topological token to delete events before
|
||||
delete_local_events (bool): True to delete local events as well as
|
||||
room_id: The room to purge from
|
||||
token: topological token to delete events before
|
||||
delete_local_events: True to delete local events as well as
|
||||
remote ones
|
||||
|
||||
Returns:
|
||||
str: unique ID for this purge transaction.
|
||||
unique ID for this purge transaction.
|
||||
"""
|
||||
if room_id in self._purges_in_progress_by_room:
|
||||
raise SynapseError(
|
||||
|
@ -284,14 +291,11 @@ class PaginationHandler(object):
|
|||
|
||||
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
|
||||
|
||||
def get_purge_status(self, purge_id):
|
||||
def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
|
||||
"""Get the current status of an active purge
|
||||
|
||||
Args:
|
||||
purge_id (str): purge_id returned by start_purge_history
|
||||
|
||||
Returns:
|
||||
PurgeStatus|None
|
||||
purge_id: purge_id returned by start_purge_history
|
||||
"""
|
||||
return self._purges_by_id.get(purge_id)
|
||||
|
||||
|
@ -312,8 +316,8 @@ class PaginationHandler(object):
|
|||
async def get_messages(
|
||||
self,
|
||||
requester: Requester,
|
||||
room_id: Optional[str] = None,
|
||||
pagin_config: Optional[PaginationConfig] = None,
|
||||
room_id: str,
|
||||
pagin_config: PaginationConfig,
|
||||
as_client_event: bool = True,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
|
@ -368,11 +372,15 @@ class PaginationHandler(object):
|
|||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
|
||||
# This is only None if the room is world_readable, in which
|
||||
# case "JOIN" would have been returned.
|
||||
assert member_event_id
|
||||
|
||||
leave_token = await self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
if RoomStreamToken.parse(leave_token).topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
await self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
|
@ -419,8 +427,8 @@ class PaginationHandler(object):
|
|||
)
|
||||
|
||||
if state_ids:
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
state = state.values()
|
||||
state_dict = await self.store.get_events(list(state_ids.values()))
|
||||
state = state_dict.values()
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -804,9 +804,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
# Always wait for room creation to progate before returning
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
last_stream_id,
|
||||
self.hs.config.worker.writers.events, "events", last_stream_id
|
||||
)
|
||||
|
||||
return result, last_stream_id
|
||||
|
@ -1262,10 +1260,10 @@ class RoomShutdownHandler(object):
|
|||
# We now wait for the create room to come back in via replication so
|
||||
# that we can assume that all the joins/invites have propogated before
|
||||
# we try and auto join below.
|
||||
#
|
||||
# TODO: Currently the events stream is written to from master
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
|
||||
"events",
|
||||
stream_id,
|
||||
self.hs.config.worker.writers.events, "events", stream_id
|
||||
)
|
||||
else:
|
||||
new_room_id = None
|
||||
|
@ -1295,9 +1293,7 @@ class RoomShutdownHandler(object):
|
|||
|
||||
# Wait for leave to come in over replication before trying to forget.
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
stream_id,
|
||||
self.hs.config.worker.writers.events, "events", stream_id
|
||||
)
|
||||
|
||||
await self.room_member_handler.forget(target_requester.user, room_id)
|
||||
|
|
|
@ -83,6 +83,13 @@ class RoomMemberHandler(object):
|
|||
self._enable_lookup = hs.config.enable_3pid_lookup
|
||||
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
||||
|
||||
self._event_stream_writer_instance = hs.config.worker.writers.events
|
||||
self._is_on_event_persistence_instance = (
|
||||
self._event_stream_writer_instance == hs.get_instance_name()
|
||||
)
|
||||
if self._is_on_event_persistence_instance:
|
||||
self.persist_event_storage = hs.get_storage().persistence
|
||||
|
||||
self._join_rate_limiter_local = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||
|
|
|
@ -65,11 +65,10 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
self.federation_handler = hs.get_handlers().federation_handler
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
|
||||
async def _serialize_payload(store, event_and_contexts, backfilled):
|
||||
"""
|
||||
Args:
|
||||
store
|
||||
room_id (str)
|
||||
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
|
||||
backfilled (bool): Whether or not the events are the result of
|
||||
backfilling
|
||||
|
@ -89,11 +88,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"events": event_payloads,
|
||||
"backfilled": backfilled,
|
||||
"room_id": room_id,
|
||||
}
|
||||
payload = {"events": event_payloads, "backfilled": backfilled}
|
||||
|
||||
return payload
|
||||
|
||||
|
@ -101,7 +96,6 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
room_id = content["room_id"]
|
||||
backfilled = content["backfilled"]
|
||||
|
||||
event_payloads = content["events"]
|
||||
|
@ -126,7 +120,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
logger.info("Got %d events from federation", len(event_and_contexts))
|
||||
|
||||
max_stream_id = await self.federation_handler.persist_events_and_notify(
|
||||
room_id, event_and_contexts, backfilled
|
||||
event_and_contexts, backfilled
|
||||
)
|
||||
|
||||
return 200, {"max_stream_id": max_stream_id}
|
||||
|
|
|
@ -109,7 +109,7 @@ class ReplicationCommandHandler:
|
|||
if isinstance(stream, (EventsStream, BackfillStream)):
|
||||
# Only add EventStream and BackfillStream as a source on the
|
||||
# instance in charge of event persistence.
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
continue
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import List, Tuple, Type
|
|||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
|
@ -117,7 +117,7 @@ class EventsStream(Stream):
|
|||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self._store._stream_id_gen.get_current_token_for_writer,
|
||||
current_token_without_instance(self._store.get_current_events_token),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ class Databases(object):
|
|||
|
||||
# If we're on a process that can persist events also
|
||||
# instantiate a `PersistEventsStore`
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
persist_events = PersistEventsStore(hs, database, main)
|
||||
|
||||
if "state" in database_config.databases:
|
||||
|
|
|
@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
|
|||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import make_in_list_sql_clause
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
|
|||
# key) and "signatures" (a signature of the structure by the ed25519 key)
|
||||
key_json = attr.ib(type=Optional[str])
|
||||
|
||||
# cross-signing sigs
|
||||
signatures = attr.ib(type=Optional[Dict], default=None)
|
||||
# cross-signing sigs on this device.
|
||||
# dict from (signing user_id)->(signing device_id)->sig
|
||||
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
|
@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
"""Fetch a list of device keys, together with their cross-signatures.
|
||||
"""Fetch a list of device keys
|
||||
|
||||
Any cross-signatures made on the keys by the owner of the device are also
|
||||
included.
|
||||
|
||||
Args:
|
||||
query_list: List of pairs of user_ids and device_ids. Device id can be None
|
||||
|
@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys",
|
||||
self._get_e2e_device_keys_and_signatures_txn,
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_list,
|
||||
include_all_devices,
|
||||
include_deleted_devices,
|
||||
)
|
||||
|
||||
# get the (user_id, device_id) tuples to look up cross-signatures for
|
||||
signature_query = (
|
||||
(user_id, device_id)
|
||||
for user_id, dev in result.items()
|
||||
for device_id, d in dev.items()
|
||||
if d is not None
|
||||
)
|
||||
|
||||
for batch in batch_iter(signature_query, 50):
|
||||
cross_sigs_result = await self.db_pool.runInteraction(
|
||||
"get_e2e_cross_signing_signatures",
|
||||
self._get_e2e_cross_signing_signatures_for_devices_txn,
|
||||
batch,
|
||||
)
|
||||
|
||||
# add each cross-signing signature to the correct device in the result dict.
|
||||
for (user_id, key_id, device_id, signature) in cross_sigs_result:
|
||||
target_device_result = result[user_id][device_id]
|
||||
target_device_signatures = target_device_result.signatures
|
||||
|
||||
signing_user_signatures = target_device_signatures.setdefault(
|
||||
user_id, {}
|
||||
)
|
||||
signing_user_signatures[key_id] = signature
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
|
||||
def _get_e2e_device_keys_and_signatures_txn(
|
||||
def _get_e2e_device_keys_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
"""Get information on devices from the database
|
||||
|
||||
The results include the device's keys and self-signatures, but *not* any
|
||||
cross-signing signatures which have been added subsequently (for which, see
|
||||
get_e2e_device_keys_and_signatures)
|
||||
"""
|
||||
query_clauses = []
|
||||
query_params = []
|
||||
signature_query_clauses = []
|
||||
signature_query_params = []
|
||||
|
||||
if include_all_devices is False:
|
||||
include_deleted_devices = False
|
||||
|
@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
for (user_id, device_id) in query_list:
|
||||
query_clause = "user_id = ?"
|
||||
query_params.append(user_id)
|
||||
signature_query_clause = "target_user_id = ?"
|
||||
signature_query_params.append(user_id)
|
||||
|
||||
if device_id is not None:
|
||||
query_clause += " AND device_id = ?"
|
||||
query_params.append(device_id)
|
||||
signature_query_clause += " AND target_device_id = ?"
|
||||
signature_query_params.append(device_id)
|
||||
|
||||
signature_query_clause += " AND user_id = ?"
|
||||
signature_query_params.append(user_id)
|
||||
|
||||
query_clauses.append(query_clause)
|
||||
signature_query_clauses.append(signature_query_clause)
|
||||
|
||||
sql = (
|
||||
"SELECT user_id, device_id, "
|
||||
|
@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
for user_id, device_id in deleted_devices:
|
||||
result.setdefault(user_id, {})[device_id] = None
|
||||
|
||||
# get signatures on the device
|
||||
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
|
||||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""Get cross-signing signatures for a given list of devices
|
||||
|
||||
Returns signatures made by the owners of the devices.
|
||||
|
||||
Returns: a list of results; each entry in the list is a tuple of
|
||||
(user_id, key_id, target_device_id, signature).
|
||||
"""
|
||||
signature_query_clauses = []
|
||||
signature_query_params = []
|
||||
|
||||
for (user_id, device_id) in device_query:
|
||||
signature_query_clauses.append(
|
||||
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
|
||||
)
|
||||
signature_query_params.extend([user_id, device_id, user_id])
|
||||
|
||||
signature_sql = """
|
||||
SELECT user_id, key_id, target_device_id, signature
|
||||
FROM e2e_cross_signing_signatures WHERE %s
|
||||
""" % (
|
||||
" OR ".join("(" + q + ")" for q in signature_query_clauses)
|
||||
)
|
||||
|
||||
txn.execute(signature_sql, signature_query_params)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
# add each cross-signing signature to the correct device in the result dict.
|
||||
for row in rows:
|
||||
signing_user_id = row["user_id"]
|
||||
signing_key_id = row["key_id"]
|
||||
target_user_id = row["target_user_id"]
|
||||
target_device_id = row["target_device_id"]
|
||||
signature = row["signature"]
|
||||
|
||||
target_user_result = result.get(target_user_id)
|
||||
if not target_user_result:
|
||||
continue
|
||||
|
||||
target_device_result = target_user_result.get(target_device_id)
|
||||
if not target_device_result:
|
||||
# note that target_device_result will be None for deleted devices.
|
||||
continue
|
||||
|
||||
target_device_signatures = target_device_result.signatures
|
||||
if target_device_signatures is None:
|
||||
target_device_signatures = target_device_result.signatures = {}
|
||||
|
||||
signing_user_signatures = target_device_signatures.setdefault(
|
||||
signing_user_id, {}
|
||||
)
|
||||
signing_user_signatures[signing_key_id] = signature
|
||||
|
||||
return result
|
||||
return txn.fetchall()
|
||||
|
||||
async def get_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str, key_ids: List[str]
|
||||
|
|
|
@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
"""
|
||||
|
||||
if stream_ordering <= self.stream_ordering_month_ago:
|
||||
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
||||
raise StoreError(400, "stream_ordering too old")
|
||||
|
||||
sql = """
|
||||
SELECT event_id FROM stream_ordering_to_exterm
|
||||
|
|
|
@ -97,7 +97,6 @@ class PersistEventsStore:
|
|||
self.store = main_data_store
|
||||
self.database_engine = db.engine
|
||||
self._clock = hs.get_clock()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -109,7 +108,7 @@ class PersistEventsStore:
|
|||
|
||||
# This should only exist on instances that are configured to write
|
||||
assert (
|
||||
hs.get_instance_name() in hs.config.worker.writers.events
|
||||
hs.config.worker.writers.events == hs.get_instance_name()
|
||||
), "Can only instantiate EventsStore on master"
|
||||
|
||||
async def _persist_events_and_state_updates(
|
||||
|
@ -801,7 +800,6 @@ class PersistEventsStore:
|
|||
table="events",
|
||||
values=[
|
||||
{
|
||||
"instance_name": self._instance_name,
|
||||
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"topological_ordering": event.depth,
|
||||
"depth": event.depth,
|
||||
|
|
|
@ -42,8 +42,7 @@ from synapse.replication.tcp.streams import BackfillStream
|
|||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -79,54 +78,27 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
||||
# regardless of whether this process writes to the streams or not.
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_stream_seq",
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
# We are the process in charge of generating stream ids for events,
|
||||
# so instantiate ID generators based on the database
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
)
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
positive=False,
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
else:
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
||||
# to support it for unit tests.
|
||||
#
|
||||
# If this process is the writer than we need to use
|
||||
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
|
||||
# updated over replication. (Multiple writers are not supported for
|
||||
# SQLite).
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
else:
|
||||
self._stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering"
|
||||
)
|
||||
self._backfill_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
)
|
||||
# Another process is in charge of persisting events and generating
|
||||
# stream IDs: rely on the replication streams to let us know which
|
||||
# IDs we can process.
|
||||
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
|
||||
self._backfill_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
)
|
||||
|
||||
self._get_event_cache = Cache(
|
||||
"*getEvent*",
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE events ADD COLUMN instance_name TEXT;
|
|
@ -1,26 +0,0 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
|
||||
|
||||
SELECT setval('events_stream_seq', (
|
||||
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
|
||||
));
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
|
||||
|
||||
SELECT setval('events_backfill_stream_seq', (
|
||||
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
|
||||
));
|
|
@ -231,12 +231,8 @@ class MultiWriterIdGenerator:
|
|||
# gaps should be relatively rare it's still worth doing the book keeping
|
||||
# that allows us to skip forwards when there are gapless runs of
|
||||
# positions.
|
||||
#
|
||||
# We start at 1 here as a) the first generated stream ID will be 2, and
|
||||
# b) other parts of the code assume that stream IDs are strictly greater
|
||||
# than 0.
|
||||
self._persisted_upto_position = (
|
||||
min(self._current_positions.values()) if self._current_positions else 1
|
||||
min(self._current_positions.values()) if self._current_positions else 0
|
||||
)
|
||||
self._known_persisted_positions = [] # type: List[int]
|
||||
|
||||
|
@ -366,7 +362,9 @@ class MultiWriterIdGenerator:
|
|||
equal to it have been successfully persisted.
|
||||
"""
|
||||
|
||||
return self.get_persisted_upto_position()
|
||||
# Currently we don't support this operation, as it's not obvious how to
|
||||
# condense the stream positions of multiple writers into a single int.
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
|
|
|
@ -18,11 +18,10 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Tuple, Union, cast
|
||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from prometheus_client import Gauge
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -38,8 +37,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
CacheKey = Union[Tuple, Any]
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
class _CachedFunction(Protocol):
|
||||
|
||||
class _CachedFunction(Generic[F]):
|
||||
invalidate = None # type: Any
|
||||
invalidate_all = None # type: Any
|
||||
invalidate_many = None # type: Any
|
||||
|
@ -47,8 +48,11 @@ class _CachedFunction(Protocol):
|
|||
cache = None # type: Any
|
||||
num_args = None # type: Any
|
||||
|
||||
def __name__(self):
|
||||
...
|
||||
__name__ = None # type: str
|
||||
|
||||
# Note: This function signature is actually fiddled with by the synapse mypy
|
||||
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
|
||||
__call__ = None # type: F
|
||||
|
||||
|
||||
cache_pending_metric = Gauge(
|
||||
|
@ -123,7 +127,7 @@ class Cache(object):
|
|||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
self.thread = None
|
||||
self.thread = None # type: Optional[threading.Thread]
|
||||
self.metrics = register_cache(
|
||||
"cache",
|
||||
name,
|
||||
|
@ -662,9 +666,13 @@ class _CacheContext:
|
|||
|
||||
|
||||
def cached(
|
||||
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
|
||||
):
|
||||
return lambda orig: CacheDescriptor(
|
||||
max_entries: int = 1000,
|
||||
num_args: Optional[int] = None,
|
||||
tree: bool = False,
|
||||
cache_context: bool = False,
|
||||
iterable: bool = False,
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
func = lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
|
@ -673,8 +681,12 @@ def cached(
|
|||
iterable=iterable,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
|
||||
def cachedList(cached_method_name, list_name, num_args=None):
|
||||
|
||||
def cachedList(
|
||||
cached_method_name: str, list_name: str, num_args: Optional[int] = None
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||
|
||||
Used to do batch lookups for an already created cache. A single argument
|
||||
|
@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
|
|||
cache.
|
||||
|
||||
Args:
|
||||
cached_method_name (str): The name of the single-item lookup method.
|
||||
cached_method_name: The name of the single-item lookup method.
|
||||
This is only used to find the cache to use.
|
||||
list_name (str): The name of the argument that is the list to use to
|
||||
list_name: The name of the argument that is the list to use to
|
||||
do batch lookups in the cache.
|
||||
num_args (int): Number of arguments to use as the key in the cache
|
||||
num_args: Number of arguments to use as the key in the cache
|
||||
(including list_name). Defaults to all named parameters.
|
||||
|
||||
Example:
|
||||
|
@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
|
|||
def batch_do_something(self, first_arg, second_args):
|
||||
...
|
||||
"""
|
||||
return lambda orig: CacheListDescriptor(
|
||||
func = lambda orig: CacheListDescriptor(
|
||||
orig,
|
||||
cached_method_name=cached_method_name,
|
||||
list_name=list_name,
|
||||
num_args=num_args,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
|
|
Loading…
Reference in a new issue