0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 06:13:52 +01:00

Add option to move event persistence off master (#7517)

This commit is contained in:
Erik Johnston 2020-05-22 16:11:35 +01:00 committed by GitHub
parent 4429764c9f
commit e5c67d04db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 382 additions and 73 deletions

1
changelog.d/7517.feature Normal file
View file

@ -0,0 +1 @@
Add option to move event persistence off master.

View file

@ -196,6 +196,9 @@ class MockHomeserver:
def get_reactor(self): def get_reactor(self):
return reactor return reactor
def get_instance_name(self):
return "master"
class Porter(object): class Porter(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):

View file

@ -39,7 +39,11 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging from synapse.config.logger import setup_logging
from synapse.federation import send_queue from synapse.federation import send_queue
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.presence import BasePresenceHandler, get_interested_parties from synapse.handlers.presence import (
BasePresenceHandler,
PresenceState,
get_interested_parties,
)
from synapse.http.server import JsonResource, OptionsResource from synapse.http.server import JsonResource, OptionsResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseSite from synapse.http.site import SynapseSite
@ -47,6 +51,10 @@ from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
)
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@ -247,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler):
# but we haven't notified the master of that yet # but we haven't notified the master of that yet
self.users_going_offline = {} self.users_going_offline = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
self._send_stop_syncing_loop = self.clock.looping_call( self._send_stop_syncing_loop = self.clock.looping_call(
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
) )
@ -304,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler):
self.users_going_offline.pop(user_id, None) self.users_going_offline.pop(user_id, None)
self.send_user_sync(user_id, False, last_sync_ms) self.send_user_sync(user_id, False, last_sync_ms)
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
return defer.succeed(None)
async def user_syncing( async def user_syncing(
self, user_id: str, affect_presence: bool self, user_id: str, affect_presence: bool
) -> ContextManager[None]: ) -> ContextManager[None]:
@ -386,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler):
if count > 0 if count > 0
] ]
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
presence = state["presence"]
valid_presence = (
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
)
if presence not in valid_presence:
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
# If presence is disabled, no-op
if not self.hs.config.use_presence:
return
# Proxy request to master
await self._set_state_client(
user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
)
async def bump_presence_active_time(self, user):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
# If presence is disabled, no-op
if not self.hs.config.use_presence:
return
# Proxy request to master
user_id = user.to_string()
await self._bump_active_client(user_id=user_id)
class GenericWorkerTyping(object): class GenericWorkerTyping(object):
def __init__(self, hs): def __init__(self, hs):

View file

@ -257,5 +257,6 @@ def setup_logging(
logging.warning("***** STARTING SERVER *****") logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name) logging.info("Server hostname: %s", config.server_name)
logging.info("Instance name: %s", hs.get_instance_name())
return logger return logger

View file

@ -15,7 +15,7 @@
import attr import attr
from ._base import Config from ._base import Config, ConfigError
@attr.s @attr.s
@ -27,6 +27,17 @@ class InstanceLocationConfig:
port = attr.ib(type=int) port = attr.ib(type=int)
@attr.s
class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instance that writes to the event and backfill streams.
"""
events = attr.ib(default="master", type=str)
class WorkerConfig(Config): class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process. """The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the They have their own pid_file and listener configuration. They use the
@ -83,11 +94,26 @@ class WorkerConfig(Config):
bind_addresses.append("") bind_addresses.append("")
# A map from instance name to host/port of their HTTP replication endpoint. # A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map", {}) or {} instance_map = config.get("instance_map") or {}
self.instance_map = { self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items() name: InstanceLocationConfig(**c) for name, c in instance_map.items()
} }
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events also appears in
# `instance_map`.
if (
self.writers.events != "master"
and self.writers.events not in self.instance_map
):
raise ConfigError(
"Instance %r is configured to write events but does not appear in `instance_map` config."
% (self.writers.events,)
)
def read_arguments(self, args): def read_arguments(self, args):
# We support a bunch of command line arguments that override options in # We support a bunch of command line arguments that override options in
# the config. A lot of these options have a worker_* prefix when running # the config. A lot of these options have a worker_* prefix when running

View file

