Merge branch 'develop' into matrix-org-hotfixes

This commit is contained in:
Brendan Abolivier 2020-03-09 15:06:56 +00:00
commit 74050d0c1c
83 changed files with 1912 additions and 491 deletions

View file

@ -6,12 +6,7 @@
set -ex
apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev
# workaround for https://github.com/jaraco/zipp/issues/40
python3.5 -m pip install 'setuptools>=34.4.0'
python3.5 -m pip install tox
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
export LANG="C.UTF-8"

View file

@ -1,3 +1,18 @@
Synapse 1.11.1 (2020-03-03)
===========================
This release includes a security fix impacting installations using Single Sign-On (i.e. SAML2 or CAS) for authentication. Administrators of such installations are encouraged to upgrade as soon as possible.
The release also includes fixes for a couple of other bugs.
Bugfixes
--------
- Add a confirmation step to the SSO login flow before redirecting users to the redirect URL. ([b2bd54a2](https://github.com/matrix-org/synapse/commit/b2bd54a2e31d9a248f73fadb184ae9b4cbdb49f9), [65c73cdf](https://github.com/matrix-org/synapse/commit/65c73cdfec1876a9fec2fd2c3a74923cd146fe0b), [a0178df1](https://github.com/matrix-org/synapse/commit/a0178df10422a76fd403b82d2b2a4ed28a9a9d1e))
- Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/<user_id>`. Contributed by @dklimpel. ([\#6910](https://github.com/matrix-org/synapse/issues/6910))
- Fix bug introduced in Synapse 1.11.0 which sometimes caused errors when joining rooms over federation, with `'coroutine' object has no attribute 'event_id'`. ([\#6996](https://github.com/matrix-org/synapse/issues/6996))
Synapse 1.11.0 (2020-02-21)
===========================

View file

@ -418,7 +418,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
for having Synapse automatically provision and renew federation
certificates through ACME can be found at [ACME.md](docs/ACME.md).
Note that, as pointed out in that document, this feature will not
work with installs set up after November 2020.
work with installs set up after November 2019.
If you are using your own certificate, be sure to use a `.pem` file that
includes the full certificate chain including any intermediate certificates

1
changelog.d/6309.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `logging/context.py`.

1
changelog.d/6315.feature Normal file
View file

@ -0,0 +1 @@
Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0.

1
changelog.d/6874.misc Normal file
View file

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

1
changelog.d/6875.misc Normal file
View file

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

1
changelog.d/6971.feature Normal file
View file

@ -0,0 +1 @@
Validate the alt_aliases property of canonical alias events.

1
changelog.d/6986.feature Normal file
View file

@ -0,0 +1 @@
Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases.

1
changelog.d/6987.misc Normal file
View file

@ -0,0 +1 @@
Add some type annotations to the database storage classes.

1
changelog.d/6995.misc Normal file
View file

@ -0,0 +1 @@
Add some type annotations to the federation base & client classes.

1
changelog.d/7002.misc Normal file
View file

@ -0,0 +1 @@
Merge worker apps together.

1
changelog.d/7003.misc Normal file
View file

@ -0,0 +1 @@
Refactoring work in preparation for changing the event redaction algorithm.

1
changelog.d/7015.misc Normal file
View file

@ -0,0 +1 @@
Change date in INSTALL.md#tls-certificates for last date of getting TLS certificates to November 2019.

1
changelog.d/7018.bugfix Normal file
View file

@ -0,0 +1 @@
Fix py35-old CI by using native tox package.

1
changelog.d/7019.misc Normal file
View file

@ -0,0 +1 @@
Port `synapse.handlers.presence` to async/await.

1
changelog.d/7020.misc Normal file
View file

@ -0,0 +1 @@
Port `synapse.rest.keys` to async/await.

1
changelog.d/7030.feature Normal file
View file

@ -0,0 +1 @@
Break down monthly active users by `appservice_id` and emit via Prometheus.

1
changelog.d/7035.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug causing `org.matrix.dummy_event` to be included in responses from `/sync`.

1
changelog.d/7037.feature Normal file
View file

@ -0,0 +1 @@
Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432).

1
changelog.d/7045.misc Normal file
View file

@ -0,0 +1 @@
Add a type check to `is_verified` when processing room keys.

1
changelog.d/7048.doc Normal file
View file

@ -0,0 +1 @@
Document that the fallback auth endpoints must be routed to the same worker node as the register endpoints.

1
changelog.d/7055.misc Normal file
View file

@ -0,0 +1 @@
Merge worker apps together.

View file

@ -15,10 +15,9 @@ services:
restart: unless-stopped
# See the readme for a full documentation of the environment settings
environment:
- SYNAPSE_CONFIG_PATH=/etc/homeserver.yaml
- SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
volumes:
# You may either store all the files in a local folder
- ./matrix-config/homeserver.yaml:/etc/homeserver.yaml
- ./files:/data
# .. or you may split this between different storage points
# - ./files:/data

View file

@ -1,6 +1,6 @@
# Using the Synapse Grafana dashboard
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
3. Set up additional recording rules

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.11.1) stable; urgency=medium
* New synapse release 1.11.1.
-- Synapse Packaging team <packages@matrix.org> Tue, 03 Mar 2020 15:01:22 +0000
matrix-synapse-py3 (1.11.0) stable; urgency=medium
* New synapse release 1.11.0.

View file

@ -1360,6 +1360,56 @@ saml2_config:
# # name: value
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
#
sso:
# A list of client URLs which are whitelisted so that the user does not
# have to confirm giving access to their account to the URL. Any client
# whose URL starts with an entry in the following list will not be subject
# to an additional confirmation step after the SSO login is completed.
#
# WARNING: An entry such as "https://my.client" is insecure, because it
# will also match "https://my.client.evil.site", exposing your users to
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# By default, this list is empty.
#
#client_whitelist:
# - https://riot.im/develop
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
# When rendering, this template is given three variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * display_url: the same as `redirect_url`, but with the query
# parameters stripped. The intention is to have a
# human-readable URL to show to users, not to use it as
# the final address to redirect to. Needs manual escaping
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * server_name: the homeserver's name.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
#jwt_config:

View file

@ -273,6 +273,7 @@ Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance:
^/_matrix/client/(r0|unstable)/register$
^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
Pagination requests can also be handled, but all requests with the same path
room must be routed to the same instance. Additionally, care must be taken to

View file

@ -1,20 +1,31 @@
name: matrix-synapse
base: core18
version: git
version: git
summary: Reference Matrix homeserver
description: |
Synapse is the reference Matrix homeserver.
Matrix is a federated and decentralised instant messaging and VoIP system.
grade: stable
confinement: strict
grade: stable
confinement: strict
apps:
matrix-synapse:
matrix-synapse:
command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml
stop-command: synctl -c $SNAP_COMMON stop
plugs: [network-bind, network]
daemon: simple
daemon: simple
hash-password:
command: hash_password
generate-config:
command: generate_config
generate-signing-key:
command: generate_signing_key.py
register-new-matrix-user:
command: register_new_matrix_user
plugs: [network]
synctl:
command: synctl
parts:
matrix-synapse:
source: .

View file

@ -36,7 +36,7 @@ try:
except ImportError:
pass
__version__ = "1.11.0"
__version__ = "1.11.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -539,7 +539,7 @@ class Auth(object):
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id: str, user: UserID):
"""Check if the user is allowed to edit the room's entry in the
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
Args:
@ -570,12 +570,7 @@ class Auth(object):
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
" edit its room list entry",
)
return user_level >= send_level
@staticmethod
def has_access_token(request):

View file

@ -66,6 +66,7 @@ class Codes(object):
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
class CodeMessageException(RuntimeError):

View file

@ -57,7 +57,7 @@ class RoomVersion(object):
state_res = attr.ib() # int; one of the StateResolutionVersions
enforce_key_validity = attr.ib() # bool
# bool: before MSC2260, anyone was allowed to send an aliases event
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool, default=False)
@ -102,12 +102,13 @@ class RoomVersions(object):
enforce_key_validity=True,
special_case_aliases_auth=True,
)
MSC2260_DEV = RoomVersion(
"org.matrix.msc2260",
MSC2432_DEV = RoomVersion(
"org.matrix.msc2432",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
)
@ -119,6 +120,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V3,
RoomVersions.V4,
RoomVersions.V5,
RoomVersions.MSC2260_DEV,
RoomVersions.MSC2432_DEV,
)
} # type: Dict[str, RoomVersion]

View file

@ -494,20 +494,26 @@ class GenericWorkerServer(HomeServer):
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
media_repo = self.get_media_repository_resource()
if self.config.can_load_media_repo:
media_repo = self.get_media_repository_resource()
# We need to serve the admin servlets for media on the
# worker.
admin_resource = JsonResource(self, canonical_json=False)
register_servlets_for_media_repo(self, admin_resource)
# We need to serve the admin servlets for media on the
# worker.
admin_resource = JsonResource(self, canonical_json=False)
register_servlets_for_media_repo(self, admin_resource)
resources.update(
{
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
"/_synapse/admin": admin_resource,
}
)
resources.update(
{
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
"/_synapse/admin": admin_resource,
}
)
else:
logger.warning(
"A 'media' listener is configured but the media"
" repository is disabled. Ignoring."
)
if name == "openid" and "federation" not in res["names"]:
# Only load the openid resource separately if federation resource

View file

@ -298,6 +298,11 @@ class SynapseHomeServer(HomeServer):
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
current_mau_by_service_gauge = Gauge(
"synapse_admin_mau_current_mau_by_service",
"Current MAU by service",
["app_service"],
)
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
registered_reserved_users_mau_gauge = Gauge(
"synapse_admin_mau:registered_reserved_users",
@ -585,12 +590,20 @@ def run(hs):
@defer.inlineCallbacks
def generate_monthly_active_users():
current_mau_count = 0
current_mau_count_by_service = {}
reserved_users = ()
store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count()
current_mau_count_by_service = (
yield store.get_monthly_active_count_by_service()
)
reserved_users = yield store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
for app_service, count in current_mau_count_by_service.items():
current_mau_by_service_gauge.labels(app_service).set(float(count))
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.max_mau_value))

View file

@ -24,6 +24,7 @@ from synapse.config import (
server,
server_notices_config,
spam_checker,
sso,
stats,
third_party_event_rules,
tls,
@ -57,6 +58,7 @@ class RootConfig:
key: key.KeyConfig
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
sso: sso.SSOConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig

View file

@ -38,6 +38,7 @@ from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
from .sso import SSOConfig
from .stats import StatsConfig
from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig):
KeyConfig,
SAML2Config,
CasConfig,
SSOConfig,
JWTConfig,
PasswordConfig,
EmailConfig,

92
synapse/config/sso.py Normal file
View file

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import pkg_resources
from ._base import Config
class SSOConfig(Config):
"""SSO Configuration
"""
section = "sso"
def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any]
# Pick a template directory in order of:
# * The sso-specific template_dir
# * /path/to/synapse/install/res/templates
template_dir = sso_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
self.sso_redirect_confirm_template_dir = template_dir
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
def generate_config_section(self, **kwargs):
return """\
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
#
sso:
# A list of client URLs which are whitelisted so that the user does not
# have to confirm giving access to their account to the URL. Any client
# whose URL starts with an entry in the following list will not be subject
# to an additional confirmation step after the SSO login is completed.
#
# WARNING: An entry such as "https://my.client" is insecure, because it
# will also match "https://my.client.evil.site", exposing your users to
# phishing attacks from evil.site. To avoid this, include a slash after the
# hostname: "https://my.client/".
#
# By default, this list is empty.
#
#client_whitelist:
# - https://riot.im/develop
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page for a confirmation step before redirecting back to the client
# with the login token: 'sso_redirect_confirm.html'.
#
# When rendering, this template is given three variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * display_url: the same as `redirect_url`, but with the query
# parameters stripped. The intention is to have a
# human-readable URL to show to users, not to use it as
# the final address to redirect to. Needs manual escaping
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * server_name: the homeserver's name.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
"""

View file

@ -140,7 +140,7 @@ def compute_event_signature(
Returns:
a dictionary in the same format of an event's signatures field.
"""
redact_json = prune_event_dict(event_dict)
redact_json = prune_event_dict(room_version, event_dict)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
if logger.isEnabledFor(logging.DEBUG):

View file

@ -137,7 +137,7 @@ def check(
raise AuthError(403, "This room has been marked as unfederatable.")
# 4. If type is m.room.aliases
if event.type == EventTypes.Aliases:
if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth:
# 4a. If event has no state_key, reject
if not event.is_state():
raise AuthError(403, "Alias event must be a state event")
@ -152,10 +152,8 @@ def check(
)
# 4c. Otherwise, allow.
# This is removed by https://github.com/matrix-org/matrix-doc/pull/2260
if room_version_obj.special_case_aliases_auth:
logger.debug("Allowing! %s", event)
return
logger.debug("Allowing! %s", event)
return
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])

View file

@ -15,9 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import os
from distutils.util import strtobool
from typing import Optional, Type
from typing import Dict, Optional, Type
import six
@ -199,15 +200,25 @@ class _EventInternalMetadata(object):
return self._dict.get("redacted", False)
class EventBase(object):
class EventBase(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def format_version(self) -> int:
"""The EventFormatVersion implemented by this event"""
...
def __init__(
self,
event_dict,
signatures={},
unsigned={},
internal_metadata_dict={},
rejected_reason=None,
event_dict: JsonDict,
room_version: RoomVersion,
signatures: Dict[str, Dict[str, str]],
unsigned: JsonDict,
internal_metadata_dict: JsonDict,
rejected_reason: Optional[str],
):
assert room_version.event_format == self.format_version
self.room_version = room_version
self.signatures = signatures
self.unsigned = unsigned
self.rejected_reason = rejected_reason
@ -303,7 +314,13 @@ class EventBase(object):
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
def __init__(
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: JsonDict = {},
rejected_reason: Optional[str] = None,
):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@ -326,8 +343,9 @@ class FrozenEvent(EventBase):
self._event_id = event_dict["event_id"]
super(FrozenEvent, self).__init__(
super().__init__(
frozen_dict,
room_version=room_version,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
@ -352,7 +370,13 @@ class FrozenEvent(EventBase):
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
def __init__(
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: JsonDict = {},
rejected_reason: Optional[str] = None,
):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@ -377,8 +401,9 @@ class FrozenEventV2(EventBase):
self._event_id = None
super(FrozenEventV2, self).__init__(
super().__init__(
frozen_dict,
room_version=room_version,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
@ -445,7 +470,7 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id
def event_type_from_format_version(format_version: int) -> Type[EventBase]:
def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
"""Returns the python type to use to construct an Event object for the
given event format version.
@ -474,5 +499,5 @@ def make_event_from_dict(
rejected_reason: Optional[str] = None,
) -> EventBase:
"""Construct an EventBase from the given event dict"""
event_type = event_type_from_format_version(room_version.event_format)
return event_type(event_dict, internal_metadata_dict, rejected_reason)
event_type = _event_type_from_format_version(room_version.event_format)
return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)

View file

@ -23,6 +23,7 @@ from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase
@ -35,26 +36,20 @@ from . import EventBase
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event):
def prune_event(event: EventBase) -> EventBase:
""" Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like
type, state_key etc.
Args:
event (FrozenEvent)
Returns:
FrozenEvent
"""
pruned_event_dict = prune_event_dict(event.get_dict())
pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
from . import event_type_from_format_version
from . import make_event_from_dict
pruned_event = event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
pruned_event = make_event_from_dict(
pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
)
# Mark the event as redacted
@ -63,15 +58,12 @@ def prune_event(event):
return pruned_event
def prune_event_dict(event_dict):
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
Args:
event_dict (dict)
Returns:
dict: A copy of the pruned event dict
A copy of the pruned event dict
"""
allowed_keys = [
@ -118,7 +110,7 @@ def prune_event_dict(event_dict):
"kick",
"redact",
)
elif event_type == EventTypes.Aliases:
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")

View file

@ -15,11 +15,13 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import Iterable, List
import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
from twisted.internet.defer import Deferred, DeferredList
from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@ -29,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
@ -56,7 +59,12 @@ class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(
self, origin, pdus, room_version, outlier=False, include_none=False
self,
origin: str,
pdus: List[EventBase],
room_version: str,
outlier: bool = False,
include_none: bool = False,
):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
@ -69,11 +77,11 @@ class FederationBase(object):
a new list.
Args:
origin (str)
pdu (list)
room_version (str)
outlier (bool): Whether the events are outliers or not
include_none (str): Whether to include None in the returned list
origin
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
@ -82,7 +90,7 @@ class FederationBase(object):
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu, deferred):
def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
except SynapseError:
@ -96,12 +104,16 @@ class FederationBase(object):
if not res and pdu.origin != origin:
try:
res = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
# This should not exist in the base implementation, until
# this is fixed, ignore it for typing. See issue #6997.
res = yield defer.ensureDeferred(
self.get_pdu( # type: ignore
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
)
except SynapseError:
pass
@ -125,21 +137,23 @@ class FederationBase(object):
else:
return [p for p in valid_pdus if p]
def _check_sigs_and_hash(self, room_version, pdu):
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(self, room_version, pdus):
def _check_sigs_and_hashes(
self, room_version: str, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
room_version (str): The room version of the PDUs
pdus (list[FrozenEvent]): the events to be checked
room_version: The room version of the PDUs
pdus: the events to be checked
Returns:
list[Deferred]: for each input event, a deferred which:
For each input event, a deferred which:
* returns the original event if the checks pass
* returns a redacted version of the event (if the signature
matched but the hash did not)
@ -150,7 +164,7 @@ class FederationBase(object):
ctx = LoggingContext.current_context()
def callback(_, pdu):
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was
@ -187,7 +201,7 @@ class FederationBase(object):
return pdu
def errback(failure, pdu):
def errback(failure: Failure, pdu: EventBase):
failure.trap(SynapseError)
with PreserveLoggingContext(ctx):
logger.warning(
@ -213,16 +227,18 @@ class PduToCheckSig(
pass
def _check_sigs_on_pdus(keyring, room_version, pdus):
def _check_sigs_on_pdus(
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
room_version (str): the room version of the PDUs
pdus (Collection[EventBase]): the events to be checked
keyring: keyring object to do the checks
room_version: the room version of the PDUs
pdus: the events to be checked
Returns:
List[Deferred]: a Deferred for each event in pdus, which will either succeed if
A Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
@ -327,7 +343,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
"""Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred.
"""
@ -339,7 +355,7 @@ def _flatten_deferred_list(deferreds):
return defer.succeed(None)
def _is_invite_via_3pid(event):
def _is_invite_via_3pid(event: EventBase) -> bool:
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE

View file

@ -187,7 +187,7 @@ class FederationClient(FederationBase):
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
) -> List[EventBase]:
) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the
given destination server.
@ -199,9 +199,9 @@ class FederationClient(FederationBase):
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
# If there are no extremities then we've (probably) reached the start.
if not extremities:
return
return None
transaction_data = await self.transport_layer.backfill(
dest, room_id, extremities, limit
@ -284,7 +284,7 @@ class FederationClient(FederationBase):
pdu_list = [
event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
]
] # type: List[EventBase]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
@ -615,7 +615,7 @@ class FederationClient(FederationBase):
]
if auth_chain_create_events != [create_event.event_id]:
raise InvalidResponseError(
"Unexpected create event(s) in auth chain"
"Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,)
)

View file

@ -17,6 +17,8 @@
import logging
import time
import unicodedata
import urllib.parse
from typing import Any
import attr
import bcrypt
@ -38,8 +40,11 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID
from synapse.util.caches.expiringcache import ExpiringCache
@ -108,6 +113,16 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
# Load the SSO redirect confirmation page HTML template
self._sso_redirect_confirm_template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip):
"""
@ -927,6 +942,65 @@ class AuthHandler(BaseHandler):
else:
return defer.succeed(False)
def complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
"""
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
)
# Append the login token to the original redirect URL (i.e. with its query
# parameters kept intact) to build the URL to which the template needs to
# redirect the users once they have clicked on the confirmation link.
redirect_url = self.add_query_param_to_url(
client_redirect_url, "loginToken", login_token
)
# if the client is whitelisted, we can redirect straight to it
if client_redirect_url.startswith(self._whitelisted_sso_clients):
request.redirect(redirect_url)
finish_request(request)
return
# Otherwise, serve the redirect confirmation page.
# Remove the query parameters from the redirect URL to get a shorter version of
# it. This is only to display a human-readable URL in the template, but not the
# URL we redirect users to.
redirect_url_no_params = client_redirect_url.split("?")[0]
html = self._sso_redirect_confirm_template.render(
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
).encode("utf-8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(html),))
request.write(html)
finish_request(request)
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({param_name: param})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
@attr.s
class MacaroonGenerator(object):

View file

@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import string
from typing import List
from typing import Iterable, List, Optional
from twisted.internet import defer
@ -30,6 +28,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.appservice import ApplicationService
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler
@ -57,7 +56,13 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None):
def _create_association(
self,
room_alias: RoomAlias,
room_id: str,
servers: Optional[Iterable[str]] = None,
creator: Optional[str] = None,
):
# general association creation for both human users and app services
for wchar in string.whitespace:
@ -83,17 +88,21 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def create_association(
self, requester, room_alias, room_id, servers=None, check_membership=True,
self,
requester: Requester,
room_alias: RoomAlias,
room_id: str,
servers: Optional[List[str]] = None,
check_membership: bool = True,
):
"""Attempt to create a new alias
Args:
requester (Requester)
room_alias (RoomAlias)
room_id (str)
servers (list[str]|None): List of servers that others servers
should try and join via
check_membership (bool): Whether to check if the user is in the room
requester
room_alias
room_id
servers: Iterable of servers that others servers should try and join via
check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it).
Returns:
@ -147,15 +156,15 @@ class DirectoryHandler(BaseHandler):
yield self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks
def delete_association(self, requester, room_alias):
def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
delete_appservice_association)
Args:
requester (Requester):
room_alias (RoomAlias):
requester
room_alias
Returns:
Deferred[unicode]: room id that the alias used to point to
@ -191,16 +200,16 @@ class DirectoryHandler(BaseHandler):
room_id = yield self._delete_association(room_alias)
try:
yield self._update_canonical_alias(
requester, requester.user.to_string(), room_id, room_alias
)
yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
return room_id
@defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias):
def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias
):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
@ -210,7 +219,7 @@ class DirectoryHandler(BaseHandler):
yield self._delete_association(room_alias)
@defer.inlineCallbacks
def _delete_association(self, room_alias):
def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
@ -219,7 +228,7 @@ class DirectoryHandler(BaseHandler):
return room_id
@defer.inlineCallbacks
def get_association(self, room_alias):
def get_association(self, room_alias: RoomAlias):
room_id = None
if self.hs.is_mine(room_alias):
result = yield self.get_association_from_room_alias(room_alias)
@ -284,7 +293,9 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
@ -307,15 +318,17 @@ class DirectoryHandler(BaseHandler):
send_update = True
content.pop("alias", "")
# Filter alt_aliases for the removed alias.
alt_aliases = content.pop("alt_aliases", None)
# If the aliases are not a list (or not found) do not attempt to modify
# the list.
if isinstance(alt_aliases, collections.Sequence):
# Filter the alt_aliases property for the removed alias. Note that the
# value is not modified if alt_aliases is of an unexpected form.
alt_aliases = content.get("alt_aliases")
if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases:
send_update = True
alt_aliases = [alias for alias in alt_aliases if alias != alias_str]
if alt_aliases:
content["alt_aliases"] = alt_aliases
else:
del content["alt_aliases"]
if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event(
@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
def get_association_from_room_alias(self, room_alias: RoomAlias):
result = yield self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias)
return result
def can_modify_alias(self, alias, user_id=None):
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler):
return defer.succeed(True)
@defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id):
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
One of the following must be true:
1. The user created the alias.
2. The user is a server administrator.
3. The user has a power-level sufficient to send a canonical alias event
for the current room.
"""
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
return True
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
return is_admin
# Resolve the alias to the corresponding room.
room_mapping = yield self.get_association(alias)
room_id = room_mapping["room_id"]
if not room_id:
return False
res = yield self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id)
)
return res
@defer.inlineCallbacks
def edit_published_room_list(self, requester, room_id, visibility):
def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
):
"""Edit the entry of the room in the published room list.
requester
room_id (str)
visibility (str): "public" or "private"
room_id
visibility: "public" or "private"
"""
user_id = requester.user.to_string()
@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler):
if room is None:
raise SynapseError(400, "Unknown room")
yield self.auth.check_can_change_room_list(room_id, requester.user)
can_change_room_list = yield self.auth.check_can_change_room_list(
room_id, requester.user
)
if not can_change_room_list:
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
" edit its room list entry",
)
making_public = visibility == "public"
if making_public:
@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def edit_published_appservice_room_list(
self, appservice_id, network_id, room_id, visibility
self, appservice_id: str, network_id: str, room_id: str, visibility: str
):
"""Add or remove a room from the appservice/network specific public
room list.
Args:
appservice_id (str): ID of the appservice that owns the list
network_id (str): The ID of the network the list is associated with
room_id (str)
visibility (str): either "public" or "private"
appservice_id: ID of the appservice that owns the list
network_id: The ID of the network the list is associated with
room_id
visibility: either "public" or "private"
"""
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")

View file

@ -207,6 +207,13 @@ class E2eRoomKeysHandler(object):
changed = False # if anything has changed, we need to update the etag
for room_id, room in iteritems(room_keys["rooms"]):
for session_id, room_key in iteritems(room["sessions"]):
if not isinstance(room_key["is_verified"], bool):
msg = (
"is_verified must be a boolean in keys for session %s in"
"room %s" % (session_id, room_id)
)
raise SynapseError(400, msg, Codes.INVALID_PARAM)
log_kv(
{
"message": "Trying to upload room key",

View file

@ -888,19 +888,60 @@ class EventCreationHandler(object):
yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
# Validate a newly added alias or newly added alt_aliases.
original_alias = None
original_alt_aliases = set()
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = yield self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
original_alt_aliases = original_event.content.get("alt_aliases", [])
# Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
if room_alias_str:
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
Codes.BAD_ALIAS,
)
# Check that alt_aliases is the proper form.
alt_aliases = event.content.get("alt_aliases", [])
if not isinstance(alt_aliases, (list, tuple)):
raise SynapseError(
400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
)
# If the old version of alt_aliases is of an unknown form,
# completely replace it.
if not isinstance(original_alt_aliases, (list, tuple)):
original_alt_aliases = []
# Check that each alias is currently valid.
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
room_alias = RoomAlias.from_string(alias_str)
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room"
% (room_alias_str,),
Codes.BAD_ALIAS,
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:

View file

@ -25,7 +25,6 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@ -48,7 +47,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs)
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
@ -116,7 +115,7 @@ class SamlHandler:
self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try:

View file

@ -27,10 +27,15 @@ import inspect
import logging
import threading
import types
from typing import Any, List
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal
from twisted.internet import defer, threads
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
logger = logging.getLogger(__name__)
try:
@ -91,7 +96,7 @@ class ContextResourceUsage(object):
"evt_db_fetch_count",
]
def __init__(self, copy_from=None):
def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
"""Create a new ContextResourceUsage
Args:
@ -101,27 +106,28 @@ class ContextResourceUsage(object):
if copy_from is None:
self.reset()
else:
self.ru_utime = copy_from.ru_utime
self.ru_stime = copy_from.ru_stime
self.db_txn_count = copy_from.db_txn_count
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
self.ru_utime = copy_from.ru_utime # type: float
self.ru_stime = copy_from.ru_stime # type: float
self.db_txn_count = copy_from.db_txn_count # type: int
self.db_txn_duration_sec = copy_from.db_txn_duration_sec
self.db_sched_duration_sec = copy_from.db_sched_duration_sec
self.evt_db_fetch_count = copy_from.evt_db_fetch_count
self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
def copy(self):
def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self)
def reset(self):
def reset(self) -> None:
self.ru_stime = 0.0
self.ru_utime = 0.0
self.db_txn_count = 0
self.db_txn_duration_sec = 0
self.db_sched_duration_sec = 0
self.db_txn_duration_sec = 0.0
self.db_sched_duration_sec = 0.0
self.evt_db_fetch_count = 0
def __repr__(self):
def __repr__(self) -> str:
return (
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
"db_txn_count='%r', db_txn_duration_sec='%r', "
@ -135,7 +141,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count,
)
def __iadd__(self, other):
def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
"""Add another ContextResourceUsage's stats to this one's.
Args:
@ -149,7 +155,7 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count += other.evt_db_fetch_count
return self
def __isub__(self, other):
def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
self.ru_utime -= other.ru_utime
self.ru_stime -= other.ru_stime
self.db_txn_count -= other.db_txn_count
@ -158,17 +164,20 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count -= other.evt_db_fetch_count
return self
def __add__(self, other):
def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res += other
return res
def __sub__(self, other):
def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
res = ContextResourceUsage(copy_from=self)
res -= other
return res
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
@ -201,7 +210,14 @@ class LoggingContext(object):
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = [] # type: List[Any]
__slots__ = ["previous_context", "alive", "request", "scope"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
def __str__(self):
return "sentinel"
@ -235,7 +251,7 @@ class LoggingContext(object):
sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None):
def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context()
self.name = name
@ -250,7 +266,7 @@ class LoggingContext(object):
self.request = None
self.tag = ""
self.alive = True
self.scope = None
self.scope = None # type: Optional[_LogContextScope]
self.parent_context = parent_context
@ -261,13 +277,13 @@ class LoggingContext(object):
# the request param overrides the request from the parent context
self.request = request
def __str__(self):
def __str__(self) -> str:
if self.request:
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls):
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
@ -276,7 +292,9 @@ class LoggingContext(object):
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
def set_current_context(cls, context):
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
@ -291,7 +309,7 @@ class LoggingContext(object):
context.start()
return current
def __enter__(self):
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = self.set_current_context(self)
if self.previous_context != old_context:
@ -304,7 +322,7 @@ class LoggingContext(object):
return self
def __exit__(self, type, value, traceback):
def __exit__(self, type, value, traceback) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@ -318,7 +336,6 @@ class LoggingContext(object):
logger.warning(
"Expected logging context %s but found %s", self, current
)
self.previous_context = None
self.alive = False
# if we have a parent, pass our CPU usage stats on
@ -330,7 +347,7 @@ class LoggingContext(object):
# reset them in case we get entered again
self._resource_usage.reset()
def copy_to(self, record):
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
@ -341,14 +358,14 @@ class LoggingContext(object):
# we also track the current scope:
record.scope = self.scope
def copy_to_twisted_log_entry(self, record):
def copy_to_twisted_log_entry(self, record) -> None:
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope
def start(self):
def start(self) -> None:
if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self)
return
@ -358,7 +375,7 @@ class LoggingContext(object):
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self):
def stop(self) -> None:
if get_thread_id() != self.main_thread:
logger.warning("Stopped logcontext %s on different thread", self)
return
@ -378,7 +395,7 @@ class LoggingContext(object):
self.usage_start = None
def get_resource_usage(self):
def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far.
Returns:
@ -398,11 +415,13 @@ class LoggingContext(object):
return res
def _get_cputime(self):
def _get_cputime(self) -> Tuple[float, float]:
"""Get the cpu usage time so far
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
"""
assert self.usage_start is not None
current = get_thread_resource_usage()
# Indicate to mypy that we know that self.usage_start is None.
@ -430,13 +449,13 @@ class LoggingContext(object):
return utime_delta, stime_delta
def add_database_transaction(self, duration_sec):
def add_database_transaction(self, duration_sec: float) -> None:
if duration_sec < 0:
raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec
def add_database_scheduled(self, sched_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
"""Record a use of the database pool
Args:
@ -447,7 +466,7 @@ class LoggingContext(object):
raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec
def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
"""Record a number of events being fetched from the db
Args:
@ -464,10 +483,10 @@ class LoggingContextFilter(logging.Filter):
missing fields
"""
def __init__(self, **defaults):
def __init__(self, **defaults) -> None:
self.defaults = defaults
def filter(self, record):
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
Returns:
True to include the record in the log output.
@ -492,12 +511,13 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=None):
def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
if new_context is None:
new_context = LoggingContext.sentinel
self.new_context = new_context
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:
self.new_context = new_context
def __enter__(self):
def __enter__(self) -> None:
"""Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context)
@ -506,7 +526,7 @@ class PreserveLoggingContext(object):
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback):
def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context)
@ -525,7 +545,9 @@ class PreserveLoggingContext(object):
logger.debug("Restoring dead context: %s", self.current_context)
def nested_logging_context(suffix, parent_context=None):
def nested_logging_context(
suffix: str, parent_context: Optional[LoggingContext] = None
) -> LoggingContext:
"""Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's
@ -546,10 +568,12 @@ def nested_logging_context(suffix, parent_context=None):
Returns:
LoggingContext: new logging context.
"""
if parent_context is None:
parent_context = LoggingContext.current_context()
if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel
else:
context = LoggingContext.current_context()
return LoggingContext(
parent_context=parent_context, request=parent_context.request + "-" + suffix
parent_context=context, request=str(context.request) + "-" + suffix
)
@ -654,7 +678,10 @@ def make_deferred_yieldable(deferred):
return deferred
def _set_context_cb(result, context):
ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context"""
LoggingContext.set_current_context(context)
return result

View file

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID
@ -211,3 +212,21 @@ class ModuleApi(object):
Deferred[object]: result of func
"""
return self._store.db.runInteraction(desc, func, *args, **kwargs)
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
URL with a token directly if the URL matches with one of the whitelisted clients.
Args:
registered_user_id: The MXID that has been registered as a previous step of
of this SSO login.
request: The request to respond to.
client_redirect_url: The URL to which to offer to redirect the user (or to
redirect them directly if whitelisted).
"""
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url,
)

View file

@ -555,10 +555,12 @@ class Mailer(object):
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
room_id = reason["room_id"]
sender_ids = list(
{
notif_events[n["event_id"]].sender
for n in notifs_by_room[reason["room_id"]]
for n in notifs_by_room[room_id]
}
)

View file

@ -18,7 +18,7 @@ import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import event_type_from_format_version
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@ -38,6 +38,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
{
"events": [{
"event": { .. serialized event .. },
"room_version": .., // "1", "2", "3", etc: the version of the room
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
@ -73,6 +76,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_payloads.append(
{
"event": event.get_pdu_json(),
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
@ -95,12 +99,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
format_ver = event_payload["event_format_version"]
room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
context = EventContext.deserialize(
self.storage, event_payload["context"]

View file

@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
from synapse.events import event_type_from_format_version
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@ -37,6 +38,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
{
"event": { .. serialized event .. },
"room_version": .., // "1", "2", "3", etc: the version of the room
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
@ -77,6 +81,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
payload = {
"event": event.get_pdu_json(),
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
@ -93,12 +98,13 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request)
event_dict = content["event"]
format_ver = content["event_format_version"]
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"]
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
requester = Requester.deserialize(self.store, content["requester"])
context = EventContext.deserialize(self.storage, content["context"])

View file

@ -0,0 +1,14 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO redirect confirmation</title>
</head>
<body>
<p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
<p>If you don't recognise this address, you should ignore this and close this tab.</p>
<p>
<a href="{{ redirect_url | e }}">I trust this address</a>
</p>
</body>
</html>

View file

@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet):
if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.")
await self.admin_handler.set_user_server_admin(
target_user, set_admin_to
)
await self.store.set_server_admin(target_user, set_admin_to)
if "password" in body:
if (
@ -651,6 +649,6 @@ class UserAdminServlet(RestServlet):
if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.")
await self.store.set_user_server_admin(target_user, set_admin_to)
await self.store.set_server_admin(target_user, set_admin_to)
return 200, {}

View file

@ -28,7 +28,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@ -548,6 +548,16 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
# Load the redirect page HTML template
self._template = load_jinja2_templates(
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
)[0]
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
@ -580,36 +590,9 @@ class SSOAuthHandler(object):
localpart=localpart, default_display_name=user_display_name
)
self.complete_sso_login(registered_user_id, request, client_redirect_url)
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id:
request:
client_redirect_url:
"""
login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url
)
redirect_url = self._add_login_token_to_redirect_url(
client_redirect_url, login_token
)
# Load page
request.redirect(redirect_url)
finish_request(request)
@staticmethod
def _add_login_token_to_redirect_url(url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
def register_servlets(hs, http_server):

View file

@ -18,8 +18,6 @@ from typing import Dict, Set
from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import (
@ -125,8 +123,7 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
async def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
store_queries = []
@ -143,7 +140,7 @@ class RemoteKey(DirectServeResource):
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
cached = yield self.store.get_server_keys_json(store_queries)
cached = await self.store.get_server_keys_json(store_queries)
json_results = set()
@ -215,8 +212,8 @@ class RemoteKey(DirectServeResource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(request, query, query_remote_on_cache_miss=False)
await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:

View file

@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@ -735,23 +752,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
def f(txn):

View file

@ -1168,7 +1168,11 @@ class EventsStore(
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
pruned_json = encode_json(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
)
else:
# Redaction wasn't allowed
pruned_json = None
@ -1929,7 +1933,9 @@ class EventsStore(
return
# Prune the event's dict then convert it to JSON.
pruned_json = encode_json(prune_event_dict(event.get_dict()))
pruned_json = encode_json(
prune_event_dict(event.room_version, event.get_dict())
)
# Update the event_json table to replace the event's JSON with the pruned
# JSON.

View file

@ -28,9 +28,12 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersions,
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore):
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
original_ev = event_type_from_format_version(format_version)(
room_version_id = row["room_version_id"]
if not room_version_id:
# this should only happen for out-of-band membership events
if not internal_metadata.get("out_of_band_membership"):
logger.warning(
"Room %s for event %s is unknown", d["room_id"], event_id
)
continue
# take a wild stab at the room version based on the event format
if format_version == EventFormatVersions.V1:
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
room_version = RoomVersions.V3
else:
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
logger.error(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
room_version_id,
)
continue
if room_version.event_format != format_version:
logger.error(
"Event %s in room %s with version %s has wrong format: "
"expected %s, was %s",
event_id,
d["room_id"],
room_version_id,
room_version.event_format,
format_version,
)
continue
original_ev = make_event_from_dict(
event_dict=d,
room_version=room_version,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore):
of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1).
* room_version_id (str|None): The version of the room which contains the event.
Hopefully one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
* rejected_reason (str|None): if the event was rejected, the reason
why.
@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore):
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
sql = (
"SELECT "
" e.event_id, "
" e.internal_metadata,"
" e.json,"
" e.format_version, "
" rej.reason "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" WHERE "
)
sql = """\
SELECT
e.event_id,
e.internal_metadata,
e.json,
e.format_version,
r.room_version,
rej.reason
FROM event_json as e
LEFT JOIN rooms r USING (room_id)
LEFT JOIN rejections as rej USING (event_id)
WHERE """
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs
@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore):
"internal_metadata": row[1],
"json": row[2],
"format_version": row[3],
"rejected_reason": row[4],
"room_version_id": row[4],
"rejected_reason": row[5],
"redactions": [],
}

View file

@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql)
(count,) = txn.fetchone()
return count
return self.db.runInteraction("count_users", _count_users)
@cached(num_args=0)
def get_monthly_active_count_by_service(self):
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
`config.track_appservice_user_ips` must be set to `true` for this
method to return anything other than native matrix users.
Returns:
Deferred[dict]: dict that includes a mapping between app_service_id
and the number of occurrences.
"""
def _count_users_by_service(txn):
sql = """
SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
FROM monthly_active_users
LEFT JOIN users ON monthly_active_users.user_id=users.name
GROUP BY appservice_id;
"""
txn.execute(sql)
result = txn.fetchall()
return dict(result)
return self.db.runInteraction("count_users_by_service", _count_users_by_service)
@defer.inlineCallbacks
def get_registered_reserved_users(self):
"""Of the reserved threepids defined in config, which are associated
@ -291,6 +318,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
self._invalidate_cache_and_stream(
txn, self.get_monthly_active_count_by_service, ()
)
self._invalidate_cache_and_stream(
txn, self.user_last_seen_monthly_active, (user_id,)
)

View file

@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
return self.db.simple_update_one(
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"admin": 1 if admin else 0},
desc="set_server_admin",
)
def set_server_admin_txn(txn):
self.db.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user.to_string(),)
)
return self.db.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (

View file

@ -15,6 +15,8 @@
import logging
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return max_stream_id, []
return defer.succeed((max_stream_id, []))
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than

View file

@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import time
from typing import Iterable, Tuple
from time import monotonic as monotonic_time
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
try:
# on python 3, use time.monotonic, since time.clock can go backwards
from time import monotonic as monotonic_time
except ImportError:
# ... but python 2 doesn't have it
from time import clock as monotonic_time
logger = logging.getLogger(__name__)
try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
@ -90,7 +80,9 @@ def make_pool(
)
def make_conn(db_config: DatabaseConnectionConfig, engine):
def make_conn(
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> Connection:
"""Make a new connection to the database and return it.
Returns:
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn
class LoggingTransaction(object):
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
# that mypy sees the type but the runtime python doesn't.
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method.
Args:
txn: The database transcation object to wrap.
name (str): The name of this transactions for logging.
database_engine (Sqlite3Engine|PostgresEngine)
after_callbacks(list|None): A list that callbacks will be appended to
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
exception_callbacks(list|None): A list that callbacks will be appended
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run.
@ -135,46 +134,67 @@ class LoggingTransaction(object):
]
def __init__(
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
self,
txn: Cursor,
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
object.__setattr__(self, "exception_callbacks", exception_callbacks)
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(self, callback, *args, **kwargs):
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(self, callback, *args, **kwargs):
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
def __setattr__(self, name, value):
setattr(self.txn, name, value)
def fetchone(self) -> Tuple:
return self.txn.fetchone()
def __iter__(self):
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@property
def rowcount(self) -> int:
return self.txn.rowcount
@property
def description(self) -> Any:
return self.txn.description
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)
def execute(self, sql, *args):
def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args):
def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql):
@ -207,6 +227,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
def close(self):
self.txn.close()
class PerformanceCounters(object):
def __init__(self):
@ -251,7 +274,9 @@ class Database(object):
_TXN_ID = 0
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
@ -259,9 +284,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self)
self._previous_txn_total_time = 0
self._current_txn_total_time = 0
self._previous_loop_ts = 0
self._previous_txn_total_time = 0.0
self._current_txn_total_time = 0.0
self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
@ -463,23 +488,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function
Arguments:
desc (str): description of the transaction, for logging and metrics
func (func): callback function, which will be called with a
desc: description of the transaction, for logging and metrics
func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
args: positional args to pass to `func`
kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
"""
after_callbacks = []
exception_callbacks = []
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc)
@ -505,15 +530,15 @@ class Database(object):
return result
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
func (func): callback function, which will be called with a
func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
args: positional args to pass to `func`
kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
@ -800,7 +825,7 @@ class Database(object):
return False
# We didn't find any existing rows, so insert a new one
allvalues = {}
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@ -829,7 +854,7 @@ class Database(object):
Returns:
None
"""
allvalues = {}
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)
@ -916,7 +941,7 @@ class Database(object):
Returns:
None
"""
allnames = []
allnames = [] # type: List[str]
allnames.extend(key_names)
allnames.extend(value_names)
@ -1100,7 +1125,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
results = []
results = [] # type: List[Dict[str, Any]]
if not iterable:
return results
@ -1439,7 +1464,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else ""
arg_list = []
arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())

View file

@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import platform
from ._base import IncorrectDatabaseSetup
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
def create_engine(database_config):
def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
if name == "sqlite3":
import sqlite3
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
if name == "psycopg2" and platform.python_implementation() == "PyPy":
name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
if platform.python_implementation() == "PyPy":
import psycopg2cffi as psycopg2 # type: ignore
else:
import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]

View file

@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Generic, TypeVar
from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
ConnectionType = TypeVar("ConnectionType", bound=Connection)
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def __init__(self, module, database_config: dict):
self.module = module
@property
@abc.abstractmethod
def single_threaded(self) -> bool:
...
@property
@abc.abstractmethod
def can_native_upsert(self) -> bool:
"""
Do we support native UPSERTs?
"""
...
@property
@abc.abstractmethod
def supports_tuple_comparison(self) -> bool:
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
...
@property
@abc.abstractmethod
def supports_using_any_list(self) -> bool:
"""
Do we support using `a = ANY(?)` and passing a list
"""
...
@abc.abstractmethod
def check_database(
self, db_conn: ConnectionType, allow_outdated_version: bool = False
) -> None:
...
@abc.abstractmethod
def check_new_database(self, txn) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
...
@abc.abstractmethod
def convert_param_style(self, sql: str) -> str:
...
@abc.abstractmethod
def on_new_connection(self, db_conn: ConnectionType) -> None:
...
@abc.abstractmethod
def is_deadlock(self, error: Exception) -> bool:
...
@abc.abstractmethod
def is_connection_closed(self, conn: ConnectionType) -> bool:
...
@abc.abstractmethod
def lock_table(self, txn, table: str) -> None:
...
@abc.abstractmethod
def get_next_state_group_id(self, txn) -> int:
"""Returns an int that can be used as a new state_group ID
"""
...
@property
@abc.abstractmethod
def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'
"""
...

View file

@ -15,16 +15,14 @@
import logging
from ._base import IncorrectDatabaseSetup
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
logger = logging.getLogger(__name__)
class PostgresEngine(object):
single_threaded = False
class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
self.module = database_module
super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@ -36,6 +34,10 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
@property
def single_threaded(self) -> bool:
return False
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and

View file

@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import struct
import threading
from synapse.storage.engines import BaseDatabaseEngine
class Sqlite3Engine(object):
single_threaded = True
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_module, database_config):
self.module = database_module
super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
@ -31,6 +31,10 @@ class Sqlite3Engine(object):
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
@property
def single_threaded(self) -> bool:
return True
@property
def can_native_upsert(self):
"""
@ -68,7 +72,6 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database

65
synapse/storage/types.py Normal file
View file

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
...
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
...
def fetchall(self) -> List[Tuple]:
...
def fetchone(self) -> Tuple:
...
@property
def description(self) -> Any:
return None
@property
def rowcount(self) -> int:
return 0
def __iter__(self) -> Iterator[Tuple]:
...
def close(self) -> None:
...
class Connection(Protocol):
def cursor(self) -> Cursor:
...
def close(self) -> None:
...
def commit(self) -> None:
...
def rollback(self, *args, **kwargs) -> None:
...

View file

@ -23,7 +23,7 @@ import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
# define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0):
@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
return self
@classmethod
def from_string(cls, s):
def from_string(cls, s: str):
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(
400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL)
400,
"Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
Codes.INVALID_PARAM,
)
parts = s[1:].split(":", 1)
@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
400,
"Expected %s of the form '%slocalname:domain'"
% (cls.__name__, cls.SIGIL),
Codes.INVALID_PARAM,
)
domain = parts[1]
@ -235,11 +238,13 @@ class GroupID(DomainSpecificString):
def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty")
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError(
400, "Group ID can only contain characters a-z, 0-9, or '=_-./'"
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
Codes.INVALID_PARAM,
)
return group_id

View file

@ -119,6 +119,9 @@ def filter_events_for_client(
the original event if they can see it as normal.
"""
if event.type == "org.matrix.dummy_event":
return None
if not event.is_state() and event.sender in ignore_list:
return None

View file

@ -27,6 +27,11 @@ class FrontendProxyTests(HomeserverTestCase):
return hs
def default_config(self, name="test"):
c = super().default_config(name)
c["worker_app"] = "synapse.app.frontend_proxy"
return c
def test_listen_http_with_presence_enabled(self):
"""
When presence is on, the stub servlet will not register.

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
copy_power_levels_contents,
@ -36,9 +37,9 @@ class PruneEventTestCase(unittest.TestCase):
""" Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted. """
def run_test(self, evdict, matchdict):
def run_test(self, evdict, matchdict, **kwargs):
self.assertEquals(
prune_event(make_event_from_dict(evdict)).get_dict(), matchdict
prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
)
def test_minimal(self):
@ -128,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase):
},
)
def test_alias_event(self):
"""Alias events have special behavior up through room version 6."""
self.run_test(
{
"type": "m.room.aliases",
"event_id": "$test:domain",
"content": {"aliases": ["test"]},
},
{
"type": "m.room.aliases",
"event_id": "$test:domain",
"content": {"aliases": ["test"]},
"signatures": {},
"unsigned": {},
},
)
def test_msc2432_alias_event(self):
"""After MSC2432, alias events have no special behavior."""
self.run_test(
{"type": "m.room.aliases", "content": {"aliases": ["test"]}},
{
"type": "m.room.aliases",
"content": {},
"signatures": {},
"unsigned": {},
},
room_version=RoomVersions.MSC2432_DEV,
)
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):

View file

@ -18,6 +18,7 @@ from mock import Mock
from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
from synapse.config.room_directory import RoomDirectoryConfig
@ -87,38 +88,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
def test_delete_alias_not_allowed(self):
room_id = "!8765qwer:test"
self.get_success(
self.store.create_room_alias_association(self.my_room, room_id, ["test"])
)
self.get_failure(
self.handler.delete_association(
create_requester("@user:test"), self.my_room
),
synapse.api.errors.AuthError,
)
def test_delete_alias(self):
room_id = "!8765qwer:test"
user_id = "@user:test"
self.get_success(
self.store.create_room_alias_association(
self.my_room, room_id, ["test"], user_id
)
)
result = self.get_success(
self.handler.delete_association(create_requester(user_id), self.my_room)
)
self.assertEquals(room_id, result)
# The alias should not be found.
self.get_failure(
self.handler.get_association(self.my_room), synapse.api.errors.SynapseError
)
def test_incoming_fed_query(self):
self.get_success(
self.store.create_room_alias_association(
@ -133,6 +102,119 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
class TestDeleteAlias(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
directory.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.handler = hs.get_handlers().directory_handler
self.state_handler = hs.get_state_handler()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
# Create a test room
self.room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
self.test_alias = "#test:test"
self.room_alias = RoomAlias.from_string(self.test_alias)
# Create a test user.
self.test_user = self.register_user("user", "pass", admin=False)
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def _create_alias(self, user):
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
self.room_alias, self.room_id, ["test"], user
)
)
def test_delete_alias_not_allowed(self):
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
self.handler.delete_association(
create_requester(self.test_user), self.room_alias
),
synapse.api.errors.AuthError,
)
def test_delete_alias_creator(self):
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
# Delete the user's alias.
result = self.get_success(
self.handler.delete_association(
create_requester(self.test_user), self.room_alias
)
)
self.assertEquals(self.room_id, result)
# Confirm the alias is gone.
self.get_failure(
self.handler.get_association(self.room_alias),
synapse.api.errors.SynapseError,
)
def test_delete_alias_admin(self):
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
# Delete the user's alias as the admin.
result = self.get_success(
self.handler.delete_association(
create_requester(self.admin_user), self.room_alias
)
)
self.assertEquals(self.room_id, result)
# Confirm the alias is gone.
self.get_failure(
self.handler.get_association(self.room_alias),
synapse.api.errors.SynapseError,
)
def test_delete_alias_sufficient_power(self):
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
# Increase the user's power level.
self.helper.send_state(
self.room_id,
"m.room.power_levels",
{"users": {self.test_user: 100}},
tok=self.admin_user_tok,
)
# They can now delete the alias.
result = self.get_success(
self.handler.delete_association(
create_requester(self.test_user), self.room_alias
)
)
self.assertEquals(self.room_id, result)
# Confirm the alias is gone.
self.get_failure(
self.handler.get_association(self.room_alias),
synapse.api.errors.SynapseError,
)
class CanonicalAliasTestCase(unittest.HomeserverTestCase):
"""Test modifications of the canonical alias when delete aliases.
"""
@ -159,30 +241,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
self.test_alias = "#test:test"
self.room_alias = RoomAlias.from_string(self.test_alias)
self.room_alias = self._add_alias(self.test_alias)
def _add_alias(self, alias: str) -> RoomAlias:
"""Add an alias to the test room."""
room_alias = RoomAlias.from_string(alias)
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
self.room_alias, self.room_id, ["test"], self.admin_user
room_alias, self.room_id, ["test"], self.admin_user
)
)
return room_alias
def _set_canonical_alias(self, content):
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
)
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
def test_remove_alias(self):
"""Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room
self.helper.send_state(
self.room_id,
"m.room.canonical_alias",
{"alias": self.test_alias, "alt_aliases": [self.test_alias]},
tok=self.admin_user_tok,
self._set_canonical_alias(
{"alias": self.test_alias, "alt_aliases": [self.test_alias]}
)
data = self.get_success(
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
@ -193,11 +287,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
data = self.get_success(
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
data = self._get_canonical_alias()
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
@ -205,29 +295,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
"""Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias.
other_test_alias = "#test2:test"
other_room_alias = RoomAlias.from_string(other_test_alias)
self.get_success(
self.store.create_room_alias_association(
other_room_alias, self.room_id, ["test"], self.admin_user
)
)
other_room_alias = self._add_alias(other_test_alias)
# Set the alias as the canonical alias for this room.
self.helper.send_state(
self.room_id,
"m.room.canonical_alias",
self._set_canonical_alias(
{
"alias": self.test_alias,
"alt_aliases": [self.test_alias, other_test_alias],
},
tok=self.admin_user_tok,
}
)
data = self.get_success(
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(
data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
@ -240,11 +318,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
data = self.get_success(
self.state_handler.get_current_state(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])

View file

@ -15,6 +15,7 @@ import logging
from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
@ -58,6 +59,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
return super(SlavedEventStoreTestCase, self).setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
self.get_success(
self.master_store.store_room(
ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
)
)
def tearDown(self):
[unpatch() for unpatch in self.unpatches]

View file

@ -16,6 +16,7 @@
import hashlib
import hmac
import json
import urllib.parse
from mock import Mock
@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.url = "/_synapse/admin/v2/users/@bob:test"
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass")
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
)
def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
request, channel = self.make_request(
"GET", self.url, access_token=self.other_user_token,
"GET", url, access_token=self.other_user_token,
)
self.render(request)
@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("You are not a server admin", channel.json_body["error"])
request, channel = self.make_request(
"PUT", self.url, access_token=self.other_user_token, content=b"{}",
"PUT", url, access_token=self.other_user_token, content=b"{}",
)
self.render(request)
@ -417,24 +420,26 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_requester_is_admin(self):
def test_create_server_admin(self):
"""
If the user is a server admin, a new user is created.
Check that a new admin user is created successfully.
"""
self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user (server admin)
body = json.dumps(
{
"password": "abc123",
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
}
)
# Create user
request, channel = self.make_request(
"PUT",
self.url,
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
@ -442,29 +447,85 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
# Get user
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(1, channel.json_body["admin"])
self.assertEqual(0, channel.json_body["is_guest"])
self.assertEqual(0, channel.json_body["deactivated"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
def test_create_user(self):
"""
Check that a new regular user is created successfully.
"""
self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
body = json.dumps(
{
"password": "abc123",
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
}
)
request, channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
# Get user
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"])
def test_set_password(self):
"""
Test setting a new password for another user.
"""
self.hs.config.registration_shared_secret = None
# Change password
body = json.dumps({"password": "hahaha"})
request, channel = self.make_request(
"PUT",
self.url,
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
@ -472,41 +533,133 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def test_set_displayname(self):
"""
Test setting the displayname of another user.
"""
self.hs.config.registration_shared_secret = None
# Modify user
body = json.dumps({"displayname": "foobar"})
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
def test_set_threepid(self):
"""
Test setting threepid for an other user.
"""
self.hs.config.registration_shared_secret = None
# Delete old and add new threepid to user
body = json.dumps(
{
"displayname": "foobar",
"deactivated": True,
"threepids": [{"medium": "email", "address": "bob2@bob.bob"}],
}
{"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
)
request, channel = self.make_request(
"PUT",
self.url,
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
def test_deactivate_user(self):
"""
Test deactivating another user.
"""
# Deactivate user
body = json.dumps({"deactivated": True})
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
# the user is deactivated, the threepid will be deleted
# Get user
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
self.assertEqual(1, channel.json_body["admin"])
self.assertEqual(0, channel.json_body["is_guest"])
self.assertEqual(1, channel.json_body["deactivated"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.
"""
self.hs.config.registration_shared_secret = None
# Set a user as an admin
body = json.dumps({"admin": True})
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"])
def test_accidental_deactivation_prevention(self):
"""
@ -514,13 +667,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
for the deactivated body parameter
"""
self.hs.config.registration_shared_secret = None
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
body = json.dumps({"password": "abc123"})
request, channel = self.make_request(
"PUT",
self.url,
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
@ -532,7 +686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Get user
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
@ -546,7 +700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT",
self.url,
url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
@ -556,7 +710,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Check user is not deactivated
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)

View file

@ -1,4 +1,7 @@
import json
import urllib.parse
from mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import login
@ -252,3 +255,111 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
self.base_url = "https://matrix.goodserver.com/"
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
config = self.default_config()
config["cas_config"] = {
"enabled": True,
"server_url": "https://fake.test",
"service_url": "https://matrix.goodserver.com:8448",
}
async def get_raw(uri, args):
"""Return an example response payload from a call to the `/proxyValidate`
endpoint of a CAS server, copied from
https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
This needs to be returned by an async function (as opposed to set as the
mock's return value) because the corresponding Synapse code awaits on it.
"""
return """
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
<cas:user>username</cas:user>
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
<cas:proxies>
<cas:proxy>https://proxy2/pgtUrl</cas:proxy>
<cas:proxy>https://proxy1/pgtUrl</cas:proxy>
</cas:proxies>
</cas:authenticationSuccess>
</cas:serviceResponse>
"""
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
config=config, proxied_http_client=mocked_http_client,
)
return self.hs
def test_cas_redirect_confirm(self):
"""Tests that the SSO login flow serves a confirmation page before redirecting a
user to the redirect URL.
"""
base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
redirect_url = "https://dodgy-site.com/"
url_parts = list(urllib.parse.urlparse(base_url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"redirectUrl": redirect_url})
query.update({"ticket": "ticket"})
url_parts[4] = urllib.parse.urlencode(query)
cas_ticket_url = urllib.parse.urlunparse(url_parts)
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
self.render(request)
# Test that the response is HTML.
self.assertEqual(channel.code, 200)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
content_type_header_value = header[1].decode("utf8")
self.assertTrue(content_type_header_value.startswith("text/html"))
# Test that the body isn't empty.
self.assertTrue(len(channel.result["body"]) > 0)
# And that it contains our redirect link
self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
@override_config(
{
"sso": {
"client_whitelist": [
"https://legit-site.com/",
"https://other-site.com/",
]
}
}
)
def test_cas_redirect_whitelisted(self):
"""Tests that the SSO login flow serves a redirect to a whitelisted url
"""
redirect_url = "https://legit-site.com/"
cas_ticket_url = (
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
% (urllib.parse.quote(redirect_url))
)
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
self.render(request)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)

View file

@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
self.room_id = self.helper.create_room_as(
self.room_owner, tok=self.room_owner_tok
)
self.alias = "#alias:test"
self._set_alias_via_directory(self.alias)
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
request, channel = self.make_request(
"GET",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
access_token=self.room_owner_tok,
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
return res
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
request, channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
json.dumps(content),
access_token=self.room_owner_tok,
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
return res
def test_canonical_alias(self):
"""Test a basic alias message."""
# There is no canonical alias to start with.
self._get_canonical_alias(expected_code=404)
# Create an alias.
self._set_canonical_alias({"alias": self.alias})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias})
# Now remove the alias.
self._set_canonical_alias({})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_alt_aliases(self):
"""Test a canonical alias message with alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alt_aliases": [self.alias]})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alt_aliases": [self.alias]})
# Now remove the alt_aliases.
self._set_canonical_alias({})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_alias_alt_aliases(self):
"""Test a canonical alias message with an alias and alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
# Now remove the alias and alt_aliases.
self._set_canonical_alias({})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_partial_modify(self):
"""Test removing only the alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
# Now remove the alt_aliases.
self._set_canonical_alias({"alias": self.alias})
# There is an alias event, but it is empty.
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias})
def test_add_alias(self):
"""Test removing only the alt_aliases."""
# Create an additional alias.
second_alias = "#second:test"
self._set_alias_via_directory(second_alias)
# Add the canonical alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
# Then add the second alias.
self._set_canonical_alias(
{"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
)
# Canonical alias now exists!
res = self._get_canonical_alias()
self.assertEqual(
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
)
def test_bad_data(self):
"""Invalid data for alt_aliases should cause errors."""
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
def test_bad_alias(self):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)

View file

@ -303,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.pump()
self.store.upsert_monthly_active_user.assert_not_called()
def test_get_monthly_active_count_by_service(self):
appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com"
appservice2_user1 = "@appservice2_user1:example.com"
native_user1 = "@native_user1:example.com"
service1 = "service1"
service2 = "service2"
native = "native"
self.store.register_user(
user_id=appservice1_user1, password_hash=None, appservice_id=service1
)
self.store.register_user(
user_id=appservice1_user2, password_hash=None, appservice_id=service1
)
self.store.register_user(
user_id=appservice2_user1, password_hash=None, appservice_id=service2
)
self.store.register_user(user_id=native_user1, password_hash=None)
self.pump()
count = self.store.get_monthly_active_count_by_service()
self.assertEqual({}, self.get_success(count))
self.store.upsert_monthly_active_user(native_user1)
self.store.upsert_monthly_active_user(appservice1_user1)
self.store.upsert_monthly_active_user(appservice1_user2)
self.store.upsert_monthly_active_user(appservice2_user1)
self.pump()
count = self.store.get_monthly_active_count()
self.assertEqual(4, self.get_success(count))
count = self.store.get_monthly_active_count_by_service()
result = self.get_success(count)
self.assertEqual(2, result[service1])
self.assertEqual(1, result[service2])
self.assertEqual(1, result[native])

View file

@ -19,6 +19,7 @@ from synapse import event_auth
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.types import get_domain_from_id
class EventAuthTestCase(unittest.TestCase):
@ -51,7 +52,7 @@ class EventAuthTestCase(unittest.TestCase):
_random_state_event(joiner),
auth_events,
do_sig_check=False,
),
)
def test_state_default_level(self):
"""
@ -87,6 +88,83 @@ class EventAuthTestCase(unittest.TestCase):
RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
)
def test_alias_event(self):
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
auth_events = {
("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator),
}
# creator should be able to send aliases
event_auth.check(
RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
)
# Reject an event with no state key.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.V1,
_alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
# If the domain of the sender does not match the state key, reject.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.V1,
_alias_event(creator, state_key="test.com"),
auth_events,
do_sig_check=False,
)
# Note that the member does *not* need to be in the room.
event_auth.check(
RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
)
def test_msc2432_alias_event(self):
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
auth_events = {
("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator),
}
# creator should be able to send aliases
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(creator),
auth_events,
do_sig_check=False,
)
# No particular checks are done on the state key.
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(creator, state_key="test.com"),
auth_events,
do_sig_check=False,
)
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
event_auth.check(
RoomVersions.MSC2432_DEV,
_alias_event(other),
auth_events,
do_sig_check=False,
)
# helpers for making events
@ -131,6 +209,19 @@ def _power_levels_event(sender, content):
)
def _alias_event(sender, **kwargs):
data = {
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
"type": "m.room.aliases",
"sender": sender,
"state_key": get_domain_from_id(sender),
"content": {"aliases": []},
}
data.update(**kwargs)
return make_event_from_dict(data)
def _random_state_event(sender):
return make_event_from_dict(
{

View file

@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase):
self.fail("Parsing '%s' should raise exception" % id_string)
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)
self.assertEqual("M_INVALID_PARAM", exc.errcode)
class MapUsernameTestCase(unittest.TestCase):

View file

@ -168,7 +168,6 @@ commands=
coverage html
[testenv:mypy]
basepython = python3.7
skip_install = True
deps =
{[base]deps}
@ -179,10 +178,14 @@ env =
extras = all
commands = mypy \
synapse/api \
synapse/config/ \
synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \
synapse/federation/federation_base.py \
synapse/federation/federation_client.py \
synapse/federation/sender \
synapse/federation/transport \
synapse/handlers/directory.py \
synapse/handlers/presence.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
@ -192,6 +195,7 @@ commands = mypy \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/storage/database.py \
synapse/streams
# To find all folders that pass mypy you run: