mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 02:43:54 +01:00
1538 lines
54 KiB
Python
1538 lines
54 KiB
Python
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
|
|
# Copyright 2020 Sorunome
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import copy
|
|
import itertools
|
|
import logging
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Awaitable,
|
|
Callable,
|
|
Collection,
|
|
Container,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import attr
|
|
from prometheus_client import Counter
|
|
|
|
from synapse.api.constants import EventTypes, Membership
|
|
from synapse.api.errors import (
|
|
CodeMessageException,
|
|
Codes,
|
|
FederationDeniedError,
|
|
HttpResponseException,
|
|
RequestSendFailed,
|
|
SynapseError,
|
|
UnsupportedRoomVersionError,
|
|
)
|
|
from synapse.api.room_versions import (
|
|
KNOWN_ROOM_VERSIONS,
|
|
EventFormatVersions,
|
|
RoomVersion,
|
|
RoomVersions,
|
|
)
|
|
from synapse.events import EventBase, builder
|
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
|
from synapse.federation.transport.client import SendJoinResponse
|
|
from synapse.logging.utils import log_function
|
|
from synapse.types import JsonDict, get_domain_from_id
|
|
from synapse.util.async_helpers import concurrently_execute
|
|
from synapse.util.caches.expiringcache import ExpiringCache
|
|
from synapse.util.retryutils import NotRetryingDestination
|
|
|
|
if TYPE_CHECKING:
|
|
from synapse.server import HomeServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
|
|
|
|
|
PDU_RETRY_TIME_MS = 1 * 60 * 1000
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class InvalidResponseError(RuntimeError):
|
|
"""Helper for _try_destination_list: indicates that the server returned a response
|
|
we couldn't parse
|
|
"""
|
|
|
|
|
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
|
class SendJoinResult:
|
|
# The event to persist.
|
|
event: EventBase
|
|
# A string giving the server the event was sent to.
|
|
origin: str
|
|
state: List[EventBase]
|
|
auth_chain: List[EventBase]
|
|
|
|
|
|
class FederationClient(FederationBase):
|
|
def __init__(self, hs: "HomeServer"):
|
|
super().__init__(hs)
|
|
|
|
self.pdu_destination_tried: Dict[str, Dict[str, int]] = {}
|
|
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
|
self.state = hs.get_state_handler()
|
|
self.transport_layer = hs.get_federation_transport_client()
|
|
|
|
self.hostname = hs.hostname
|
|
self.signing_key = hs.signing_key
|
|
|
|
self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache(
|
|
cache_name="get_pdu_cache",
|
|
clock=self._clock,
|
|
max_len=1000,
|
|
expiry_ms=120 * 1000,
|
|
reset_expiry_on_get=False,
|
|
)
|
|
|
|
# A cache for fetching the room hierarchy over federation.
|
|
#
|
|
# Some stale data over federation is OK, but must be refreshed
|
|
# periodically since the local server is in the room.
|
|
#
|
|
# It is a map of (room ID, suggested-only) -> the response of
|
|
# get_room_hierarchy.
|
|
self._get_room_hierarchy_cache: ExpiringCache[
|
|
Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
|
|
] = ExpiringCache(
|
|
cache_name="get_room_hierarchy_cache",
|
|
clock=self._clock,
|
|
max_len=1000,
|
|
expiry_ms=5 * 60 * 1000,
|
|
reset_expiry_on_get=False,
|
|
)
|
|
|
|
def _clear_tried_cache(self):
|
|
"""Clear pdu_destination_tried cache"""
|
|
now = self._clock.time_msec()
|
|
|
|
old_dict = self.pdu_destination_tried
|
|
self.pdu_destination_tried = {}
|
|
|
|
for event_id, destination_dict in old_dict.items():
|
|
destination_dict = {
|
|
dest: time
|
|
for dest, time in destination_dict.items()
|
|
if time + PDU_RETRY_TIME_MS > now
|
|
}
|
|
if destination_dict:
|
|
self.pdu_destination_tried[event_id] = destination_dict
|
|
|
|
@log_function
|
|
async def make_query(
|
|
self,
|
|
destination: str,
|
|
query_type: str,
|
|
args: dict,
|
|
retry_on_dns_fail: bool = False,
|
|
ignore_backoff: bool = False,
|
|
) -> JsonDict:
|
|
"""Sends a federation Query to a remote homeserver of the given type
|
|
and arguments.
|
|
|
|
Args:
|
|
destination: Domain name of the remote homeserver
|
|
query_type: Category of the query type; should match the
|
|
handler name used in register_query_handler().
|
|
args: Mapping of strings to strings containing the details
|
|
of the query request.
|
|
ignore_backoff: true to ignore the historical backoff data
|
|
and try the request anyway.
|
|
|
|
Returns:
|
|
The JSON object from the response
|
|
"""
|
|
sent_queries_counter.labels(query_type).inc()
|
|
|
|
return await self.transport_layer.make_query(
|
|
destination,
|
|
query_type,
|
|
args,
|
|
retry_on_dns_fail=retry_on_dns_fail,
|
|
ignore_backoff=ignore_backoff,
|
|
)
|
|
|
|
@log_function
|
|
async def query_client_keys(
|
|
self, destination: str, content: JsonDict, timeout: int
|
|
) -> JsonDict:
|
|
"""Query device keys for a device hosted on a remote server.
|
|
|
|
Args:
|
|
destination: Domain name of the remote homeserver
|
|
content: The query content.
|
|
|
|
Returns:
|
|
The JSON object from the response
|
|
"""
|
|
sent_queries_counter.labels("client_device_keys").inc()
|
|
return await self.transport_layer.query_client_keys(
|
|
destination, content, timeout
|
|
)
|
|
|
|
@log_function
|
|
async def query_user_devices(
|
|
self, destination: str, user_id: str, timeout: int = 30000
|
|
) -> JsonDict:
|
|
"""Query the device keys for a list of user ids hosted on a remote
|
|
server.
|
|
"""
|
|
sent_queries_counter.labels("user_devices").inc()
|
|
return await self.transport_layer.query_user_devices(
|
|
destination, user_id, timeout
|
|
)
|
|
|
|
@log_function
|
|
async def claim_client_keys(
|
|
self, destination: str, content: JsonDict, timeout: int
|
|
) -> JsonDict:
|
|
"""Claims one-time keys for a device hosted on a remote server.
|
|
|
|
Args:
|
|
destination: Domain name of the remote homeserver
|
|
content: The query content.
|
|
|
|
Returns:
|
|
The JSON object from the response
|
|
"""
|
|
sent_queries_counter.labels("client_one_time_keys").inc()
|
|
return await self.transport_layer.claim_client_keys(
|
|
destination, content, timeout
|
|
)
|
|
|
|
async def backfill(
|
|
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
|
) -> Optional[List[EventBase]]:
|
|
"""Requests some more historic PDUs for the given room from the
|
|
given destination server.
|
|
|
|
Args:
|
|
dest: The remote homeserver to ask.
|
|
room_id: The room_id to backfill.
|
|
limit: The maximum number of events to return.
|
|
extremities: our current backwards extremities, to backfill from
|
|
"""
|
|
logger.debug("backfill extrem=%s", extremities)
|
|
|
|
# If there are no extremities then we've (probably) reached the start.
|
|
if not extremities:
|
|
return None
|
|
|
|
transaction_data = await self.transport_layer.backfill(
|
|
dest, room_id, extremities, limit
|
|
)
|
|
|
|
logger.debug("backfill transaction_data=%r", transaction_data)
|
|
|
|
room_version = await self.store.get_room_version(room_id)
|
|
|
|
pdus = [
|
|
event_from_pdu_json(p, room_version, outlier=False)
|
|
for p in transaction_data["pdus"]
|
|
]
|
|
|
|
# Check signatures and hash of pdus, removing any from the list that fail checks
|
|
pdus[:] = await self._check_sigs_and_hash_and_fetch(
|
|
dest, pdus, outlier=True, room_version=room_version
|
|
)
|
|
|
|
return pdus
|
|
|
|
async def get_pdu(
|
|
self,
|
|
destinations: Iterable[str],
|
|
event_id: str,
|
|
room_version: RoomVersion,
|
|
outlier: bool = False,
|
|
timeout: Optional[int] = None,
|
|
) -> Optional[EventBase]:
|
|
"""Requests the PDU with given origin and ID from the remote home
|
|
servers.
|
|
|
|
Will attempt to get the PDU from each destination in the list until
|
|
one succeeds.
|
|
|
|
Args:
|
|
destinations: Which homeservers to query
|
|
event_id: event to fetch
|
|
room_version: version of the room
|
|
outlier: Indicates whether the PDU is an `outlier`, i.e. if
|
|
it's from an arbitrary point in the context as opposed to part
|
|
of the current block of PDUs. Defaults to `False`
|
|
timeout: How long to try (in ms) each destination for before
|
|
moving to the next destination. None indicates no timeout.
|
|
|
|
Returns:
|
|
The requested PDU, or None if we were unable to find it.
|
|
"""
|
|
|
|
# TODO: Rate limit the number of times we try and get the same event.
|
|
|
|
ev = self._get_pdu_cache.get(event_id)
|
|
if ev:
|
|
return ev
|
|
|
|
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
|
|
|
signed_pdu = None
|
|
for destination in destinations:
|
|
now = self._clock.time_msec()
|
|
last_attempt = pdu_attempts.get(destination, 0)
|
|
if last_attempt + PDU_RETRY_TIME_MS > now:
|
|
continue
|
|
|
|
try:
|
|
transaction_data = await self.transport_layer.get_event(
|
|
destination, event_id, timeout=timeout
|
|
)
|
|
|
|
logger.debug(
|
|
"retrieved event id %s from %s: %r",
|
|
event_id,
|
|
destination,
|
|
transaction_data,
|
|
)
|
|
|
|
pdu_list: List[EventBase] = [
|
|
event_from_pdu_json(p, room_version, outlier=outlier)
|
|
for p in transaction_data["pdus"]
|
|
]
|
|
|
|
if pdu_list and pdu_list[0]:
|
|
pdu = pdu_list[0]
|
|
|
|
# Check signatures are correct.
|
|
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
|
|
|
|
break
|
|
|
|
pdu_attempts[destination] = now
|
|
|
|
except SynapseError as e:
|
|
logger.info(
|
|
"Failed to get PDU %s from %s because %s", event_id, destination, e
|
|
)
|
|
continue
|
|
except NotRetryingDestination as e:
|
|
logger.info(str(e))
|
|
continue
|
|
except FederationDeniedError as e:
|
|
logger.info(str(e))
|
|
continue
|
|
except Exception as e:
|
|
pdu_attempts[destination] = now
|
|
|
|
logger.info(
|
|
"Failed to get PDU %s from %s because %s", event_id, destination, e
|
|
)
|
|
continue
|
|
|
|
if signed_pdu:
|
|
self._get_pdu_cache[event_id] = signed_pdu
|
|
|
|
return signed_pdu
|
|
|
|
async def get_room_state_ids(
|
|
self, destination: str, room_id: str, event_id: str
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Calls the /state_ids endpoint to fetch the state at a particular point
|
|
in the room, and the auth events for the given event
|
|
|
|
Returns:
|
|
a tuple of (state event_ids, auth event_ids)
|
|
"""
|
|
result = await self.transport_layer.get_room_state_ids(
|
|
destination, room_id, event_id=event_id
|
|
)
|
|
|
|
state_event_ids = result["pdu_ids"]
|
|
auth_event_ids = result.get("auth_chain_ids", [])
|
|
|
|
if not isinstance(state_event_ids, list) or not isinstance(
|
|
auth_event_ids, list
|
|
):
|
|
raise Exception("invalid response from /state_ids")
|
|
|
|
return state_event_ids, auth_event_ids
|
|
|
|
async def _check_sigs_and_hash_and_fetch(
|
|
self,
|
|
origin: str,
|
|
pdus: Collection[EventBase],
|
|
room_version: RoomVersion,
|
|
outlier: bool = False,
|
|
) -> List[EventBase]:
|
|
"""Takes a list of PDUs and checks the signatures and hashes of each
|
|
one. If a PDU fails its signature check then we check if we have it in
|
|
the database and if not then request if from the originating server of
|
|
that PDU.
|
|
|
|
If a PDU fails its content hash check then it is redacted.
|
|
|
|
The given list of PDUs are not modified, instead the function returns
|
|
a new list.
|
|
|
|
Args:
|
|
origin
|
|
pdu
|
|
room_version
|
|
outlier: Whether the events are outliers or not
|
|
|
|
Returns:
|
|
A list of PDUs that have valid signatures and hashes.
|
|
"""
|
|
|
|
# We limit how many PDUs we check at once, as if we try to do hundreds
|
|
# of thousands of PDUs at once we see large memory spikes.
|
|
|
|
valid_pdus = []
|
|
|
|
async def _execute(pdu: EventBase) -> None:
|
|
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
|
pdu=pdu,
|
|
origin=origin,
|
|
outlier=outlier,
|
|
room_version=room_version,
|
|
)
|
|
|
|
if valid_pdu:
|
|
valid_pdus.append(valid_pdu)
|
|
|
|
await concurrently_execute(_execute, pdus, 10000)
|
|
|
|
return valid_pdus
|
|
|
|
async def _check_sigs_and_hash_and_fetch_one(
|
|
self,
|
|
pdu: EventBase,
|
|
origin: str,
|
|
room_version: RoomVersion,
|
|
outlier: bool = False,
|
|
) -> Optional[EventBase]:
|
|
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
|
|
its signature check then we check if we have it in the database and if
|
|
not then request if from the originating server of that PDU.
|
|
|
|
If then PDU fails its content hash check then it is redacted.
|
|
|
|
Args:
|
|
origin
|
|
pdu
|
|
room_version
|
|
outlier: Whether the events are outliers or not
|
|
include_none: Whether to include None in the returned list
|
|
for events that have failed their checks
|
|
|
|
Returns:
|
|
The PDU (possibly redacted) if it has valid signatures and hashes.
|
|
"""
|
|
|
|
res = None
|
|
try:
|
|
res = await self._check_sigs_and_hash(room_version, pdu)
|
|
except SynapseError:
|
|
pass
|
|
|
|
if not res:
|
|
# Check local db.
|
|
res = await self.store.get_event(
|
|
pdu.event_id, allow_rejected=True, allow_none=True
|
|
)
|
|
|
|
pdu_origin = get_domain_from_id(pdu.sender)
|
|
if not res and pdu_origin != origin:
|
|
try:
|
|
res = await self.get_pdu(
|
|
destinations=[pdu_origin],
|
|
event_id=pdu.event_id,
|
|
room_version=room_version,
|
|
outlier=outlier,
|
|
timeout=10000,
|
|
)
|
|
except SynapseError:
|
|
pass
|
|
|
|
if not res:
|
|
logger.warning(
|
|
"Failed to find copy of %s with valid signature", pdu.event_id
|
|
)
|
|
|
|
return res
|
|
|
|
async def get_event_auth(
|
|
self, destination: str, room_id: str, event_id: str
|
|
) -> List[EventBase]:
|
|
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
|
|
|
room_version = await self.store.get_room_version(room_id)
|
|
|
|
auth_chain = [
|
|
event_from_pdu_json(p, room_version, outlier=True)
|
|
for p in res["auth_chain"]
|
|
]
|
|
|
|
signed_auth = await self._check_sigs_and_hash_and_fetch(
|
|
destination, auth_chain, outlier=True, room_version=room_version
|
|
)
|
|
|
|
signed_auth.sort(key=lambda e: e.depth)
|
|
|
|
return signed_auth
|
|
|
|
def _is_unknown_endpoint(
|
|
self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None
|
|
) -> bool:
|
|
"""
|
|
Returns true if the response was due to an endpoint being unimplemented.
|
|
|
|
Args:
|
|
e: The error response received from the remote server.
|
|
synapse_error: The above error converted to a SynapseError. This is
|
|
automatically generated if not provided.
|
|
|
|
"""
|
|
if synapse_error is None:
|
|
synapse_error = e.to_synapse_error()
|
|
# There is no good way to detect an "unknown" endpoint.
|
|
#
|
|
# Dendrite returns a 404 (with no body); synapse returns a 400
|
|
# with M_UNRECOGNISED.
|
|
return e.code == 404 or (
|
|
e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
|
|
)
|
|
|
|
async def _try_destination_list(
|
|
self,
|
|
description: str,
|
|
destinations: Iterable[str],
|
|
callback: Callable[[str], Awaitable[T]],
|
|
failover_errcodes: Optional[Container[str]] = None,
|
|
failover_on_unknown_endpoint: bool = False,
|
|
) -> T:
|
|
"""Try an operation on a series of servers, until it succeeds
|
|
|
|
Args:
|
|
description: description of the operation we're doing, for logging
|
|
|
|
destinations: list of server_names to try
|
|
|
|
callback: Function to run for each server. Passed a single
|
|
argument: the server_name to try.
|
|
|
|
If the callback raises a CodeMessageException with a 300/400 code or
|
|
an UnsupportedRoomVersionError, attempts to perform the operation
|
|
stop immediately and the exception is reraised.
|
|
|
|
Otherwise, if the callback raises an Exception the error is logged and the
|
|
next server tried. Normally the stacktrace is logged but this is
|
|
suppressed if the exception is an InvalidResponseError.
|
|
|
|
failover_errcodes: Error codes (specific to this endpoint) which should
|
|
cause a failover when received as part of an HTTP 400 error.
|
|
|
|
failover_on_unknown_endpoint: if True, we will try other servers if it looks
|
|
like a server doesn't support the endpoint. This is typically useful
|
|
if the endpoint in question is new or experimental.
|
|
|
|
Returns:
|
|
The result of callback, if it succeeds
|
|
|
|
Raises:
|
|
SynapseError if the chosen remote server returns a 300/400 code, or
|
|
no servers were reachable.
|
|
"""
|
|
if failover_errcodes is None:
|
|
failover_errcodes = ()
|
|
|
|
for destination in destinations:
|
|
if destination == self.server_name:
|
|
continue
|
|
|
|
try:
|
|
return await callback(destination)
|
|
except (
|
|
RequestSendFailed,
|
|
InvalidResponseError,
|
|
NotRetryingDestination,
|
|
) as e:
|
|
logger.warning("Failed to %s via %s: %s", description, destination, e)
|
|
except UnsupportedRoomVersionError:
|
|
raise
|
|
except HttpResponseException as e:
|
|
synapse_error = e.to_synapse_error()
|
|
failover = False
|
|
|
|
# Failover should occur:
|
|
#
|
|
# * On internal server errors.
|
|
# * If the destination responds that it cannot complete the request.
|
|
# * If the destination doesn't implemented the endpoint for some reason.
|
|
if 500 <= e.code < 600:
|
|
failover = True
|
|
|
|
elif e.code == 400 and synapse_error.errcode in failover_errcodes:
|
|
failover = True
|
|
|
|
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
|
|
e, synapse_error
|
|
):
|
|
failover = True
|
|
|
|
if not failover:
|
|
raise synapse_error from e
|
|
|
|
logger.warning(
|
|
"Failed to %s via %s: %i %s",
|
|
description,
|
|
destination,
|
|
e.code,
|
|
e.args[0],
|
|
)
|
|
except Exception:
|
|
logger.warning(
|
|
"Failed to %s via %s", description, destination, exc_info=True
|
|
)
|
|
|
|
raise SynapseError(502, "Failed to %s via any server" % (description,))
|
|
|
|
async def make_membership_event(
|
|
self,
|
|
destinations: Iterable[str],
|
|
room_id: str,
|
|
user_id: str,
|
|
membership: str,
|
|
content: dict,
|
|
params: Optional[Mapping[str, Union[str, Iterable[str]]]],
|
|
) -> Tuple[str, EventBase, RoomVersion]:
|
|
"""
|
|
Creates an m.room.member event, with context, without participating in the room.
|
|
|
|
Does so by asking one of the already participating servers to create an
|
|
event with proper context.
|
|
|
|
Returns a fully signed and hashed event.
|
|
|
|
Note that this does not append any events to any graphs.
|
|
|
|
Args:
|
|
destinations: Candidate homeservers which are probably
|
|
participating in the room.
|
|
room_id: The room in which the event will happen.
|
|
user_id: The user whose membership is being evented.
|
|
membership: The "membership" property of the event. Must be one of
|
|
"join" or "leave".
|
|
content: Any additional data to put into the content field of the
|
|
event.
|
|
params: Query parameters to include in the request.
|
|
|
|
Returns:
|
|
`(origin, event, room_version)` where origin is the remote
|
|
homeserver which generated the event, and room_version is the
|
|
version of the room.
|
|
|
|
Raises:
|
|
UnsupportedRoomVersionError: if remote responds with
|
|
a room version we don't understand.
|
|
|
|
SynapseError: if the chosen remote server returns a 300/400 code, or
|
|
no servers successfully handle the request.
|
|
"""
|
|
valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
|
|
|
|
if membership not in valid_memberships:
|
|
raise RuntimeError(
|
|
"make_membership_event called with membership='%s', must be one of %s"
|
|
% (membership, ",".join(valid_memberships))
|
|
)
|
|
|
|
async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]:
|
|
ret = await self.transport_layer.make_membership_event(
|
|
destination, room_id, user_id, membership, params
|
|
)
|
|
|
|
# Note: If not supplied, the room version may be either v1 or v2,
|
|
# however either way the event format version will be v1.
|
|
room_version_id = ret.get("room_version", RoomVersions.V1.identifier)
|
|
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
|
if not room_version:
|
|
raise UnsupportedRoomVersionError()
|
|
|
|
if not room_version.msc2403_knocking and membership == Membership.KNOCK:
|
|
raise SynapseError(
|
|
400,
|
|
"This room version does not support knocking",
|
|
errcode=Codes.FORBIDDEN,
|
|
)
|
|
|
|
pdu_dict = ret.get("event", None)
|
|
if not isinstance(pdu_dict, dict):
|
|
raise InvalidResponseError("Bad 'event' field in response")
|
|
|
|
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
|
|
|
|
pdu_dict["content"].update(content)
|
|
|
|
# The protoevent received over the JSON wire may not have all
|
|
# the required fields. Lets just gloss over that because
|
|
# there's some we never care about
|
|
if "prev_state" not in pdu_dict:
|
|
pdu_dict["prev_state"] = []
|
|
|
|
ev = builder.create_local_event_from_event_dict(
|
|
self._clock,
|
|
self.hostname,
|
|
self.signing_key,
|
|
room_version=room_version,
|
|
event_dict=pdu_dict,
|
|
)
|
|
|
|
return destination, ev, room_version
|
|
|
|
# MSC3083 defines additional error codes for room joins. Unfortunately
|
|
# we do not yet know the room version, assume these will only be returned
|
|
# by valid room versions.
|
|
failover_errcodes = (
|
|
(Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
|
|
if membership == Membership.JOIN
|
|
else None
|
|
)
|
|
|
|
return await self._try_destination_list(
|
|
"make_" + membership,
|
|
destinations,
|
|
send_request,
|
|
failover_errcodes=failover_errcodes,
|
|
)
|
|
|
|
async def send_join(
|
|
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
|
|
) -> SendJoinResult:
|
|
"""Sends a join event to one of a list of homeservers.
|
|
|
|
Doing so will cause the remote server to add the event to the graph,
|
|
and send the event out to the rest of the federation.
|
|
|
|
Args:
|
|
destinations: Candidate homeservers which are probably
|
|
participating in the room.
|
|
pdu: event to be sent
|
|
room_version: the version of the room (according to the server that
|
|
did the make_join)
|
|
|
|
Returns:
|
|
The result of the send join request.
|
|
|
|
Raises:
|
|
SynapseError: if the chosen remote server returns a 300/400 code, or
|
|
no servers successfully handle the request.
|
|
"""
|
|
|
|
async def send_request(destination) -> SendJoinResult:
|
|
response = await self._do_send_join(room_version, destination, pdu)
|
|
|
|
# If an event was returned (and expected to be returned):
|
|
#
|
|
# * Ensure it has the same event ID (note that the event ID is a hash
|
|
# of the event fields for versions which support MSC3083).
|
|
# * Ensure the signatures are good.
|
|
#
|
|
# Otherwise, fallback to the provided event.
|
|
if room_version.msc3083_join_rules and response.event:
|
|
event = response.event
|
|
|
|
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
|
pdu=event,
|
|
origin=destination,
|
|
outlier=True,
|
|
room_version=room_version,
|
|
)
|
|
|
|
if valid_pdu is None or event.event_id != pdu.event_id:
|
|
raise InvalidResponseError("Returned an invalid join event")
|
|
else:
|
|
event = pdu
|
|
|
|
state = response.state
|
|
auth_chain = response.auth_events
|
|
|
|
create_event = None
|
|
for e in state:
|
|
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
|
create_event = e
|
|
break
|
|
|
|
if create_event is None:
|
|
# If the state doesn't have a create event then the room is
|
|
# invalid, and it would fail auth checks anyway.
|
|
raise InvalidResponseError("No create event in state")
|
|
|
|
# the room version should be sane.
|
|
create_room_version = create_event.content.get(
|
|
"room_version", RoomVersions.V1.identifier
|
|
)
|
|
if create_room_version != room_version.identifier:
|
|
# either the server that fulfilled the make_join, or the server that is
|
|
# handling the send_join, is lying.
|
|
raise InvalidResponseError(
|
|
"Unexpected room version %s in create event"
|
|
% (create_room_version,)
|
|
)
|
|
|
|
logger.info(
|
|
"Processing from send_join %d events", len(state) + len(auth_chain)
|
|
)
|
|
|
|
# We now go and check the signatures and hashes for the event. Note
|
|
# that we limit how many events we process at a time to keep the
|
|
# memory overhead from exploding.
|
|
valid_pdus_map: Dict[str, EventBase] = {}
|
|
|
|
async def _execute(pdu: EventBase) -> None:
|
|
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
|
pdu=pdu,
|
|
origin=destination,
|
|
outlier=True,
|
|
room_version=room_version,
|
|
)
|
|
|
|
if valid_pdu:
|
|
valid_pdus_map[valid_pdu.event_id] = valid_pdu
|
|
|
|
await concurrently_execute(
|
|
_execute, itertools.chain(state, auth_chain), 10000
|
|
)
|
|
|
|
# NB: We *need* to copy to ensure that we don't have multiple
|
|
# references being passed on, as that causes... issues.
|
|
signed_state = [
|
|
copy.copy(valid_pdus_map[p.event_id])
|
|
for p in state
|
|
if p.event_id in valid_pdus_map
|
|
]
|
|
|
|
signed_auth = [
|
|
valid_pdus_map[p.event_id]
|
|
for p in auth_chain
|
|
if p.event_id in valid_pdus_map
|
|
]
|
|
|
|
# NB: We *need* to copy to ensure that we don't have multiple
|
|
# references being passed on, as that causes... issues.
|
|
for s in signed_state:
|
|
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
|
|
|
# double-check that the same create event has ended up in the auth chain
|
|
auth_chain_create_events = [
|
|
e.event_id
|
|
for e in signed_auth
|
|
if (e.type, e.state_key) == (EventTypes.Create, "")
|
|
]
|
|
if auth_chain_create_events != [create_event.event_id]:
|
|
raise InvalidResponseError(
|
|
"Unexpected create event(s) in auth chain: %s"
|
|
% (auth_chain_create_events,)
|
|
)
|
|
|
|
return SendJoinResult(
|
|
event=event,
|
|
state=signed_state,
|
|
auth_chain=signed_auth,
|
|
origin=destination,
|
|
)
|
|
|
|
# MSC3083 defines additional error codes for room joins.
|
|
failover_errcodes = None
|
|
if room_version.msc3083_join_rules:
|
|
failover_errcodes = (
|
|
Codes.UNABLE_AUTHORISE_JOIN,
|
|
Codes.UNABLE_TO_GRANT_JOIN,
|
|
)
|
|
|
|
# If the join is being authorised via allow rules, we need to send
|
|
# the /send_join back to the same server that was originally used
|
|
# with /make_join.
|
|
if "join_authorised_via_users_server" in pdu.content:
|
|
destinations = [
|
|
get_domain_from_id(pdu.content["join_authorised_via_users_server"])
|
|
]
|
|
|
|
return await self._try_destination_list(
|
|
"send_join", destinations, send_request, failover_errcodes=failover_errcodes
|
|
)
|
|
|
|
async def _do_send_join(
|
|
self, room_version: RoomVersion, destination: str, pdu: EventBase
|
|
) -> SendJoinResponse:
|
|
time_now = self._clock.time_msec()
|
|
|
|
try:
|
|
return await self.transport_layer.send_join_v2(
|
|
room_version=room_version,
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content=pdu.get_pdu_json(time_now),
|
|
)
|
|
except HttpResponseException as e:
|
|
# If an error is received that is due to an unrecognised endpoint,
|
|
# fallback to the v1 endpoint. Otherwise consider it a legitmate error
|
|
# and raise.
|
|
if not self._is_unknown_endpoint(e):
|
|
raise
|
|
|
|
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
|
|
|
return await self.transport_layer.send_join_v1(
|
|
room_version=room_version,
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content=pdu.get_pdu_json(time_now),
|
|
)
|
|
|
|
async def send_invite(
|
|
self,
|
|
destination: str,
|
|
room_id: str,
|
|
event_id: str,
|
|
pdu: EventBase,
|
|
) -> EventBase:
|
|
room_version = await self.store.get_room_version(room_id)
|
|
|
|
content = await self._do_send_invite(destination, pdu, room_version)
|
|
|
|
pdu_dict = content["event"]
|
|
|
|
logger.debug("Got response to send_invite: %s", pdu_dict)
|
|
|
|
pdu = event_from_pdu_json(pdu_dict, room_version)
|
|
|
|
# Check signatures are correct.
|
|
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
|
|
|
# FIXME: We should handle signature failures more gracefully.
|
|
|
|
return pdu
|
|
|
|
async def _do_send_invite(
|
|
self, destination: str, pdu: EventBase, room_version: RoomVersion
|
|
) -> JsonDict:
|
|
"""Actually sends the invite, first trying v2 API and falling back to
|
|
v1 API if necessary.
|
|
|
|
Returns:
|
|
The event as a dict as returned by the remote server
|
|
|
|
Raises:
|
|
SynapseError: if the remote server returns an error or if the server
|
|
only supports the v1 endpoint and a room version other than "1"
|
|
or "2" is requested.
|
|
"""
|
|
time_now = self._clock.time_msec()
|
|
|
|
try:
|
|
return await self.transport_layer.send_invite_v2(
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content={
|
|
"event": pdu.get_pdu_json(time_now),
|
|
"room_version": room_version.identifier,
|
|
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
|
},
|
|
)
|
|
except HttpResponseException as e:
|
|
# If an error is received that is due to an unrecognised endpoint,
|
|
# fallback to the v1 endpoint if the room uses old-style event IDs.
|
|
# Otherwise consider it a legitmate error and raise.
|
|
err = e.to_synapse_error()
|
|
if self._is_unknown_endpoint(e, err):
|
|
if room_version.event_format != EventFormatVersions.V1:
|
|
raise SynapseError(
|
|
400,
|
|
"User's homeserver does not support this room version",
|
|
Codes.UNSUPPORTED_ROOM_VERSION,
|
|
)
|
|
else:
|
|
raise err
|
|
|
|
# Didn't work, try v1 API.
|
|
# Note the v1 API returns a tuple of `(200, content)`
|
|
|
|
_, content = await self.transport_layer.send_invite_v1(
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content=pdu.get_pdu_json(time_now),
|
|
)
|
|
return content
|
|
|
|
async def send_leave(self, destinations: Iterable[str], pdu: EventBase) -> None:
|
|
"""Sends a leave event to one of a list of homeservers.
|
|
|
|
Doing so will cause the remote server to add the event to the graph,
|
|
and send the event out to the rest of the federation.
|
|
|
|
This is mostly useful to reject received invites.
|
|
|
|
Args:
|
|
destinations: Candidate homeservers which are probably
|
|
participating in the room.
|
|
pdu: event to be sent
|
|
|
|
Raises:
|
|
SynapseError: if the chosen remote server returns a 300/400 code, or
|
|
no servers successfully handle the request.
|
|
"""
|
|
|
|
async def send_request(destination: str) -> None:
|
|
content = await self._do_send_leave(destination, pdu)
|
|
logger.debug("Got content: %s", content)
|
|
|
|
return await self._try_destination_list(
|
|
"send_leave", destinations, send_request
|
|
)
|
|
|
|
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
|
|
time_now = self._clock.time_msec()
|
|
|
|
try:
|
|
return await self.transport_layer.send_leave_v2(
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content=pdu.get_pdu_json(time_now),
|
|
)
|
|
except HttpResponseException as e:
|
|
# If an error is received that is due to an unrecognised endpoint,
|
|
# fallback to the v1 endpoint. Otherwise consider it a legitmate error
|
|
# and raise.
|
|
if not self._is_unknown_endpoint(e):
|
|
raise
|
|
|
|
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
|
|
|
|
resp = await self.transport_layer.send_leave_v1(
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content=pdu.get_pdu_json(time_now),
|
|
)
|
|
|
|
# We expect the v1 API to respond with [200, content], so we only return the
|
|
# content.
|
|
return resp[1]
|
|
|
|
async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
|
|
"""Attempts to send a knock event to given a list of servers. Iterates
|
|
through the list until one attempt succeeds.
|
|
|
|
Doing so will cause the remote server to add the event to the graph,
|
|
and send the event out to the rest of the federation.
|
|
|
|
Args:
|
|
destinations: A list of candidate homeservers which are likely to be
|
|
participating in the room.
|
|
pdu: The event to be sent.
|
|
|
|
Returns:
|
|
The remote homeserver return some state from the room. The response
|
|
dictionary is in the form:
|
|
|
|
{"knock_state_events": [<state event dict>, ...]}
|
|
|
|
The list of state events may be empty.
|
|
|
|
Raises:
|
|
SynapseError: If the chosen remote server returns a 3xx/4xx code.
|
|
RuntimeError: If no servers were reachable.
|
|
"""
|
|
|
|
async def send_request(destination: str) -> JsonDict:
|
|
return await self._do_send_knock(destination, pdu)
|
|
|
|
return await self._try_destination_list(
|
|
"send_knock", destinations, send_request
|
|
)
|
|
|
|
async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
|
|
"""Send a knock event to a remote homeserver.
|
|
|
|
Args:
|
|
destination: The homeserver to send to.
|
|
pdu: The event to send.
|
|
|
|
Returns:
|
|
The remote homeserver can optionally return some state from the room. The response
|
|
dictionary is in the form:
|
|
|
|
{"knock_state_events": [<state event dict>, ...]}
|
|
|
|
The list of state events may be empty.
|
|
"""
|
|
time_now = self._clock.time_msec()
|
|
|
|
return await self.transport_layer.send_knock_v1(
|
|
destination=destination,
|
|
room_id=pdu.room_id,
|
|
event_id=pdu.event_id,
|
|
content=pdu.get_pdu_json(time_now),
|
|
)
|
|
|
|
async def get_public_rooms(
|
|
self,
|
|
remote_server: str,
|
|
limit: Optional[int] = None,
|
|
since_token: Optional[str] = None,
|
|
search_filter: Optional[Dict] = None,
|
|
include_all_networks: bool = False,
|
|
third_party_instance_id: Optional[str] = None,
|
|
) -> JsonDict:
|
|
"""Get the list of public rooms from a remote homeserver
|
|
|
|
Args:
|
|
remote_server: The name of the remote server
|
|
limit: Maximum amount of rooms to return
|
|
since_token: Used for result pagination
|
|
search_filter: A filter dictionary to send the remote homeserver
|
|
and filter the result set
|
|
include_all_networks: Whether to include results from all third party instances
|
|
third_party_instance_id: Whether to only include results from a specific third
|
|
party instance
|
|
|
|
Returns:
|
|
The response from the remote server.
|
|
|
|
Raises:
|
|
HttpResponseException / RequestSendFailed: There was an exception
|
|
returned from the remote server
|
|
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
|
|
requests over federation
|
|
|
|
"""
|
|
return await self.transport_layer.get_public_rooms(
|
|
remote_server,
|
|
limit,
|
|
since_token,
|
|
search_filter,
|
|
include_all_networks=include_all_networks,
|
|
third_party_instance_id=third_party_instance_id,
|
|
)
|
|
|
|
async def get_missing_events(
|
|
self,
|
|
destination: str,
|
|
room_id: str,
|
|
earliest_events_ids: Iterable[str],
|
|
latest_events: Iterable[EventBase],
|
|
limit: int,
|
|
min_depth: int,
|
|
timeout: int,
|
|
) -> List[EventBase]:
|
|
"""Tries to fetch events we are missing. This is called when we receive
|
|
an event without having received all of its ancestors.
|
|
|
|
Args:
|
|
destination
|
|
room_id
|
|
earliest_events_ids: List of event ids. Effectively the
|
|
events we expected to receive, but haven't. `get_missing_events`
|
|
should only return events that didn't happen before these.
|
|
latest_events: List of events we have received that we don't
|
|
have all previous events for.
|
|
limit: Maximum number of events to return.
|
|
min_depth: Minimum depth of events to return.
|
|
timeout: Max time to wait in ms
|
|
"""
|
|
try:
|
|
content = await self.transport_layer.get_missing_events(
|
|
destination=destination,
|
|
room_id=room_id,
|
|
earliest_events=earliest_events_ids,
|
|
latest_events=[e.event_id for e in latest_events],
|
|
limit=limit,
|
|
min_depth=min_depth,
|
|
timeout=timeout,
|
|
)
|
|
|
|
room_version = await self.store.get_room_version(room_id)
|
|
|
|
events = [
|
|
event_from_pdu_json(e, room_version) for e in content.get("events", [])
|
|
]
|
|
|
|
signed_events = await self._check_sigs_and_hash_and_fetch(
|
|
destination, events, outlier=False, room_version=room_version
|
|
)
|
|
except HttpResponseException as e:
|
|
if not e.code == 400:
|
|
raise
|
|
|
|
# We are probably hitting an old server that doesn't support
|
|
# get_missing_events
|
|
signed_events = []
|
|
|
|
return signed_events
|
|
|
|
async def forward_third_party_invite(
|
|
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
|
) -> None:
|
|
for destination in destinations:
|
|
if destination == self.server_name:
|
|
continue
|
|
|
|
try:
|
|
await self.transport_layer.exchange_third_party_invite(
|
|
destination=destination, room_id=room_id, event_dict=event_dict
|
|
)
|
|
return
|
|
except CodeMessageException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Failed to send_third_party_invite via %s: %s", destination, str(e)
|
|
)
|
|
|
|
raise RuntimeError("Failed to send to any server.")
|
|
|
|
async def get_room_complexity(
|
|
self, destination: str, room_id: str
|
|
) -> Optional[JsonDict]:
|
|
"""
|
|
Fetch the complexity of a remote room from another server.
|
|
|
|
Args:
|
|
destination: The remote server
|
|
room_id: The room ID to ask about.
|
|
|
|
Returns:
|
|
Dict contains the complexity metric versions, while None means we
|
|
could not fetch the complexity.
|
|
"""
|
|
try:
|
|
return await self.transport_layer.get_room_complexity(
|
|
destination=destination, room_id=room_id
|
|
)
|
|
except CodeMessageException as e:
|
|
# We didn't manage to get it -- probably a 404. We are okay if other
|
|
# servers don't give it to us.
|
|
logger.debug(
|
|
"Failed to fetch room complexity via %s for %s, got a %d",
|
|
destination,
|
|
room_id,
|
|
e.code,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to fetch room complexity via %s for %s", destination, room_id
|
|
)
|
|
|
|
# If we don't manage to find it, return None. It's not an error if a
|
|
# server doesn't give it to us.
|
|
return None
|
|
|
|
async def get_space_summary(
|
|
self,
|
|
destinations: Iterable[str],
|
|
room_id: str,
|
|
suggested_only: bool,
|
|
max_rooms_per_space: Optional[int],
|
|
exclude_rooms: List[str],
|
|
) -> "FederationSpaceSummaryResult":
|
|
"""
|
|
Call other servers to get a summary of the given space
|
|
|
|
|
|
Args:
|
|
destinations: The remote servers. We will try them in turn, omitting any
|
|
that have been blacklisted.
|
|
|
|
room_id: ID of the space to be queried
|
|
|
|
suggested_only: If true, ask the remote server to only return children
|
|
with the "suggested" flag set
|
|
|
|
max_rooms_per_space: A limit on the number of children to return for each
|
|
space
|
|
|
|
exclude_rooms: A list of room IDs to tell the remote server to skip
|
|
|
|
Returns:
|
|
a parsed FederationSpaceSummaryResult
|
|
|
|
Raises:
|
|
SynapseError if we were unable to get a valid summary from any of the
|
|
remote servers
|
|
"""
|
|
|
|
async def send_request(destination: str) -> FederationSpaceSummaryResult:
|
|
res = await self.transport_layer.get_space_summary(
|
|
destination=destination,
|
|
room_id=room_id,
|
|
suggested_only=suggested_only,
|
|
max_rooms_per_space=max_rooms_per_space,
|
|
exclude_rooms=exclude_rooms,
|
|
)
|
|
|
|
try:
|
|
return FederationSpaceSummaryResult.from_json_dict(res)
|
|
except ValueError as e:
|
|
raise InvalidResponseError(str(e))
|
|
|
|
return await self._try_destination_list(
|
|
"fetch space summary",
|
|
destinations,
|
|
send_request,
|
|
failover_on_unknown_endpoint=True,
|
|
)
|
|
|
|
async def get_room_hierarchy(
|
|
self,
|
|
destinations: Iterable[str],
|
|
room_id: str,
|
|
suggested_only: bool,
|
|
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
|
|
"""
|
|
Call other servers to get a hierarchy of the given room.
|
|
|
|
Performs simple data validates and parsing of the response.
|
|
|
|
Args:
|
|
destinations: The remote servers. We will try them in turn, omitting any
|
|
that have been blacklisted.
|
|
room_id: ID of the space to be queried
|
|
suggested_only: If true, ask the remote server to only return children
|
|
with the "suggested" flag set
|
|
|
|
Returns:
|
|
A tuple of:
|
|
The room as a JSON dictionary.
|
|
A list of children rooms, as JSON dictionaries.
|
|
A list of inaccessible children room IDs.
|
|
|
|
Raises:
|
|
SynapseError if we were unable to get a valid summary from any of the
|
|
remote servers
|
|
"""
|
|
|
|
cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only))
|
|
if cached_result:
|
|
return cached_result
|
|
|
|
async def send_request(
|
|
destination: str,
|
|
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
|
|
res = await self.transport_layer.get_room_hierarchy(
|
|
destination=destination,
|
|
room_id=room_id,
|
|
suggested_only=suggested_only,
|
|
)
|
|
|
|
room = res.get("room")
|
|
if not isinstance(room, dict):
|
|
raise InvalidResponseError("'room' must be a dict")
|
|
|
|
# Validate children_state of the room.
|
|
children_state = room.get("children_state", [])
|
|
if not isinstance(children_state, Sequence):
|
|
raise InvalidResponseError("'room.children_state' must be a list")
|
|
if any(not isinstance(e, dict) for e in children_state):
|
|
raise InvalidResponseError("Invalid event in 'children_state' list")
|
|
try:
|
|
[
|
|
FederationSpaceSummaryEventResult.from_json_dict(e)
|
|
for e in children_state
|
|
]
|
|
except ValueError as e:
|
|
raise InvalidResponseError(str(e))
|
|
|
|
# Validate the children rooms.
|
|
children = res.get("children", [])
|
|
if not isinstance(children, Sequence):
|
|
raise InvalidResponseError("'children' must be a list")
|
|
if any(not isinstance(r, dict) for r in children):
|
|
raise InvalidResponseError("Invalid room in 'children' list")
|
|
|
|
# Validate the inaccessible children.
|
|
inaccessible_children = res.get("inaccessible_children", [])
|
|
if not isinstance(inaccessible_children, Sequence):
|
|
raise InvalidResponseError("'inaccessible_children' must be a list")
|
|
if any(not isinstance(r, str) for r in inaccessible_children):
|
|
raise InvalidResponseError(
|
|
"Invalid room ID in 'inaccessible_children' list"
|
|
)
|
|
|
|
return room, children, inaccessible_children
|
|
|
|
try:
|
|
result = await self._try_destination_list(
|
|
"fetch room hierarchy",
|
|
destinations,
|
|
send_request,
|
|
failover_on_unknown_endpoint=True,
|
|
)
|
|
except SynapseError as e:
|
|
# If an unexpected error occurred, re-raise it.
|
|
if e.code != 502:
|
|
raise
|
|
|
|
# Fallback to the old federation API and translate the results if
|
|
# no servers implement the new API.
|
|
#
|
|
# The algorithm below is a bit inefficient as it only attempts to
|
|
# parse information for the requested room, but the legacy API may
|
|
# return additional layers.
|
|
legacy_result = await self.get_space_summary(
|
|
destinations,
|
|
room_id,
|
|
suggested_only,
|
|
max_rooms_per_space=None,
|
|
exclude_rooms=[],
|
|
)
|
|
|
|
# Find the requested room in the response (and remove it).
|
|
for _i, room in enumerate(legacy_result.rooms):
|
|
if room.get("room_id") == room_id:
|
|
break
|
|
else:
|
|
# The requested room was not returned, nothing we can do.
|
|
raise
|
|
requested_room = legacy_result.rooms.pop(_i)
|
|
|
|
# Find any children events of the requested room.
|
|
children_events = []
|
|
children_room_ids = set()
|
|
for event in legacy_result.events:
|
|
if event.room_id == room_id:
|
|
children_events.append(event.data)
|
|
children_room_ids.add(event.state_key)
|
|
# And add them under the requested room.
|
|
requested_room["children_state"] = children_events
|
|
|
|
# Find the children rooms.
|
|
children = []
|
|
for room in legacy_result.rooms:
|
|
if room.get("room_id") in children_room_ids:
|
|
children.append(room)
|
|
|
|
# It isn't clear from the response whether some of the rooms are
|
|
# not accessible.
|
|
result = (requested_room, children, ())
|
|
|
|
# Cache the result to avoid fetching data over federation every time.
|
|
self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
|
|
return result
|
|
|
|
|
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
|
class FederationSpaceSummaryEventResult:
|
|
"""Represents a single event in the result of a successful get_space_summary call.
|
|
|
|
It's essentially just a serialised event object, but we do a bit of parsing and
|
|
validation in `from_json_dict` and store some of the validated properties in
|
|
object attributes.
|
|
"""
|
|
|
|
event_type: str
|
|
room_id: str
|
|
state_key: str
|
|
via: Sequence[str]
|
|
|
|
# the raw data, including the above keys
|
|
data: JsonDict
|
|
|
|
@classmethod
|
|
def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult":
|
|
"""Parse an event within the result of a /spaces/ request
|
|
|
|
Args:
|
|
d: json object to be parsed
|
|
|
|
Raises:
|
|
ValueError if d is not a valid event
|
|
"""
|
|
|
|
event_type = d.get("type")
|
|
if not isinstance(event_type, str):
|
|
raise ValueError("Invalid event: 'event_type' must be a str")
|
|
|
|
room_id = d.get("room_id")
|
|
if not isinstance(room_id, str):
|
|
raise ValueError("Invalid event: 'room_id' must be a str")
|
|
|
|
state_key = d.get("state_key")
|
|
if not isinstance(state_key, str):
|
|
raise ValueError("Invalid event: 'state_key' must be a str")
|
|
|
|
content = d.get("content")
|
|
if not isinstance(content, dict):
|
|
raise ValueError("Invalid event: 'content' must be a dict")
|
|
|
|
via = content.get("via")
|
|
if not isinstance(via, Sequence):
|
|
raise ValueError("Invalid event: 'via' must be a list")
|
|
if any(not isinstance(v, str) for v in via):
|
|
raise ValueError("Invalid event: 'via' must be a list of strings")
|
|
|
|
return cls(event_type, room_id, state_key, via, d)
|
|
|
|
|
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
|
class FederationSpaceSummaryResult:
|
|
"""Represents the data returned by a successful get_space_summary call."""
|
|
|
|
rooms: List[JsonDict]
|
|
events: Sequence[FederationSpaceSummaryEventResult]
|
|
|
|
@classmethod
|
|
def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult":
|
|
"""Parse the result of a /spaces/ request
|
|
|
|
Args:
|
|
d: json object to be parsed
|
|
|
|
Raises:
|
|
ValueError if d is not a valid /spaces/ response
|
|
"""
|
|
rooms = d.get("rooms")
|
|
if not isinstance(rooms, List):
|
|
raise ValueError("'rooms' must be a list")
|
|
if any(not isinstance(r, dict) for r in rooms):
|
|
raise ValueError("Invalid room in 'rooms' list")
|
|
|
|
events = d.get("events")
|
|
if not isinstance(events, Sequence):
|
|
raise ValueError("'events' must be a list")
|
|
if any(not isinstance(e, dict) for e in events):
|
|
raise ValueError("Invalid event in 'events' list")
|
|
parsed_events = [
|
|
FederationSpaceSummaryEventResult.from_json_dict(e) for e in events
|
|
]
|
|
|
|
return cls(rooms, parsed_events)
|