@ -126,11 +126,10 @@ class FederationHandler(BaseHandler):
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config self.config = hs.config
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler() self._replication = hs.get_replication_data_handler()
self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
hs
)
self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
hs hs
) )
@ -1243,6 +1242,10 @@ class FederationHandler(BaseHandler):
content: The event content to use for the join event. content: The event content to use for the join event.
""" """
# TODO: We should be able to call this on workers, but the upgrading of
# room stuff after join currently doesn't work on workers.
assert self.config.worker.worker_app is None
logger.debug("Joining %s to %s", joinee, room_id) logger.debug("Joining %s to %s", joinee, room_id)
origin, event, room_version_obj = await self._make_and_verify_event( origin, event, room_version_obj = await self._make_and_verify_event(
@ -1314,7 +1317,7 @@ class FederationHandler(BaseHandler):
# #
# TODO: Currently the events stream is written to from master # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position( await self._replication.wait_for_stream_position(
"master", "events", max_stream_id self.config.worker.writers.events, "events", max_stream_id
) )
# Check whether this room is the result of an upgrade of a room we already know # Check whether this room is the result of an upgrade of a room we already know
@ -2854,8 +2857,9 @@ class FederationHandler(BaseHandler):
backfilled: Whether these events are a result of backfilled: Whether these events are a result of
backfilling or not backfilling or not
""" """
if self.config.worker_app: if self.config.worker.writers.events != self._instance_name:
result = await self._send_events_to_master( result = await self._send_events(
instance_name=self.config.worker.writers.events,
store=self.store, store=self.store,
event_and_contexts=event_and_contexts, event_and_contexts=event_and_contexts,
backfilled=backfilled, backfilled=backfilled,

View file

@ -366,10 +366,11 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.config = hs.config self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types self.room_invite_state_types = self.hs.config.room_invite_state_types
self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users # This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs) self.base_handler = BaseHandler(hs)
@ -835,8 +836,9 @@ class EventCreationHandler(object):
success = False success = False
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker.writers.events != self._instance_name:
result = await self.send_event_to_master( result = await self.send_event(
instance_name=self.config.worker.writers.events,
event_id=event.event_id, event_id=event.event_id,
store=self.store, store=self.store,
requester=requester, requester=requester,
@ -902,9 +904,9 @@ class EventCreationHandler(object):
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth. calculated the push actions for the event, and checked auth.
This should only be run on master. This should only be run on the instance in charge of persisting events.
""" """
assert not self.config.worker_app assert self.config.worker.writers.events == self._instance_name
if ratelimit: if ratelimit:
# We check if this is a room admin redacting an event so that we # We check if this is a room admin redacting an event so that we

View file

@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC):
) -> None: ) -> None:
"""Set the presence state of the user. """ """Set the presence state of the user. """
@abc.abstractmethod
async def bump_presence_active_time(self, user: UserID):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
class PresenceHandler(BasePresenceHandler): class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):

View file

@ -89,6 +89,8 @@ class RoomCreationHandler(BaseHandler):
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config self.config = hs.config
self._replication = hs.get_replication_data_handler()
# linearizer to stop two upgrades happening at once # linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@ -752,6 +754,11 @@ class RoomCreationHandler(BaseHandler):
if room_alias: if room_alias:
result["room_alias"] = room_alias.to_string() result["room_alias"] = room_alias.to_string()
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", last_stream_id
)
return result, last_stream_id return result, last_stream_id
async def _send_events_for_new_room( async def _send_events_for_new_room(

View file

@ -26,6 +26,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.replication.http.membership import (
ReplicationLocallyRejectInviteRestServlet,
)
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
@ -44,11 +47,6 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -71,6 +69,17 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
else:
self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
hs
)
# This is only used to get at ratelimit function, and # This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as # maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state. # it doesn't store state.
@ -121,6 +130,22 @@ class RoomMemberHandler(object):
""" """
raise NotImplementedError() raise NotImplementedError()
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
if self._is_on_event_persistence_instance:
return await self.persist_event_storage.locally_reject_invite(
user_id, room_id
)
else:
result = await self._locally_reject_client(
instance_name=self._event_stream_writer_instance,
user_id=user_id,
room_id=room_id,
)
return result["stream_id"]
@abc.abstractmethod @abc.abstractmethod
async def _user_joined_room(self, target: UserID, room_id: str) -> None: async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the """Notifies distributor on master process that the user has joined the
@ -1015,9 +1040,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
stream_id = await self.store.locally_reject_invite( stream_id = await self.locally_reject_invite(target.to_string(), room_id)
target.to_string(), room_id
)
return None, stream_id return None, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None: async def _user_joined_room(self, target: UserID, room_id: str) -> None:

View file

@ -19,6 +19,7 @@ from synapse.replication.http import (
federation, federation,
login, login,
membership, membership,
presence,
register, register,
send_event, send_event,
streams, streams,
@ -35,10 +36,11 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs): def register_servlets(self, hs):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:
membership.register_servlets(hs, self)
login.register_servlets(hs, self) login.register_servlets(hs, self)
register.register_servlets(hs, self) register.register_servlets(hs, self)
devices.register_servlets(hs, self) devices.register_servlets(hs, self)

View file

@ -142,6 +142,7 @@ class ReplicationEndpoint(object):
""" """
clock = hs.get_clock() clock = hs.get_clock()
client = hs.get_simple_http_client() client = hs.get_simple_http_client()
local_instance_name = hs.get_instance_name()
master_host = hs.config.worker_replication_host master_host = hs.config.worker_replication_host
master_port = hs.config.worker_replication_http_port master_port = hs.config.worker_replication_http_port
@ -151,6 +152,8 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(instance_name="master", **kwargs): def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master": if instance_name == "master":
host = master_host host = master_host
port = master_port port = master_port

View file

@ -14,12 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod @staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
@ -149,12 +154,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
stream_id = await self.store.locally_reject_invite(user_id, room_id) stream_id = await self.member_handler.locally_reject_invite(
user_id, room_id
)
event_id = None event_id = None
return 200, {"event_id": event_id, "stream_id": stream_id} return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects the invite for the user and room locally.
Request format:
POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
{}
"""
NAME = "locally_reject_invite"
PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.member_handler = hs.get_room_member_handler()
@staticmethod
def _serialize_payload(room_id, user_id):
return {}
async def _handle_request(self, request, room_id, user_id):
logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
return 200, {"stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room """Notifies that a user has joined or left the room
@ -208,3 +245,4 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)

View file

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
"""We've seen the user do something that indicates they're interacting
with the app.
The POST looks like:
POST /_synapse/replication/bump_presence_active_time/<user_id>
200 OK
{}
"""
NAME = "bump_presence_active_time"
PATH_ARGS = ("user_id",)
METHOD = "POST"
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._presence_handler = hs.get_presence_handler()
@staticmethod
def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id)
)
return (
200,
{},
)
class ReplicationPresenceSetState(ReplicationEndpoint):
"""Set the presence state for a user.
The POST looks like:
POST /_synapse/replication/presence_set_state/<user_id>
{
"state": { ... },
"ignore_status_msg": false,
}
200 OK
{}
"""
NAME = "presence_set_state"
PATH_ARGS = ("user_id",)
METHOD = "POST"
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._presence_handler = hs.get_presence_handler()
@staticmethod
def _serialize_payload(user_id, state, ignore_status_msg=False):
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
}
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
await self._presence_handler.set_state(
UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
)
return (
200,
{},
)
def register_servlets(hs, http_server):
ReplicationBumpPresenceActiveTime(hs).register(http_server)
ReplicationPresenceSetState(hs).register(http_server)

