forked from MirrorHub/synapse
Merge branch 'develop' into babolivier/msc3026
This commit is contained in:
commit
592d6305fd
18 changed files with 574 additions and 165 deletions
1
changelog.d/9636.bugfix
Normal file
1
changelog.d/9636.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Checks if passwords are allowed before setting it for the user.
|
1
changelog.d/9640.misc
Normal file
1
changelog.d/9640.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve performance of federation catch up by sending events the latest events in the room to the remote, rather than just the last event sent by the local server.
|
1
changelog.d/9643.feature
Normal file
1
changelog.d/9643.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add initial experimental support for a "space summary" API.
|
1
changelog.d/9645.misc
Normal file
1
changelog.d/9645.misc
Normal file
|
@ -0,0 +1 @@
|
|||
In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested.
|
1
changelog.d/9647.misc
Normal file
1
changelog.d/9647.misc
Normal file
|
@ -0,0 +1 @@
|
|||
In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`.
|
|
@ -22,8 +22,8 @@ import sys
|
|||
from typing import Any, Optional
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import nacl.signing
|
||||
import requests
|
||||
import signedjson.key
|
||||
import signedjson.types
|
||||
import srvlookup
|
||||
import yaml
|
||||
|
@ -44,18 +44,6 @@ def encode_base64(input_bytes):
|
|||
return output_string
|
||||
|
||||
|
||||
def decode_base64(input_string):
|
||||
"""Decode a base64 string to bytes inferring padding from the length of the
|
||||
string."""
|
||||
|
||||
input_bytes = input_string.encode("ascii")
|
||||
input_len = len(input_bytes)
|
||||
padding = b"=" * (3 - ((input_len + 3) % 4))
|
||||
output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
|
||||
output_bytes = base64.b64decode(input_bytes + padding)
|
||||
return output_bytes[:output_len]
|
||||
|
||||
|
||||
def encode_canonical_json(value):
|
||||
return json.dumps(
|
||||
value,
|
||||
|
@ -88,42 +76,6 @@ def sign_json(
|
|||
return json_object
|
||||
|
||||
|
||||
NACL_ED25519 = "ed25519"
|
||||
|
||||
|
||||
def decode_signing_key_base64(algorithm, version, key_base64):
|
||||
"""Decode a base64 encoded signing key
|
||||
Args:
|
||||
algorithm (str): The algorithm the key is for (currently "ed25519").
|
||||
version (str): Identifies this key out of the keys for this entity.
|
||||
key_base64 (str): Base64 encoded bytes of the key.
|
||||
Returns:
|
||||
A SigningKey object.
|
||||
"""
|
||||
if algorithm == NACL_ED25519:
|
||||
key_bytes = decode_base64(key_base64)
|
||||
key = nacl.signing.SigningKey(key_bytes)
|
||||
key.version = version
|
||||
key.alg = NACL_ED25519
|
||||
return key
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm %s" % (algorithm,))
|
||||
|
||||
|
||||
def read_signing_keys(stream):
|
||||
"""Reads a list of keys from a stream
|
||||
Args:
|
||||
stream : A stream to iterate for keys.
|
||||
Returns:
|
||||
list of SigningKey objects.
|
||||
"""
|
||||
keys = []
|
||||
for line in stream:
|
||||
algorithm, version, key_base64 = line.split()
|
||||
keys.append(decode_signing_key_base64(algorithm, version, key_base64))
|
||||
return keys
|
||||
|
||||
|
||||
def request(
|
||||
method: Optional[str],
|
||||
origin_name: str,
|
||||
|
@ -223,23 +175,28 @@ def main():
|
|||
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
|
||||
|
||||
parser.add_argument(
|
||||
"path", help="request path. We will add '/_matrix/federation/v1/' to this."
|
||||
"path", help="request path, including the '/_matrix/federation/...' prefix."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.server_name or not args.signing_key_path:
|
||||
args.signing_key = None
|
||||
if args.signing_key_path:
|
||||
with open(args.signing_key_path) as f:
|
||||
args.signing_key = f.readline()
|
||||
|
||||
if not args.server_name or not args.signing_key:
|
||||
read_args_from_config(args)
|
||||
|
||||
with open(args.signing_key_path) as f:
|
||||
key = read_signing_keys(f)[0]
|
||||
algorithm, version, key_base64 = args.signing_key.split()
|
||||
key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
|
||||
|
||||
result = request(
|
||||
args.method,
|
||||
args.server_name,
|
||||
key,
|
||||
args.destination,
|
||||
"/_matrix/federation/v1/" + args.path,
|
||||
args.path,
|
||||
content=args.body,
|
||||
)
|
||||
|
||||
|
@ -255,10 +212,16 @@ def main():
|
|||
def read_args_from_config(args):
|
||||
with open(args.config, "r") as fh:
|
||||
config = yaml.safe_load(fh)
|
||||
|
||||
if not args.server_name:
|
||||
args.server_name = config["server_name"]
|
||||
if not args.signing_key_path:
|
||||
args.signing_key_path = config["signing_key_path"]
|
||||
|
||||
if not args.signing_key:
|
||||
if "signing_key" in config:
|
||||
args.signing_key = config["signing_key"]
|
||||
else:
|
||||
with open(config["signing_key_path"]) as f:
|
||||
args.signing_key = f.readline()
|
||||
|
||||
|
||||
class MatrixConnectionAdapter(HTTPAdapter):
|
||||
|
|
|
@ -101,6 +101,9 @@ class EventTypes:
|
|||
|
||||
Dummy = "org.matrix.dummy_event"
|
||||
|
||||
MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
|
||||
MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
|
||||
|
||||
|
||||
class EduTypes:
|
||||
Presence = "m.presence"
|
||||
|
@ -161,6 +164,9 @@ class EventContentFields:
|
|||
# cf https://github.com/matrix-org/matrix-doc/pull/2228
|
||||
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
|
||||
|
||||
# cf https://github.com/matrix-org/matrix-doc/pull/1772
|
||||
MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
|
||||
|
||||
|
||||
class RoomEncryptionAlgorithms:
|
||||
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
|
||||
|
|
|
@ -27,5 +27,7 @@ class ExperimentalConfig(Config):
|
|||
|
||||
# MSC2858 (multiple SSO identity providers)
|
||||
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
||||
# Spaces (MSC1772, MSC2946, etc)
|
||||
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
|
||||
# MSC3026 (busy presence state)
|
||||
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
||||
|
|
|
@ -35,7 +35,7 @@ from twisted.internet import defer
|
|||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.constants import EduTypes, EventTypes, Membership
|
||||
from synapse.api.constants import EduTypes, EventTypes
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
|
@ -63,7 +63,7 @@ from synapse.replication.http.federation import (
|
|||
ReplicationFederationSendEduRestServlet,
|
||||
ReplicationGetQueryRestServlet,
|
||||
)
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -727,27 +727,6 @@ class FederationServer(FederationBase):
|
|||
if the event was unacceptable for any other reason (eg, too large,
|
||||
too many prev_events, couldn't find the prev_events)
|
||||
"""
|
||||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if origin != get_domain_from_id(pdu.sender):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893. This is also true for some third party
|
||||
# invites).
|
||||
if not (
|
||||
pdu.type == "m.room.member"
|
||||
and pdu.content
|
||||
and pdu.content.get("membership", None)
|
||||
in (Membership.JOIN, Membership.INVITE)
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s", pdu.event_id, origin
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
|
||||
|
||||
# We've already checked that we know the room version by this point
|
||||
room_version = await self.store.get_room_version(pdu.room_id)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
import datetime
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
@ -77,6 +77,7 @@ class PerDestinationQueue:
|
|||
self._transaction_manager = transaction_manager
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._federation_shard_config = hs.config.worker.federation_shard_config
|
||||
self._state = hs.get_state_handler()
|
||||
|
||||
self._should_send_on_this_instance = True
|
||||
if not self._federation_shard_config.should_handle(
|
||||
|
@ -415,22 +416,95 @@ class PerDestinationQueue:
|
|||
"This should not happen." % event_ids
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
rooms = [p.room_id for p in catchup_pdus]
|
||||
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
|
||||
# We send transactions with events from one room only, as its likely
|
||||
# that the remote will have to do additional processing, which may
|
||||
# take some time. It's better to give it small amounts of work
|
||||
# rather than risk the request timing out and repeatedly being
|
||||
# retried, and not making any progress.
|
||||
#
|
||||
# Note: `catchup_pdus` will have exactly one PDU per room.
|
||||
for pdu in catchup_pdus:
|
||||
# The PDU from the DB will be the last PDU in the room from
|
||||
# *this server* that wasn't sent to the remote. However, other
|
||||
# servers may have sent lots of events since then, and we want
|
||||
# to try and tell the remote only about the *latest* events in
|
||||
# the room. This is so that it doesn't get inundated by events
|
||||
# from various parts of the DAG, which all need to be processed.
|
||||
#
|
||||
# Note: this does mean that in large rooms a server coming back
|
||||
# online will get sent the same events from all the different
|
||||
# servers, but the remote will correctly deduplicate them and
|
||||
# handle it only once.
|
||||
|
||||
await self._transaction_manager.send_new_transaction(
|
||||
self._destination, catchup_pdus, []
|
||||
)
|
||||
# Step 1, fetch the current extremities
|
||||
extrems = await self._store.get_prev_events_for_room(pdu.room_id)
|
||||
|
||||
sent_transactions_counter.inc()
|
||||
final_pdu = catchup_pdus[-1]
|
||||
self._last_successful_stream_ordering = cast(
|
||||
int, final_pdu.internal_metadata.stream_ordering
|
||||
)
|
||||
await self._store.set_destination_last_successful_stream_ordering(
|
||||
self._destination, self._last_successful_stream_ordering
|
||||
)
|
||||
if pdu.event_id in extrems:
|
||||
# If the event is in the extremities, then great! We can just
|
||||
# use that without having to do further checks.
|
||||
room_catchup_pdus = [pdu]
|
||||
else:
|
||||
# If not, fetch the extremities and figure out which we can
|
||||
# send.
|
||||
extrem_events = await self._store.get_events_as_list(extrems)
|
||||
|
||||
new_pdus = []
|
||||
for p in extrem_events:
|
||||
# We pulled this from the DB, so it'll be non-null
|
||||
assert p.internal_metadata.stream_ordering
|
||||
|
||||
# Filter out events that happened before the remote went
|
||||
# offline
|
||||
if (
|
||||
p.internal_metadata.stream_ordering
|
||||
< self._last_successful_stream_ordering
|
||||
):
|
||||
continue
|
||||
|
||||
# Filter out events where the server is not in the room,
|
||||
# e.g. it may have left/been kicked. *Ideally* we'd pull
|
||||
# out the kick and send that, but it's a rare edge case
|
||||
# so we don't bother for now (the server that sent the
|
||||
# kick should send it out if its online).
|
||||
hosts = await self._state.get_hosts_in_room_at_events(
|
||||
p.room_id, [p.event_id]
|
||||
)
|
||||
if self._destination not in hosts:
|
||||
continue
|
||||
|
||||
new_pdus.append(p)
|
||||
|
||||
# If we've filtered out all the extremities, fall back to
|
||||
# sending the original event. This should ensure that the
|
||||
# server gets at least some of missed events (especially if
|
||||
# the other sending servers are up).
|
||||
if new_pdus:
|
||||
room_catchup_pdus = new_pdus
|
||||
|
||||
logger.info(
|
||||
"Catching up rooms to %s: %r", self._destination, pdu.room_id
|
||||
)
|
||||
|
||||
await self._transaction_manager.send_new_transaction(
|
||||
self._destination, room_catchup_pdus, []
|
||||
)
|
||||
|
||||
sent_transactions_counter.inc()
|
||||
|
||||
# We pulled this from the DB, so it'll be non-null
|
||||
assert pdu.internal_metadata.stream_ordering
|
||||
|
||||
# Note that we mark the last successful stream ordering as that
|
||||
# from the *original* PDU, rather than the PDU(s) we actually
|
||||
# send. This is because we use it to mark our position in the
|
||||
# queue of missed PDUs to process.
|
||||
self._last_successful_stream_ordering = (
|
||||
pdu.internal_metadata.stream_ordering
|
||||
)
|
||||
|
||||
await self._store.set_destination_last_successful_stream_ordering(
|
||||
self._destination, self._last_successful_stream_ordering
|
||||
)
|
||||
|
||||
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
|
||||
if not self._pending_rrs:
|
||||
|
|
|
@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
|
|||
logout_devices: bool,
|
||||
requester: Optional[Requester] = None,
|
||||
) -> None:
|
||||
if not self.hs.config.password_localdb_enabled:
|
||||
if not self._auth_handler.can_change_password():
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
try:
|
||||
|
|
199
synapse/handlers/space_summary.py
Normal file
199
synapse/handlers/space_summary.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Set
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import format_event_for_client_v2
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# number of rooms to return. We'll stop once we hit this limit.
|
||||
# TODO: allow clients to reduce this with a request param.
|
||||
MAX_ROOMS = 50
|
||||
|
||||
# max number of events to return per room.
|
||||
MAX_ROOMS_PER_SPACE = 50
|
||||
|
||||
|
||||
class SpaceSummaryHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._auth = hs.get_auth()
|
||||
self._room_list_handler = hs.get_room_list_handler()
|
||||
self._state_handler = hs.get_state_handler()
|
||||
self._store = hs.get_datastore()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
async def get_space_summary(
|
||||
self,
|
||||
requester: str,
|
||||
room_id: str,
|
||||
suggested_only: bool = False,
|
||||
max_rooms_per_space: Optional[int] = None,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Implementation of the space summary API
|
||||
|
||||
Args:
|
||||
requester: user id of the user making this request
|
||||
|
||||
room_id: room id to start the summary at
|
||||
|
||||
suggested_only: whether we should only return children with the "suggested"
|
||||
flag set.
|
||||
|
||||
max_rooms_per_space: an optional limit on the number of child rooms we will
|
||||
return. This does not apply to the root room (ie, room_id), and
|
||||
is overridden by ROOMS_PER_SPACE_LIMIT.
|
||||
|
||||
Returns:
|
||||
summary dict to return
|
||||
"""
|
||||
# first of all, check that the user is in the room in question (or it's
|
||||
# world-readable)
|
||||
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
|
||||
|
||||
# the queue of rooms to process
|
||||
room_queue = deque((room_id,))
|
||||
|
||||
processed_rooms = set() # type: Set[str]
|
||||
|
||||
rooms_result = [] # type: List[JsonDict]
|
||||
events_result = [] # type: List[JsonDict]
|
||||
|
||||
now = self._clock.time_msec()
|
||||
|
||||
while room_queue and len(rooms_result) < MAX_ROOMS:
|
||||
room_id = room_queue.popleft()
|
||||
logger.debug("Processing room %s", room_id)
|
||||
processed_rooms.add(room_id)
|
||||
|
||||
try:
|
||||
await self._auth.check_user_in_room_or_world_readable(
|
||||
room_id, requester
|
||||
)
|
||||
except AuthError:
|
||||
logger.info(
|
||||
"user %s cannot view room %s, omitting from summary",
|
||||
requester,
|
||||
room_id,
|
||||
)
|
||||
continue
|
||||
|
||||
room_entry = await self._build_room_entry(room_id)
|
||||
rooms_result.append(room_entry)
|
||||
|
||||
# look for child rooms/spaces.
|
||||
child_events = await self._get_child_events(room_id)
|
||||
|
||||
if suggested_only:
|
||||
# we only care about suggested children
|
||||
child_events = filter(_is_suggested_child_event, child_events)
|
||||
|
||||
# The client-specified max_rooms_per_space limit doesn't apply to the
|
||||
# room_id specified in the request, so we ignore it if this is the
|
||||
# first room we are processing. Otherwise, apply any client-specified
|
||||
# limit, capping to our built-in limit.
|
||||
if max_rooms_per_space is not None and len(processed_rooms) > 1:
|
||||
max_rooms = min(MAX_ROOMS_PER_SPACE, max_rooms_per_space)
|
||||
else:
|
||||
max_rooms = MAX_ROOMS_PER_SPACE
|
||||
|
||||
for edge_event in itertools.islice(child_events, max_rooms):
|
||||
edge_room_id = edge_event.state_key
|
||||
|
||||
events_result.append(
|
||||
await self._event_serializer.serialize_event(
|
||||
edge_event,
|
||||
time_now=now,
|
||||
event_format=format_event_for_client_v2,
|
||||
)
|
||||
)
|
||||
|
||||
# if we haven't yet visited the target of this link, add it to the queue
|
||||
if edge_room_id not in processed_rooms:
|
||||
room_queue.append(edge_room_id)
|
||||
|
||||
return {"rooms": rooms_result, "events": events_result}
|
||||
|
||||
async def _build_room_entry(self, room_id: str) -> JsonDict:
|
||||
"""Generate en entry suitable for the 'rooms' list in the summary response"""
|
||||
stats = await self._store.get_room_with_stats(room_id)
|
||||
|
||||
# currently this should be impossible because we call
|
||||
# check_user_in_room_or_world_readable on the room before we get here, so
|
||||
# there should always be an entry
|
||||
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
|
||||
|
||||
current_state_ids = await self._store.get_current_state_ids(room_id)
|
||||
create_event = await self._store.get_event(
|
||||
current_state_ids[(EventTypes.Create, "")]
|
||||
)
|
||||
|
||||
# TODO: update once MSC1772 lands
|
||||
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
|
||||
|
||||
entry = {
|
||||
"room_id": stats["room_id"],
|
||||
"name": stats["name"],
|
||||
"topic": stats["topic"],
|
||||
"canonical_alias": stats["canonical_alias"],
|
||||
"num_joined_members": stats["joined_members"],
|
||||
"avatar_url": stats["avatar"],
|
||||
"world_readable": (
|
||||
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
|
||||
),
|
||||
"guest_can_join": stats["guest_access"] == "can_join",
|
||||
"room_type": room_type,
|
||||
}
|
||||
|
||||
# Filter out Nones – rather omit the field altogether
|
||||
room_entry = {k: v for k, v in entry.items() if v is not None}
|
||||
|
||||
return room_entry
|
||||
|
||||
async def _get_child_events(self, room_id: str) -> Iterable[EventBase]:
|
||||
# look for child rooms/spaces.
|
||||
current_state_ids = await self._store.get_current_state_ids(room_id)
|
||||
|
||||
events = await self._store.get_events_as_list(
|
||||
[
|
||||
event_id
|
||||
for key, event_id in current_state_ids.items()
|
||||
# TODO: update once MSC1772 lands
|
||||
if key[0] == EventTypes.MSC1772_SPACE_CHILD
|
||||
]
|
||||
)
|
||||
|
||||
# filter out any events without a "via" (which implies it has been redacted)
|
||||
return (e for e in events if e.content.get("via"))
|
||||
|
||||
|
||||
def _is_suggested_child_event(edge_event: EventBase) -> bool:
|
||||
suggested = edge_event.content.get("suggested")
|
||||
if isinstance(suggested, bool) and suggested:
|
||||
return True
|
||||
logger.debug("Ignorning not-suggested child %s", edge_event.state_key)
|
||||
return False
|
|
@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet):
|
|||
elif not deactivate and user["deactivated"]:
|
||||
if (
|
||||
"password" not in body
|
||||
and self.hs.config.password_localdb_enabled
|
||||
and self.auth_handler.can_change_password()
|
||||
):
|
||||
raise SynapseError(
|
||||
400, "Must provide a password to re-activate an account."
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
@ -35,16 +35,25 @@ from synapse.events.utils import format_event_for_client_v2
|
|||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_boolean,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
StreamToken,
|
||||
ThirdPartyInstanceID,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||
|
||||
|
@ -987,7 +996,58 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
|||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server, is_worker=False):
|
||||
class RoomSpaceSummaryRestServlet(RestServlet):
|
||||
PATTERNS = (
|
||||
re.compile(
|
||||
"^/_matrix/client/unstable/org.matrix.msc2946"
|
||||
"/rooms/(?P<room_id>[^/]*)/spaces$"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._auth = hs.get_auth()
|
||||
self._space_summary_handler = hs.get_space_summary_handler()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
return 200, await self._space_summary_handler.get_space_summary(
|
||||
requester.user.to_string(),
|
||||
room_id,
|
||||
suggested_only=parse_boolean(request, "suggested_only", default=False),
|
||||
max_rooms_per_space=parse_integer(request, "max_rooms_per_space"),
|
||||
)
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
suggested_only = content.get("suggested_only", False)
|
||||
if not isinstance(suggested_only, bool):
|
||||
raise SynapseError(
|
||||
400, "'suggested_only' must be a boolean", Codes.BAD_JSON
|
||||
)
|
||||
|
||||
max_rooms_per_space = content.get("max_rooms_per_space")
|
||||
if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int):
|
||||
raise SynapseError(
|
||||
400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON
|
||||
)
|
||||
|
||||
return 200, await self._space_summary_handler.get_space_summary(
|
||||
requester.user.to_string(),
|
||||
room_id,
|
||||
suggested_only=suggested_only,
|
||||
max_rooms_per_space=max_rooms_per_space,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server, is_worker=False):
|
||||
RoomStateEventRestServlet(hs).register(http_server)
|
||||
RoomMemberListRestServlet(hs).register(http_server)
|
||||
JoinedRoomMemberListRestServlet(hs).register(http_server)
|
||||
|
@ -1001,6 +1061,9 @@ def register_servlets(hs, http_server, is_worker=False):
|
|||
RoomTypingRestServlet(hs).register(http_server)
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.experimental.spaces_enabled:
|
||||
RoomSpaceSummaryRestServlet(hs).register(http_server)
|
||||
|
||||
# Some servlets only get registered for the main process.
|
||||
if not is_worker:
|
||||
RoomCreateRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -100,6 +100,7 @@ from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHand
|
|||
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||
from synapse.handlers.search import SearchHandler
|
||||
from synapse.handlers.set_password import SetPasswordHandler
|
||||
from synapse.handlers.space_summary import SpaceSummaryHandler
|
||||
from synapse.handlers.sso import SsoHandler
|
||||
from synapse.handlers.stats import StatsHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
|
@ -732,6 +733,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
def get_account_data_handler(self) -> AccountDataHandler:
|
||||
return AccountDataHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_space_summary_handler(self) -> SpaceSummaryHandler:
|
||||
return SpaceSummaryHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_external_cache(self) -> ExternalCache:
|
||||
return ExternalCache(self)
|
||||
|
|
|
@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_deactivated_status, (user_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
@cached()
|
||||
|
|
|
@ -2,6 +2,7 @@ from typing import List, Tuple
|
|||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.sender import PerDestinationQueue, TransactionManager
|
||||
from synapse.federation.units import Edu
|
||||
|
@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
|||
self.assertNotIn("zzzerver", woken)
|
||||
# - all destinations are woken exactly once; they appear once in woken.
|
||||
self.assertCountEqual(woken, server_names[:-1])
|
||||
|
||||
@override_config({"send_federation": True})
|
||||
def test_not_latest_event(self):
|
||||
"""Test that we send the latest event in the room even if its not ours."""
|
||||
|
||||
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
|
||||
|
||||
# Make a room with a local user, and two servers. One will go offline
|
||||
# and one will send some events.
|
||||
self.register_user("u1", "you the one")
|
||||
u1_token = self.login("u1", "you the one")
|
||||
room_1 = self.helper.create_room_as("u1", tok=u1_token)
|
||||
|
||||
self.get_success(
|
||||
event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
|
||||
)
|
||||
event_1 = self.get_success(
|
||||
event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
|
||||
)
|
||||
|
||||
# First we send something from the local server, so that we notice the
|
||||
# remote is down and go into catchup mode.
|
||||
self.helper.send(room_1, "you hear me!!", tok=u1_token)
|
||||
|
||||
# Now simulate us receiving an event from the still online remote.
|
||||
event_2 = self.get_success(
|
||||
event_injection.inject_event(
|
||||
self.hs,
|
||||
type=EventTypes.Message,
|
||||
sender="@user:host3",
|
||||
room_id=room_1,
|
||||
content={"msgtype": "m.text", "body": "Hello"},
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_datastore().set_destination_last_successful_stream_ordering(
|
||||
"host2", event_1.internal_metadata.stream_ordering
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(per_dest_queue._catch_up_transmission_loop())
|
||||
|
||||
# We expect only the last message from the remote, event_2, to have been
|
||||
# sent, rather than the last *local* event that was sent.
|
||||
self.assertEqual(len(sent_pdus), 1)
|
||||
self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
|
||||
self.assertFalse(per_dest_queue._catching_up)
|
||||
|
|
|
@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
# create users and get access tokens
|
||||
# regardless of whether password login or SSO is allowed
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
self.admin_user_tok = self.get_success(
|
||||
self.auth_handler.get_access_token_for_user_id(
|
||||
self.admin_user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
self.other_user = self.register_user("user", "pass", displayname="User")
|
||||
self.other_user_token = self.login("user", "pass")
|
||||
self.other_user_token = self.get_success(
|
||||
self.auth_handler.get_access_token_for_user_id(
|
||||
self.other_user, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
|
||||
self.other_user
|
||||
)
|
||||
|
@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertTrue(channel.json_body["admin"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
# Get user
|
||||
|
@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertEqual(False, channel.json_body["is_guest"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self.assertTrue(channel.json_body["admin"])
|
||||
self.assertFalse(channel.json_body["is_guest"])
|
||||
self.assertFalse(channel.json_body["deactivated"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
def test_create_user(self):
|
||||
|
@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertFalse(channel.json_body["admin"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
# Get user
|
||||
|
@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertEqual(False, channel.json_body["is_guest"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self.assertEqual(False, channel.json_body["shadow_banned"])
|
||||
self.assertFalse(channel.json_body["admin"])
|
||||
self.assertFalse(channel.json_body["is_guest"])
|
||||
self.assertFalse(channel.json_body["deactivated"])
|
||||
self.assertFalse(channel.json_body["shadow_banned"])
|
||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||
|
||||
@override_config(
|
||||
|
@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertFalse(channel.json_body["admin"])
|
||||
|
||||
@override_config(
|
||||
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
|
||||
|
@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
# Admin user is not blocked by mau anymore
|
||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertFalse(channel.json_body["admin"])
|
||||
|
||||
@override_config(
|
||||
{
|
||||
|
@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self.assertFalse(channel.json_body["deactivated"])
|
||||
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
||||
self.assertEqual("User", channel.json_body["displayname"])
|
||||
|
||||
# Deactivate user
|
||||
body = json.dumps({"deactivated": True})
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
content={"deactivated": True},
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
self.assertTrue(channel.json_body["deactivated"])
|
||||
self.assertIsNone(channel.json_body["password_hash"])
|
||||
self.assertEqual(0, len(channel.json_body["threepids"]))
|
||||
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
||||
self.assertEqual("User", channel.json_body["displayname"])
|
||||
|
@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
self.assertTrue(channel.json_body["deactivated"])
|
||||
self.assertIsNone(channel.json_body["password_hash"])
|
||||
self.assertEqual(0, len(channel.json_body["threepids"]))
|
||||
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
||||
self.assertEqual("User", channel.json_body["displayname"])
|
||||
|
@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertTrue(profile["display_name"] == "User")
|
||||
|
||||
# Deactivate user
|
||||
body = json.dumps({"deactivated": True})
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
content={"deactivated": True},
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
self.assertTrue(channel.json_body["deactivated"])
|
||||
|
||||
# is not in user directory
|
||||
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
||||
self.assertTrue(profile is None)
|
||||
self.assertIsNone(profile)
|
||||
|
||||
# Set new displayname user
|
||||
body = json.dumps({"displayname": "Foobar"})
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
content={"displayname": "Foobar"},
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
self.assertTrue(channel.json_body["deactivated"])
|
||||
self.assertEqual("Foobar", channel.json_body["displayname"])
|
||||
|
||||
# is not in user directory
|
||||
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
||||
self.assertTrue(profile is None)
|
||||
self.assertIsNone(profile)
|
||||
|
||||
def test_reactivate_user(self):
|
||||
"""
|
||||
|
@ -1520,24 +1527,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
"""
|
||||
|
||||
# Deactivate the user.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self._is_erased("@user:test", False)
|
||||
d = self.store.mark_user_erased("@user:test")
|
||||
self.assertIsNone(self.get_success(d))
|
||||
self._is_erased("@user:test", True)
|
||||
self._deactivate_user("@user:test")
|
||||
|
||||
# Attempt to reactivate the user (without a password).
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
|
||||
content={"deactivated": False},
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
|
@ -1546,22 +1543,76 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=json.dumps({"deactivated": False, "password": "foo"}).encode(
|
||||
encoding="utf_8"
|
||||
),
|
||||
content={"deactivated": False, "password": "foo"},
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
self.assertFalse(channel.json_body["deactivated"])
|
||||
self.assertIsNotNone(channel.json_body["password_hash"])
|
||||
self._is_erased("@user:test", False)
|
||||
|
||||
@override_config({"password_config": {"localdb_enabled": False}})
|
||||
def test_reactivate_user_localdb_disabled(self):
|
||||
"""
|
||||
Test reactivating another user when using SSO.
|
||||
"""
|
||||
|
||||
# Deactivate the user.
|
||||
self._deactivate_user("@user:test")
|
||||
|
||||
# Reactivate the user with a password
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content={"deactivated": False, "password": "foo"},
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
# Reactivate the user without a password.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content={"deactivated": False},
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertFalse(channel.json_body["deactivated"])
|
||||
self.assertIsNone(channel.json_body["password_hash"])
|
||||
self._is_erased("@user:test", False)
|
||||
|
||||
@override_config({"password_config": {"enabled": False}})
|
||||
def test_reactivate_user_password_disabled(self):
|
||||
"""
|
||||
Test reactivating another user when using SSO.
|
||||
"""
|
||||
|
||||
# Deactivate the user.
|
||||
self._deactivate_user("@user:test")
|
||||
|
||||
# Reactivate the user with a password
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content={"deactivated": False, "password": "foo"},
|
||||
)
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
# Reactivate the user without a password.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content={"deactivated": False},
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertFalse(channel.json_body["deactivated"])
|
||||
self.assertIsNone(channel.json_body["password_hash"])
|
||||
self._is_erased("@user:test", False)
|
||||
|
||||
def test_set_user_as_admin(self):
|
||||
|
@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
"""
|
||||
|
||||
# Set a user as an admin
|
||||
body = json.dumps({"admin": True})
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
content={"admin": True},
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertTrue(channel.json_body["admin"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
|
@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertTrue(channel.json_body["admin"])
|
||||
|
||||
def test_accidental_deactivation_prevention(self):
|
||||
"""
|
||||
|
@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
url = "/_synapse/admin/v2/users/@bob:test"
|
||||
|
||||
# Create user
|
||||
body = json.dumps({"password": "abc123"})
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
content={"password": "abc123"},
|
||||
)
|
||||
|
||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(0, channel.json_body["deactivated"])
|
||||
|
||||
# Change password (and use a str for deactivate instead of a bool)
|
||||
body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
content={"password": "abc123", "deactivated": "false"},
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure they're still alive
|
||||
self.assertEqual(0, channel.json_body["deactivated"])
|
||||
|
||||
def _is_erased(self, user_id, expect):
|
||||
def _is_erased(self, user_id: str, expect: bool) -> None:
|
||||
"""Assert that the user is erased or not"""
|
||||
d = self.store.is_user_erased(user_id)
|
||||
if expect:
|
||||
|
@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
else:
|
||||
self.assertFalse(self.get_success(d))
|
||||
|
||||
def _deactivate_user(self, user_id: str) -> None:
|
||||
"""Deactivate user and set as erased"""
|
||||
|
||||
# Deactivate the user.
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
|
||||
access_token=self.admin_user_tok,
|
||||
content={"deactivated": True},
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertTrue(channel.json_body["deactivated"])
|
||||
self.assertIsNone(channel.json_body["password_hash"])
|
||||
self._is_erased(user_id, False)
|
||||
d = self.store.mark_user_erased(user_id)
|
||||
self.assertIsNone(self.get_success(d))
|
||||
self._is_erased(user_id, True)
|
||||
|
||||
|
||||
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue