Add ability to wait for replication streams (#7542)

The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room).

Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on.

People probably want to look at this commit by commit.
This commit is contained in:
Erik Johnston 2020-05-22 14:21:54 +01:00 committed by GitHub
parent 06a02bc1ce
commit 1531b214fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 304 additions and 112 deletions

1
changelog.d/7542.misc Normal file
View file

@ -0,0 +1 @@
Add ability to wait for replication streams.

View file

@ -126,6 +126,7 @@ class FederationHandler(BaseHandler):
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self._replication = hs.get_replication_data_handler()
self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
hs
@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler):
async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
) -> None:
) -> Tuple[str, int]:
""" Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`.
@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler):
room_id=room_id, room_version=room_version_obj,
)
await self._persist_auth_tree(
max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
"master", "events", max_stream_id
)
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = await self.store.get_room_predecessor(room_id)
if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
return event.event_id, max_stream_id
old_room_id = predecessor["room_id"]
logger.debug(
"Found predecessor for %s during remote join: %s", room_id, old_room_id
@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler):
)
logger.debug("Finished joining %s to %s", joinee, room_id)
return event.event_id, max_stream_id
finally:
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler):
async def do_remotely_reject_invite(
self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
) -> EventBase:
) -> Tuple[EventBase, int]:
origin, event, room_version = await self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content
)
@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(target_hosts, event)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
stream_id = await self.persist_events_and_notify([(event, context)])
return event
return event, stream_id
async def _make_and_verify_event(
self,
@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler):
state: List[EventBase],
event: EventBase,
room_version: RoomVersion,
) -> None:
) -> int:
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event separately. Notifies about the persisted events
@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
await self.persist_events_and_notify([(event, new_event_context)])
return await self.persist_events_and_notify([(event, new_event_context)])
async def _prep_event(
self,
@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler):
self,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> None:
) -> int:
"""Persists events and tells the notifier/pushers about them, if
necessary.
@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler):
backfilling or not
"""
if self.config.worker_app:
await self._send_events_to_master(
result = await self._send_events_to_master(
store=self.store,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
return result["max_stream_id"]
else:
max_stream_id = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler):
for event, _ in event_and_contexts:
await self._notify_persisted_event(event, max_stream_id)
return max_stream_id
async def _notify_persisted_event(
self, event: EventBase, max_stream_id: int
) -> None:

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import Optional, Tuple
from six import iteritems, itervalues, string_types
@ -42,6 +42,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events import EventBase
from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@ -630,7 +631,9 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
async def send_nonmember_event(
self, requester, event, context, ratelimit=True
) -> int:
"""
Persists and notifies local clients and federation of an event.
@ -639,6 +642,9 @@ class EventCreationHandler(object):
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
Return:
The stream_id of the persisted event.
"""
if event.type == EventTypes.Member:
raise SynapseError(
@ -659,7 +665,7 @@ class EventCreationHandler(object):
)
return prev_state
await self.handle_new_client_event(
return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@ -688,7 +694,7 @@ class EventCreationHandler(object):
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
@ -711,10 +717,10 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
await self.send_nonmember_event(
stream_id = await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
return event, stream_id
@measure_func("create_new_client_event")
@defer.inlineCallbacks
@ -774,7 +780,7 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event")
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
) -> int:
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@ -787,6 +793,9 @@ class EventCreationHandler(object):
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
Return:
The stream_id of the persisted event.
"""
if event.is_state() and (event.type, event.state_key) == (
@ -827,7 +836,7 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
await self.send_event_to_master(
result = await self.send_event_to_master(
event_id=event.event_id,
store=self.store,
requester=requester,
@ -836,14 +845,17 @@ class EventCreationHandler(object):
ratelimit=ratelimit,
extra_users=extra_users,
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
success = True
return
return stream_id
await self.persist_and_notify_client_event(
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
return stream_id
finally:
if not success:
# Ensure that we actually remove the entries in the push actions
@ -886,7 +898,7 @@ class EventCreationHandler(object):
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
) -> int:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@ -1076,6 +1088,8 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event_stream_id
async def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()

View file

@ -22,6 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
from typing import Tuple
from six import iteritems, string_types
@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler):
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
):
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler):
`avatar_url` and/or `displayname`.
Returns:
Deferred[dict]:
a dict containing the keys `room_id` and, if an alias was
requested, `room_alias`.
First, a dict containing the keys `room_id` and, if an alias
was, requested, `room_alias`. Secondly, the stream_id of the
last persisted event.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
await self._send_events_for_new_room(
last_stream_id = await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
await self.event_creation_handler.create_and_send_nonmember_event(
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
await self.event_creation_handler.create_and_send_nonmember_event(
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
await self.room_member_handler.update_membership(
_, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
await self.hs.get_room_member_handler().do_3pid_invite(
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
return result
return result, last_stream_id
async def _send_events_for_new_room(
self,
@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler):
room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None,
):
) -> int:
"""Sends the initial events into a new room.
Returns:
The stream_id of the last event persisted.
"""
def create(etype, content, **kwargs):
e = {"type": etype, "content": content}
@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype, content, **kwargs):
async def send(etype, content, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
await self.event_creation_handler.create_and_send_nonmember_event(
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
return last_stream_id
config = RoomCreationHandler.PRESETS_DICT[preset_config]
@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
await send(etype=EventTypes.PowerLevels, content=pl_content)
last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=pl_content
)
else:
power_level_content = {
"users": {creator_id: 100},
@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
await send(etype=EventTypes.PowerLevels, content=power_level_content)
last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=power_level_content
)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
await send(
last_sent_stream_id = await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content)
last_sent_stream_id = await send(
etype=etype, state_key=state_key, content=content
)
return last_sent_stream_id
async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,

View file

@ -17,7 +17,7 @@
import abc
import logging
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple
from six.moves import http_client
@ -84,7 +84,7 @@ class RoomMemberHandler(object):
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
) -> Tuple[str, int]:
"""Try and join a room that this server is not in
Args:
@ -104,7 +104,7 @@ class RoomMemberHandler(object):
room_id: str,
target: UserID,
content: dict,
) -> dict:
) -> Tuple[Optional[str], int]:
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
@ -154,7 +154,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> EventBase:
) -> Tuple[str, int]:
user_id = target.to_string()
if content is None:
@ -187,9 +187,10 @@ class RoomMemberHandler(object):
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
await self.event_creation_handler.handle_new_client_event(
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
@ -213,7 +214,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id)
return event
return event.event_id, stream_id
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
@ -263,7 +264,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]:
) -> Tuple[Optional[str], int]:
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@ -294,7 +295,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]:
) -> Tuple[Optional[str], int]:
content_specified = bool(content)
if content is None:
content = {}
@ -398,7 +399,13 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
return old_state
_, stream_id = await self.store.get_event_ordering(
old_state.event_id
)
return (
old_state.event_id,
stream_id,
)
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
@ -705,7 +712,7 @@ class RoomMemberHandler(object):
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> None:
) -> int:
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@ -737,11 +744,11 @@ class RoomMemberHandler(object):
)
if invitee:
await self.update_membership(
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
await self._make_and_store_3pid_invite(
stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@ -752,6 +759,8 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
return stream_id
async def _make_and_store_3pid_invite(
self,
requester: Requester,
@ -762,7 +771,7 @@ class RoomMemberHandler(object):
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> None:
) -> int:
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
@ -817,7 +826,10 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
await self.event_creation_handler.create_and_send_nonmember_event(
(
event,
stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@ -835,6 +847,7 @@ class RoomMemberHandler(object):
ratelimit=False,
txn_id=txn_id,
)
return stream_id
async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id: str,
user: UserID,
content: dict,
) -> None:
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
await self.federation_handler.do_invite_join(
event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
await self._user_joined_room(user, room_id)
@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
if too_complex is False:
# We checked, and we're under the limit.
return
return event_id, stream_id
# Check again, but with the local state events
too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
return
return event_id, stream_id
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler):
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
)
return event_id, stream_id
async def _remote_reject_invite(
self,
requester: Requester,
@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id: str,
target: UserID,
content: dict,
) -> dict:
) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
try:
ret = await fed_handler.do_remotely_reject_invite(
event, stream_id = await fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content,
)
return ret
return event.event_id, stream_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(target.to_string(), room_id)
return {}
stream_id = await self.store.locally_reject_invite(
target.to_string(), room_id
)
return None, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import List, Optional
from typing import List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
await self._user_joined_room(user, room_id)
return ret
return ret["event_id"], ret["stream_id"]
async def _remote_reject_invite(
self,
@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
room_id: str,
target: UserID,
content: dict,
) -> dict:
) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
"""
return await self._remote_reject_client(
ret = await self._remote_reject_client(
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
content=content,
)
return ret["event_id"], ret["stream_id"]
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and
notifying.
notifying. Returns the maximum stream ID of the persisted events.
The API looks like:
@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"context": { .. serialized event context .. },
}],
"backfilled": false
}
200 OK
{
"max_stream_id": 32443,
}
"""
NAME = "fed_send_events"
@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
await self.federation_handler.persist_events_and_notify(
max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
return 200, {}
return 200, {"max_stream_id": max_stream_id}
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):

View file

@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
await self.federation_handler.do_invite_join(
event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
return 200, {}
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
event = await self.federation_handler.do_remotely_reject_invite(
event, stream_id = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content,
)
ret = event.get_pdu_json()
event_id = event.event_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
#
logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(user_id, room_id)
ret = {}
stream_id = await self.store.locally_reject_invite(user_id, room_id)
event_id = None
return 200, ret
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):

View file

@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
await self.event_creation_handler.persist_and_notify_client_event(
stream_id = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
return 200, {}
return 200, {"stream_id": stream_id}
def register_servlets(hs, http_server):

View file

@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
super().__init__(hs)
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication handler (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_tcp_replication().get_streams()
self.streams = hs.get_replication_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token):

View file

@ -14,19 +14,23 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import heapq
import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -35,6 +39,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# How long we allow callers to wait for replication updates before timing out.
_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
@ -92,6 +100,16 @@ class ReplicationDataHandler:
self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@ -131,8 +149,76 @@ class ReplicationDataHandler:
await self.pusher_pool.on_new_notifications(token, token)
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
waiting_list = self._streams_to_waiters.get(stream_name, [])
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
# loop below means either a) no items in list so no-op or b) all items
# in list were called and so the list should be cleared. Setting it to
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)
for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
try:
with PreserveLoggingContext():
deferred.callback(None)
except Exception:
# The deferred has been cancelled or timed out.
pass
else:
# The list is sorted by position so we don't need to continue
# checking any futher entries in the list.
index_of_first_deferred_not_called = idx
break
# Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
"""Wait until this instance has received updates up to and including
the given stream position.
"""
if instance_name == self._instance_name:
# We don't get told about updates written by this process, and
# anyway in that case we don't need to wait.
return
current_position = self._streams[stream_name].current_token(self._instance_name)
if position <= current_position:
# We're already past the position
return
# Create a new deferred that times out after N seconds, as we don't want
# to wedge here forever.
deferred = Deferred()
deferred = timeout_deferred(
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
# We insert into the list using heapq as it is more efficient than
# pushing then resorting each time.
heapq.heappush(waiting_list, (position, deferred))
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
await make_deferred_yieldable(deferred)
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)

View file

@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet):
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
self._replication = hs.get_replication_data_handler()
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
info = await self._room_creation_handler.create_room(
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet):
# desirable in case the first attempt at blocking the room failed below.
await self.store.block_room(room_id, requester_user_id)
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position("master", "events", stream_id)
users = await self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []

View file

@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
info = await self._room_creation_handler.create_room(
info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
event = await self.room_member_handler.update_membership(
event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
event = await self.event_creation_handler.create_and_send_nonmember_event(
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
ret = {} # type: dict
if event:
set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
if event_id:
set_tag("event_id", event_id)
ret = {"event_id": event_id}
return 200, ret
@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event = await self.event_creation_handler.create_and_send_nonmember_event(
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
event = await self.event_creation_handler.create_and_send_nonmember_event(
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,

View file

@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
event = await self.event_creation_handler.create_and_send_nonmember_event(
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)

View file

@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@ -210,6 +211,7 @@ class HomeServer(object):
"storage",
"replication_streamer",
"replication_data_handler",
"replication_streams",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@ -583,6 +585,9 @@ class HomeServer(object):
def build_replication_data_handler(self):
return ReplicationDataHandler(self)
def build_replication_streams(self):
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -1,3 +1,5 @@
from typing import Dict
import twisted.internet
import synapse.api.auth
@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
from synapse.replication.tcp.streams import Stream
class HomeServer(object):
@property
@ -136,3 +139,5 @@ class HomeServer(object):
pass
def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool:
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass

View file

@ -83,10 +83,10 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
res = await self._event_creation_handler.create_and_send_nonmember_event(
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
return res
return event
@cached()
async def get_or_create_notice_room_for_user(self, user_id):
@ -143,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
info = await self._room_creation_handler.create_room(
info, _ = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,

View file

@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore):
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
to_1, so_1 = await self._get_event_ordering(event_id1)
to_2, so_2 = await self._get_event_ordering(event_id2)
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
def get_event_ordering(self, event_id):
res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],

View file

@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
return stream_ordering
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""

View file

@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
d = handler._remote_join(
None,
@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(

View file

@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
reactor.pump((1000,))
hs = self.setup_test_homeserver(
notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
hs.datastores = datastores

View file

@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {}))
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
def run_background_update(self):
@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {}))
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.user_consent_version = self.CONSENT_VERSION

View file

@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
info = self.get_success(room_creator.create_room(requester, {}))
info, _ = self.get_success(room_creator.create_room(requester, {}))
room_id = info["room_id"]
last_event = None

View file

@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room = ensureDeferred(
room_deferred = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]
self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
self.store = self.homeserver.get_datastore()