View file

@ -38,7 +38,9 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
STREAMS_MAP, STREAMS_MAP,
BackfillStream,
CachesStream, CachesStream,
EventsStream,
FederationStream, FederationStream,
Stream, Stream,
) )
@ -87,6 +89,14 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream) self._streams_to_replicate.append(stream)
continue continue
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if hs.config.worker.writers.events == hs.get_instance_name():
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master. # Only add any other streams if we're on master.
if hs.config.worker_app is not None: if hs.config.worker_app is not None:
continue continue

View file

@ -100,7 +100,9 @@ class ShutdownRoomRestServlet(RestServlet):
# we try and auto join below. # we try and auto join below.
# #
# TODO: Currently the events stream is written to from master # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position("master", "events", stream_id) await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
)
users = await self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
kicked_users = [] kicked_users = []
@ -113,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet):
try: try:
target_requester = create_requester(user_id) target_requester = create_requester(user_id)
await self.room_member_handler.update_membership( _, stream_id = await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=room_id, room_id=room_id,
@ -123,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet):
require_consent=False, require_consent=False,
) )
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
)
await self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.forget(target_requester.user, room_id)
await self.room_member_handler.update_membership( await self.room_member_handler.update_membership(

View file

@ -66,9 +66,9 @@ class DataStores(object):
self.main = main_store_class(database, db_conn, hs) self.main = main_store_class(database, db_conn, hs)
# If we're on a process that can persist events (currently # If we're on a process that can persist events also
# master), also instantiate a `PersistEventsStore` # instantiate a `PersistEventsStore`
if hs.config.worker.worker_app is None: if hs.config.worker.writers.events == hs.get_instance_name():
self.persist_events = PersistEventsStore( self.persist_events = PersistEventsStore(
hs, database, self.main hs, database, self.main
) )

View file

@ -689,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale", desc="make_remote_user_device_cache_as_stale",
) )
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
self.db.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
)
self._invalidate_cache_and_stream(
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
return self.db.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
@ -969,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device", desc="update_device",
) )
@defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
)
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry( def update_remote_device_list_cache_entry(
self, user_id, device_id, content, stream_id self, user_id, device_id, content, stream_id
): ):

View file

@ -138,10 +138,10 @@ class PersistEventsStore:
self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
# This should only exist on master for now # This should only exist on instances that are configured to write
assert ( assert (
hs.config.worker.worker_app is None hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate PersistEventsStore on master" ), "Can only instantiate EventsStore on master"
@_retry_on_integrity_error @_retry_on_integrity_error
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1590,3 +1590,31 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier() if not ev.internal_metadata.is_outlier()
], ],
) )
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
return stream_ordering

View file

@ -76,7 +76,7 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs) super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker_app is None: if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events, # We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database # so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(

View file

@ -1046,31 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs) super(RoomMemberStore, self).__init__(database, db_conn, hs)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
return stream_ordering
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""

View file

@ -786,3 +786,9 @@ class EventsPersistenceStorage(object):
for user_id in left_users: for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
return await self.persist_events_store.locally_reject_invite(user_id, room_id)