Merge branch 'develop' into matrix-org-hotfixes
This commit is contained in:
commit
74050d0c1c
|
@ -6,12 +6,7 @@
|
|||
set -ex
|
||||
|
||||
apt-get update
|
||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev
|
||||
|
||||
# workaround for https://github.com/jaraco/zipp/issues/40
|
||||
python3.5 -m pip install 'setuptools>=34.4.0'
|
||||
|
||||
python3.5 -m pip install tox
|
||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
|
||||
|
||||
export LANG="C.UTF-8"
|
||||
|
||||
|
|
15
CHANGES.md
15
CHANGES.md
|
@ -1,3 +1,18 @@
|
|||
Synapse 1.11.1 (2020-03-03)
|
||||
===========================
|
||||
|
||||
This release includes a security fix impacting installations using Single Sign-On (i.e. SAML2 or CAS) for authentication. Administrators of such installations are encouraged to upgrade as soon as possible.
|
||||
|
||||
The release also includes fixes for a couple of other bugs.
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Add a confirmation step to the SSO login flow before redirecting users to the redirect URL. ([b2bd54a2](https://github.com/matrix-org/synapse/commit/b2bd54a2e31d9a248f73fadb184ae9b4cbdb49f9), [65c73cdf](https://github.com/matrix-org/synapse/commit/65c73cdfec1876a9fec2fd2c3a74923cd146fe0b), [a0178df1](https://github.com/matrix-org/synapse/commit/a0178df10422a76fd403b82d2b2a4ed28a9a9d1e))
|
||||
- Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/<user_id>`. Contributed by @dklimpel. ([\#6910](https://github.com/matrix-org/synapse/issues/6910))
|
||||
- Fix bug introduced in Synapse 1.11.0 which sometimes caused errors when joining rooms over federation, with `'coroutine' object has no attribute 'event_id'`. ([\#6996](https://github.com/matrix-org/synapse/issues/6996))
|
||||
|
||||
|
||||
Synapse 1.11.0 (2020-02-21)
|
||||
===========================
|
||||
|
||||
|
|
|
@ -418,7 +418,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
|||
for having Synapse automatically provision and renew federation
|
||||
certificates through ACME can be found at [ACME.md](docs/ACME.md).
|
||||
Note that, as pointed out in that document, this feature will not
|
||||
work with installs set up after November 2020.
|
||||
work with installs set up after November 2019.
|
||||
|
||||
If you are using your own certificate, be sure to use a `.pem` file that
|
||||
includes the full certificate chain including any intermediate certificates
|
||||
|
|
1
changelog.d/6309.misc
Normal file
1
changelog.d/6309.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to `logging/context.py`.
|
1
changelog.d/6315.feature
Normal file
1
changelog.d/6315.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0.
|
1
changelog.d/6874.misc
Normal file
1
changelog.d/6874.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactoring work in preparation for changing the event redaction algorithm.
|
1
changelog.d/6875.misc
Normal file
1
changelog.d/6875.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactoring work in preparation for changing the event redaction algorithm.
|
1
changelog.d/6971.feature
Normal file
1
changelog.d/6971.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Validate the alt_aliases property of canonical alias events.
|
1
changelog.d/6986.feature
Normal file
1
changelog.d/6986.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases.
|
1
changelog.d/6987.misc
Normal file
1
changelog.d/6987.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add some type annotations to the database storage classes.
|
1
changelog.d/6995.misc
Normal file
1
changelog.d/6995.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add some type annotations to the federation base & client classes.
|
1
changelog.d/7002.misc
Normal file
1
changelog.d/7002.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Merge worker apps together.
|
1
changelog.d/7003.misc
Normal file
1
changelog.d/7003.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactoring work in preparation for changing the event redaction algorithm.
|
1
changelog.d/7015.misc
Normal file
1
changelog.d/7015.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Change date in INSTALL.md#tls-certificates for last date of getting TLS certificates to November 2019.
|
1
changelog.d/7018.bugfix
Normal file
1
changelog.d/7018.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix py35-old CI by using native tox package.
|
1
changelog.d/7019.misc
Normal file
1
changelog.d/7019.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Port `synapse.handlers.presence` to async/await.
|
1
changelog.d/7020.misc
Normal file
1
changelog.d/7020.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Port `synapse.rest.keys` to async/await.
|
1
changelog.d/7030.feature
Normal file
1
changelog.d/7030.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Break down monthly active users by `appservice_id` and emit via Prometheus.
|
1
changelog.d/7035.bugfix
Normal file
1
changelog.d/7035.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a bug causing `org.matrix.dummy_event` to be included in responses from `/sync`.
|
1
changelog.d/7037.feature
Normal file
1
changelog.d/7037.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432).
|
1
changelog.d/7045.misc
Normal file
1
changelog.d/7045.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add a type check to `is_verified` when processing room keys.
|
1
changelog.d/7048.doc
Normal file
1
changelog.d/7048.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Document that the fallback auth endpoints must be routed to the same worker node as the register endpoints.
|
1
changelog.d/7055.misc
Normal file
1
changelog.d/7055.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Merge worker apps together.
|
|
@ -15,10 +15,9 @@ services:
|
|||
restart: unless-stopped
|
||||
# See the readme for a full documentation of the environment settings
|
||||
environment:
|
||||
- SYNAPSE_CONFIG_PATH=/etc/homeserver.yaml
|
||||
- SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
|
||||
volumes:
|
||||
# You may either store all the files in a local folder
|
||||
- ./matrix-config/homeserver.yaml:/etc/homeserver.yaml
|
||||
- ./files:/data
|
||||
# .. or you may split this between different storage points
|
||||
# - ./files:/data
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Using the Synapse Grafana dashboard
|
||||
|
||||
0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/
|
||||
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst
|
||||
1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
|
||||
2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/
|
||||
3. Set up additional recording rules
|
||||
|
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
|||
matrix-synapse-py3 (1.11.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.11.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 03 Mar 2020 15:01:22 +0000
|
||||
|
||||
matrix-synapse-py3 (1.11.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.11.0.
|
||||
|
|
|
@ -1360,6 +1360,56 @@ saml2_config:
|
|||
# # name: value
|
||||
|
||||
|
||||
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
|
||||
#
|
||||
sso:
|
||||
# A list of client URLs which are whitelisted so that the user does not
|
||||
# have to confirm giving access to their account to the URL. Any client
|
||||
# whose URL starts with an entry in the following list will not be subject
|
||||
# to an additional confirmation step after the SSO login is completed.
|
||||
#
|
||||
# WARNING: An entry such as "https://my.client" is insecure, because it
|
||||
# will also match "https://my.client.evil.site", exposing your users to
|
||||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||
# hostname: "https://my.client/".
|
||||
#
|
||||
# By default, this list is empty.
|
||||
#
|
||||
#client_whitelist:
|
||||
# - https://riot.im/develop
|
||||
# - https://my.custom.client/
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
# * HTML page for a confirmation step before redirecting back to the client
|
||||
# with the login token: 'sso_redirect_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given three variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * display_url: the same as `redirect_url`, but with the query
|
||||
# parameters stripped. The intention is to have a
|
||||
# human-readable URL to show to users, not to use it as
|
||||
# the final address to redirect to. Needs manual escaping
|
||||
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
#template_dir: "res/templates"
|
||||
|
||||
|
||||
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
||||
#
|
||||
#jwt_config:
|
||||
|
|
|
@ -273,6 +273,7 @@ Additionally, the following REST endpoints can be handled, but all requests must
|
|||
be routed to the same instance:
|
||||
|
||||
^/_matrix/client/(r0|unstable)/register$
|
||||
^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
|
||||
|
||||
Pagination requests can also be handled, but all requests with the same path
|
||||
room must be routed to the same instance. Additionally, care must be taken to
|
||||
|
|
|
@ -1,20 +1,31 @@
|
|||
name: matrix-synapse
|
||||
base: core18
|
||||
version: git
|
||||
version: git
|
||||
summary: Reference Matrix homeserver
|
||||
description: |
|
||||
Synapse is the reference Matrix homeserver.
|
||||
Matrix is a federated and decentralised instant messaging and VoIP system.
|
||||
|
||||
grade: stable
|
||||
confinement: strict
|
||||
grade: stable
|
||||
confinement: strict
|
||||
|
||||
apps:
|
||||
matrix-synapse:
|
||||
matrix-synapse:
|
||||
command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml
|
||||
stop-command: synctl -c $SNAP_COMMON stop
|
||||
plugs: [network-bind, network]
|
||||
daemon: simple
|
||||
daemon: simple
|
||||
hash-password:
|
||||
command: hash_password
|
||||
generate-config:
|
||||
command: generate_config
|
||||
generate-signing-key:
|
||||
command: generate_signing_key.py
|
||||
register-new-matrix-user:
|
||||
command: register_new_matrix_user
|
||||
plugs: [network]
|
||||
synctl:
|
||||
command: synctl
|
||||
parts:
|
||||
matrix-synapse:
|
||||
source: .
|
||||
|
|
|
@ -36,7 +36,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.11.0"
|
||||
__version__ = "1.11.1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
|
|
@ -539,7 +539,7 @@ class Auth(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def check_can_change_room_list(self, room_id: str, user: UserID):
|
||||
"""Check if the user is allowed to edit the room's entry in the
|
||||
"""Determine whether the user is allowed to edit the room's entry in the
|
||||
published room list.
|
||||
|
||||
Args:
|
||||
|
@ -570,12 +570,7 @@ class Auth(object):
|
|||
)
|
||||
user_level = event_auth.get_user_power_level(user_id, auth_events)
|
||||
|
||||
if user_level < send_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"This server requires you to be a moderator in the room to"
|
||||
" edit its room list entry",
|
||||
)
|
||||
return user_level >= send_level
|
||||
|
||||
@staticmethod
|
||||
def has_access_token(request):
|
||||
|
|
|
@ -66,6 +66,7 @@ class Codes(object):
|
|||
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
|
||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||
BAD_ALIAS = "M_BAD_ALIAS"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
|
|
@ -57,7 +57,7 @@ class RoomVersion(object):
|
|||
state_res = attr.ib() # int; one of the StateResolutionVersions
|
||||
enforce_key_validity = attr.ib() # bool
|
||||
|
||||
# bool: before MSC2260, anyone was allowed to send an aliases event
|
||||
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
|
||||
special_case_aliases_auth = attr.ib(type=bool, default=False)
|
||||
|
||||
|
||||
|
@ -102,12 +102,13 @@ class RoomVersions(object):
|
|||
enforce_key_validity=True,
|
||||
special_case_aliases_auth=True,
|
||||
)
|
||||
MSC2260_DEV = RoomVersion(
|
||||
"org.matrix.msc2260",
|
||||
MSC2432_DEV = RoomVersion(
|
||||
"org.matrix.msc2432",
|
||||
RoomDisposition.UNSTABLE,
|
||||
EventFormatVersions.V3,
|
||||
StateResolutionVersions.V2,
|
||||
enforce_key_validity=True,
|
||||
special_case_aliases_auth=False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -119,6 +120,6 @@ KNOWN_ROOM_VERSIONS = {
|
|||
RoomVersions.V3,
|
||||
RoomVersions.V4,
|
||||
RoomVersions.V5,
|
||||
RoomVersions.MSC2260_DEV,
|
||||
RoomVersions.MSC2432_DEV,
|
||||
)
|
||||
} # type: Dict[str, RoomVersion]
|
||||
|
|
|
@ -494,20 +494,26 @@ class GenericWorkerServer(HomeServer):
|
|||
elif name == "federation":
|
||||
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
|
||||
elif name == "media":
|
||||
media_repo = self.get_media_repository_resource()
|
||||
if self.config.can_load_media_repo:
|
||||
media_repo = self.get_media_repository_resource()
|
||||
|
||||
# We need to serve the admin servlets for media on the
|
||||
# worker.
|
||||
admin_resource = JsonResource(self, canonical_json=False)
|
||||
register_servlets_for_media_repo(self, admin_resource)
|
||||
# We need to serve the admin servlets for media on the
|
||||
# worker.
|
||||
admin_resource = JsonResource(self, canonical_json=False)
|
||||
register_servlets_for_media_repo(self, admin_resource)
|
||||
|
||||
resources.update(
|
||||
{
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
"/_synapse/admin": admin_resource,
|
||||
}
|
||||
)
|
||||
resources.update(
|
||||
{
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
"/_synapse/admin": admin_resource,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"A 'media' listener is configured but the media"
|
||||
" repository is disabled. Ignoring."
|
||||
)
|
||||
|
||||
if name == "openid" and "federation" not in res["names"]:
|
||||
# Only load the openid resource separately if federation resource
|
||||
|
|
|
@ -298,6 +298,11 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
# Gauges to expose monthly active user control metrics
|
||||
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
|
||||
current_mau_by_service_gauge = Gauge(
|
||||
"synapse_admin_mau_current_mau_by_service",
|
||||
"Current MAU by service",
|
||||
["app_service"],
|
||||
)
|
||||
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
|
||||
registered_reserved_users_mau_gauge = Gauge(
|
||||
"synapse_admin_mau:registered_reserved_users",
|
||||
|
@ -585,12 +590,20 @@ def run(hs):
|
|||
@defer.inlineCallbacks
|
||||
def generate_monthly_active_users():
|
||||
current_mau_count = 0
|
||||
current_mau_count_by_service = {}
|
||||
reserved_users = ()
|
||||
store = hs.get_datastore()
|
||||
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
|
||||
current_mau_count = yield store.get_monthly_active_count()
|
||||
current_mau_count_by_service = (
|
||||
yield store.get_monthly_active_count_by_service()
|
||||
)
|
||||
reserved_users = yield store.get_registered_reserved_users()
|
||||
current_mau_gauge.set(float(current_mau_count))
|
||||
|
||||
for app_service, count in current_mau_count_by_service.items():
|
||||
current_mau_by_service_gauge.labels(app_service).set(float(count))
|
||||
|
||||
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
|
||||
max_mau_gauge.set(float(hs.config.max_mau_value))
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from synapse.config import (
|
|||
server,
|
||||
server_notices_config,
|
||||
spam_checker,
|
||||
sso,
|
||||
stats,
|
||||
third_party_event_rules,
|
||||
tls,
|
||||
|
@ -57,6 +58,7 @@ class RootConfig:
|
|||
key: key.KeyConfig
|
||||
saml2: saml2_config.SAML2Config
|
||||
cas: cas.CasConfig
|
||||
sso: sso.SSOConfig
|
||||
jwt: jwt_config.JWTConfig
|
||||
password: password.PasswordConfig
|
||||
email: emailconfig.EmailConfig
|
||||
|
|
|
@ -38,6 +38,7 @@ from .saml2_config import SAML2Config
|
|||
from .server import ServerConfig
|
||||
from .server_notices_config import ServerNoticesConfig
|
||||
from .spam_checker import SpamCheckerConfig
|
||||
from .sso import SSOConfig
|
||||
from .stats import StatsConfig
|
||||
from .third_party_event_rules import ThirdPartyRulesConfig
|
||||
from .tls import TlsConfig
|
||||
|
@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig):
|
|||
KeyConfig,
|
||||
SAML2Config,
|
||||
CasConfig,
|
||||
SSOConfig,
|
||||
JWTConfig,
|
||||
PasswordConfig,
|
||||
EmailConfig,
|
||||
|
|
92
synapse/config/sso.py
Normal file
92
synapse/config/sso.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class SSOConfig(Config):
|
||||
"""SSO Configuration
|
||||
"""
|
||||
|
||||
section = "sso"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
sso_config = config.get("sso") or {} # type: Dict[str, Any]
|
||||
|
||||
# Pick a template directory in order of:
|
||||
# * The sso-specific template_dir
|
||||
# * /path/to/synapse/install/res/templates
|
||||
template_dir = sso_config.get("template_dir")
|
||||
if not template_dir:
|
||||
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
|
||||
|
||||
self.sso_redirect_confirm_template_dir = template_dir
|
||||
|
||||
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
|
||||
#
|
||||
sso:
|
||||
# A list of client URLs which are whitelisted so that the user does not
|
||||
# have to confirm giving access to their account to the URL. Any client
|
||||
# whose URL starts with an entry in the following list will not be subject
|
||||
# to an additional confirmation step after the SSO login is completed.
|
||||
#
|
||||
# WARNING: An entry such as "https://my.client" is insecure, because it
|
||||
# will also match "https://my.client.evil.site", exposing your users to
|
||||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||
# hostname: "https://my.client/".
|
||||
#
|
||||
# By default, this list is empty.
|
||||
#
|
||||
#client_whitelist:
|
||||
# - https://riot.im/develop
|
||||
# - https://my.custom.client/
|
||||
|
||||
# Directory in which Synapse will try to find the template files below.
|
||||
# If not set, default templates from within the Synapse package will be used.
|
||||
#
|
||||
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
|
||||
# If you *do* uncomment it, you will need to make sure that all the templates
|
||||
# below are in the directory.
|
||||
#
|
||||
# Synapse will look for the following templates in this directory:
|
||||
#
|
||||
# * HTML page for a confirmation step before redirecting back to the client
|
||||
# with the login token: 'sso_redirect_confirm.html'.
|
||||
#
|
||||
# When rendering, this template is given three variables:
|
||||
# * redirect_url: the URL the user is about to be redirected to. Needs
|
||||
# manual escaping (see
|
||||
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * display_url: the same as `redirect_url`, but with the query
|
||||
# parameters stripped. The intention is to have a
|
||||
# human-readable URL to show to users, not to use it as
|
||||
# the final address to redirect to. Needs manual escaping
|
||||
# (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
|
||||
#
|
||||
# * server_name: the homeserver's name.
|
||||
#
|
||||
# You can see the default templates at:
|
||||
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
|
||||
#
|
||||
#template_dir: "res/templates"
|
||||
"""
|
|
@ -140,7 +140,7 @@ def compute_event_signature(
|
|||
Returns:
|
||||
a dictionary in the same format of an event's signatures field.
|
||||
"""
|
||||
redact_json = prune_event_dict(event_dict)
|
||||
redact_json = prune_event_dict(room_version, event_dict)
|
||||
redact_json.pop("age_ts", None)
|
||||
redact_json.pop("unsigned", None)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
|
|
|
@ -137,7 +137,7 @@ def check(
|
|||
raise AuthError(403, "This room has been marked as unfederatable.")
|
||||
|
||||
# 4. If type is m.room.aliases
|
||||
if event.type == EventTypes.Aliases:
|
||||
if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth:
|
||||
# 4a. If event has no state_key, reject
|
||||
if not event.is_state():
|
||||
raise AuthError(403, "Alias event must be a state event")
|
||||
|
@ -152,10 +152,8 @@ def check(
|
|||
)
|
||||
|
||||
# 4c. Otherwise, allow.
|
||||
# This is removed by https://github.com/matrix-org/matrix-doc/pull/2260
|
||||
if room_version_obj.special_case_aliases_auth:
|
||||
logger.debug("Allowing! %s", event)
|
||||
return
|
||||
logger.debug("Allowing! %s", event)
|
||||
return
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
|
||||
|
|
|
@ -15,9 +15,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Optional, Type
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
import six
|
||||
|
||||
|
@ -199,15 +200,25 @@ class _EventInternalMetadata(object):
|
|||
return self._dict.get("redacted", False)
|
||||
|
||||
|
||||
class EventBase(object):
|
||||
class EventBase(metaclass=abc.ABCMeta):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def format_version(self) -> int:
|
||||
"""The EventFormatVersion implemented by this event"""
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_dict,
|
||||
signatures={},
|
||||
unsigned={},
|
||||
internal_metadata_dict={},
|
||||
rejected_reason=None,
|
||||
event_dict: JsonDict,
|
||||
room_version: RoomVersion,
|
||||
signatures: Dict[str, Dict[str, str]],
|
||||
unsigned: JsonDict,
|
||||
internal_metadata_dict: JsonDict,
|
||||
rejected_reason: Optional[str],
|
||||
):
|
||||
assert room_version.event_format == self.format_version
|
||||
|
||||
self.room_version = room_version
|
||||
self.signatures = signatures
|
||||
self.unsigned = unsigned
|
||||
self.rejected_reason = rejected_reason
|
||||
|
@ -303,7 +314,13 @@ class EventBase(object):
|
|||
class FrozenEvent(EventBase):
|
||||
format_version = EventFormatVersions.V1 # All events of this type are V1
|
||||
|
||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||
def __init__(
|
||||
self,
|
||||
event_dict: JsonDict,
|
||||
room_version: RoomVersion,
|
||||
internal_metadata_dict: JsonDict = {},
|
||||
rejected_reason: Optional[str] = None,
|
||||
):
|
||||
event_dict = dict(event_dict)
|
||||
|
||||
# Signatures is a dict of dicts, and this is faster than doing a
|
||||
|
@ -326,8 +343,9 @@ class FrozenEvent(EventBase):
|
|||
|
||||
self._event_id = event_dict["event_id"]
|
||||
|
||||
super(FrozenEvent, self).__init__(
|
||||
super().__init__(
|
||||
frozen_dict,
|
||||
room_version=room_version,
|
||||
signatures=signatures,
|
||||
unsigned=unsigned,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
|
@ -352,7 +370,13 @@ class FrozenEvent(EventBase):
|
|||
class FrozenEventV2(EventBase):
|
||||
format_version = EventFormatVersions.V2 # All events of this type are V2
|
||||
|
||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||
def __init__(
|
||||
self,
|
||||
event_dict: JsonDict,
|
||||
room_version: RoomVersion,
|
||||
internal_metadata_dict: JsonDict = {},
|
||||
rejected_reason: Optional[str] = None,
|
||||
):
|
||||
event_dict = dict(event_dict)
|
||||
|
||||
# Signatures is a dict of dicts, and this is faster than doing a
|
||||
|
@ -377,8 +401,9 @@ class FrozenEventV2(EventBase):
|
|||
|
||||
self._event_id = None
|
||||
|
||||
super(FrozenEventV2, self).__init__(
|
||||
super().__init__(
|
||||
frozen_dict,
|
||||
room_version=room_version,
|
||||
signatures=signatures,
|
||||
unsigned=unsigned,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
|
@ -445,7 +470,7 @@ class FrozenEventV3(FrozenEventV2):
|
|||
return self._event_id
|
||||
|
||||
|
||||
def event_type_from_format_version(format_version: int) -> Type[EventBase]:
|
||||
def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
|
||||
"""Returns the python type to use to construct an Event object for the
|
||||
given event format version.
|
||||
|
||||
|
@ -474,5 +499,5 @@ def make_event_from_dict(
|
|||
rejected_reason: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
"""Construct an EventBase from the given event dict"""
|
||||
event_type = event_type_from_format_version(room_version.event_format)
|
||||
return event_type(event_dict, internal_metadata_dict, rejected_reason)
|
||||
event_type = _event_type_from_format_version(room_version.event_format)
|
||||
return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)
|
||||
|
|
|
@ -23,6 +23,7 @@ from frozendict import frozendict
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
|
||||
from . import EventBase
|
||||
|
@ -35,26 +36,20 @@ from . import EventBase
|
|||
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
|
||||
|
||||
|
||||
def prune_event(event):
|
||||
def prune_event(event: EventBase) -> EventBase:
|
||||
""" Returns a pruned version of the given event, which removes all keys we
|
||||
don't know about or think could potentially be dodgy.
|
||||
|
||||
This is used when we "redact" an event. We want to remove all fields that
|
||||
the user has specified, but we do want to keep necessary information like
|
||||
type, state_key etc.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
|
||||
Returns:
|
||||
FrozenEvent
|
||||
"""
|
||||
pruned_event_dict = prune_event_dict(event.get_dict())
|
||||
pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
|
||||
|
||||
from . import event_type_from_format_version
|
||||
from . import make_event_from_dict
|
||||
|
||||
pruned_event = event_type_from_format_version(event.format_version)(
|
||||
pruned_event_dict, event.internal_metadata.get_dict()
|
||||
pruned_event = make_event_from_dict(
|
||||
pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
|
||||
)
|
||||
|
||||
# Mark the event as redacted
|
||||
|
@ -63,15 +58,12 @@ def prune_event(event):
|
|||
return pruned_event
|
||||
|
||||
|
||||
def prune_event_dict(event_dict):
|
||||
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||
"""Redacts the event_dict in the same way as `prune_event`, except it
|
||||
operates on dicts rather than event objects
|
||||
|
||||
Args:
|
||||
event_dict (dict)
|
||||
|
||||
Returns:
|
||||
dict: A copy of the pruned event dict
|
||||
A copy of the pruned event dict
|
||||
"""
|
||||
|
||||
allowed_keys = [
|
||||
|
@ -118,7 +110,7 @@ def prune_event_dict(event_dict):
|
|||
"kick",
|
||||
"redact",
|
||||
)
|
||||
elif event_type == EventTypes.Aliases:
|
||||
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
|
||||
add_fields("aliases")
|
||||
elif event_type == EventTypes.RoomHistoryVisibility:
|
||||
add_fields("history_visibility")
|
||||
|
|
|
@ -15,11 +15,13 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Iterable, List
|
||||
|
||||
import six
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import DeferredList
|
||||
from twisted.internet.defer import Deferred, DeferredList
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
@ -29,6 +31,7 @@ from synapse.api.room_versions import (
|
|||
RoomVersion,
|
||||
)
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
|
@ -56,7 +59,12 @@ class FederationBase(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(
|
||||
self, origin, pdus, room_version, outlier=False, include_none=False
|
||||
self,
|
||||
origin: str,
|
||||
pdus: List[EventBase],
|
||||
room_version: str,
|
||||
outlier: bool = False,
|
||||
include_none: bool = False,
|
||||
):
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
|
@ -69,11 +77,11 @@ class FederationBase(object):
|
|||
a new list.
|
||||
|
||||
Args:
|
||||
origin (str)
|
||||
pdu (list)
|
||||
room_version (str)
|
||||
outlier (bool): Whether the events are outliers or not
|
||||
include_none (str): Whether to include None in the returned list
|
||||
origin
|
||||
pdu
|
||||
room_version
|
||||
outlier: Whether the events are outliers or not
|
||||
include_none: Whether to include None in the returned list
|
||||
for events that have failed their checks
|
||||
|
||||
Returns:
|
||||
|
@ -82,7 +90,7 @@ class FederationBase(object):
|
|||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_check_result(pdu, deferred):
|
||||
def handle_check_result(pdu: EventBase, deferred: Deferred):
|
||||
try:
|
||||
res = yield make_deferred_yieldable(deferred)
|
||||
except SynapseError:
|
||||
|
@ -96,12 +104,16 @@ class FederationBase(object):
|
|||
|
||||
if not res and pdu.origin != origin:
|
||||
try:
|
||||
res = yield self.get_pdu(
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
room_version=room_version,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
# This should not exist in the base implementation, until
|
||||
# this is fixed, ignore it for typing. See issue #6997.
|
||||
res = yield defer.ensureDeferred(
|
||||
self.get_pdu( # type: ignore
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
room_version=room_version,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
)
|
||||
except SynapseError:
|
||||
pass
|
||||
|
@ -125,21 +137,23 @@ class FederationBase(object):
|
|||
else:
|
||||
return [p for p in valid_pdus if p]
|
||||
|
||||
def _check_sigs_and_hash(self, room_version, pdu):
|
||||
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
|
||||
return make_deferred_yieldable(
|
||||
self._check_sigs_and_hashes(room_version, [pdu])[0]
|
||||
)
|
||||
|
||||
def _check_sigs_and_hashes(self, room_version, pdus):
|
||||
def _check_sigs_and_hashes(
|
||||
self, room_version: str, pdus: List[EventBase]
|
||||
) -> List[Deferred]:
|
||||
"""Checks that each of the received events is correctly signed by the
|
||||
sending server.
|
||||
|
||||
Args:
|
||||
room_version (str): The room version of the PDUs
|
||||
pdus (list[FrozenEvent]): the events to be checked
|
||||
room_version: The room version of the PDUs
|
||||
pdus: the events to be checked
|
||||
|
||||
Returns:
|
||||
list[Deferred]: for each input event, a deferred which:
|
||||
For each input event, a deferred which:
|
||||
* returns the original event if the checks pass
|
||||
* returns a redacted version of the event (if the signature
|
||||
matched but the hash did not)
|
||||
|
@ -150,7 +164,7 @@ class FederationBase(object):
|
|||
|
||||
ctx = LoggingContext.current_context()
|
||||
|
||||
def callback(_, pdu):
|
||||
def callback(_, pdu: EventBase):
|
||||
with PreserveLoggingContext(ctx):
|
||||
if not check_event_content_hash(pdu):
|
||||
# let's try to distinguish between failures because the event was
|
||||
|
@ -187,7 +201,7 @@ class FederationBase(object):
|
|||
|
||||
return pdu
|
||||
|
||||
def errback(failure, pdu):
|
||||
def errback(failure: Failure, pdu: EventBase):
|
||||
failure.trap(SynapseError)
|
||||
with PreserveLoggingContext(ctx):
|
||||
logger.warning(
|
||||
|
@ -213,16 +227,18 @@ class PduToCheckSig(
|
|||
pass
|
||||
|
||||
|
||||
def _check_sigs_on_pdus(keyring, room_version, pdus):
|
||||
def _check_sigs_on_pdus(
|
||||
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
|
||||
) -> List[Deferred]:
|
||||
"""Check that the given events are correctly signed
|
||||
|
||||
Args:
|
||||
keyring (synapse.crypto.Keyring): keyring object to do the checks
|
||||
room_version (str): the room version of the PDUs
|
||||
pdus (Collection[EventBase]): the events to be checked
|
||||
keyring: keyring object to do the checks
|
||||
room_version: the room version of the PDUs
|
||||
pdus: the events to be checked
|
||||
|
||||
Returns:
|
||||
List[Deferred]: a Deferred for each event in pdus, which will either succeed if
|
||||
A Deferred for each event in pdus, which will either succeed if
|
||||
the signatures are valid, or fail (with a SynapseError) if not.
|
||||
"""
|
||||
|
||||
|
@ -327,7 +343,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
|||
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
|
||||
|
||||
|
||||
def _flatten_deferred_list(deferreds):
|
||||
def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
|
||||
"""Given a list of deferreds, either return the single deferred,
|
||||
combine into a DeferredList, or return an already resolved deferred.
|
||||
"""
|
||||
|
@ -339,7 +355,7 @@ def _flatten_deferred_list(deferreds):
|
|||
return defer.succeed(None)
|
||||
|
||||
|
||||
def _is_invite_via_3pid(event):
|
||||
def _is_invite_via_3pid(event: EventBase) -> bool:
|
||||
return (
|
||||
event.type == EventTypes.Member
|
||||
and event.membership == Membership.INVITE
|
||||
|
|
|
@ -187,7 +187,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
||||
) -> List[EventBase]:
|
||||
) -> Optional[List[EventBase]]:
|
||||
"""Requests some more historic PDUs for the given room from the
|
||||
given destination server.
|
||||
|
||||
|
@ -199,9 +199,9 @@ class FederationClient(FederationBase):
|
|||
"""
|
||||
logger.debug("backfill extrem=%s", extremities)
|
||||
|
||||
# If there are no extremeties then we've (probably) reached the start.
|
||||
# If there are no extremities then we've (probably) reached the start.
|
||||
if not extremities:
|
||||
return
|
||||
return None
|
||||
|
||||
transaction_data = await self.transport_layer.backfill(
|
||||
dest, room_id, extremities, limit
|
||||
|
@ -284,7 +284,7 @@ class FederationClient(FederationBase):
|
|||
pdu_list = [
|
||||
event_from_pdu_json(p, room_version, outlier=outlier)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
] # type: List[EventBase]
|
||||
|
||||
if pdu_list and pdu_list[0]:
|
||||
pdu = pdu_list[0]
|
||||
|
@ -615,7 +615,7 @@ class FederationClient(FederationBase):
|
|||
]
|
||||
if auth_chain_create_events != [create_event.event_id]:
|
||||
raise InvalidResponseError(
|
||||
"Unexpected create event(s) in auth chain"
|
||||
"Unexpected create event(s) in auth chain: %s"
|
||||
% (auth_chain_create_events,)
|
||||
)
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
import logging
|
||||
import time
|
||||
import unicodedata
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
|
||||
import attr
|
||||
import bcrypt
|
||||
|
@ -38,8 +40,11 @@ from synapse.api.errors import (
|
|||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
|
||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
|
@ -108,6 +113,16 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
# Load the SSO redirect confirmation page HTML template
|
||||
self._sso_redirect_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
|
||||
)[0]
|
||||
|
||||
self._server_name = hs.config.server_name
|
||||
|
||||
# cast to tuple for use with str.startswith
|
||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
||||
"""
|
||||
|
@ -927,6 +942,65 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
return defer.succeed(False)
|
||||
|
||||
def complete_sso_login(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
Args:
|
||||
registered_user_id: The registered user ID to complete SSO login for.
|
||||
request: The request to complete.
|
||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||
process.
|
||||
"""
|
||||
# Create a login token
|
||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||
registered_user_id
|
||||
)
|
||||
|
||||
# Append the login token to the original redirect URL (i.e. with its query
|
||||
# parameters kept intact) to build the URL to which the template needs to
|
||||
# redirect the users once they have clicked on the confirmation link.
|
||||
redirect_url = self.add_query_param_to_url(
|
||||
client_redirect_url, "loginToken", login_token
|
||||
)
|
||||
|
||||
# if the client is whitelisted, we can redirect straight to it
|
||||
if client_redirect_url.startswith(self._whitelisted_sso_clients):
|
||||
request.redirect(redirect_url)
|
||||
finish_request(request)
|
||||
return
|
||||
|
||||
# Otherwise, serve the redirect confirmation page.
|
||||
|
||||
# Remove the query parameters from the redirect URL to get a shorter version of
|
||||
# it. This is only to display a human-readable URL in the template, but not the
|
||||
# URL we redirect users to.
|
||||
redirect_url_no_params = client_redirect_url.split("?")[0]
|
||||
|
||||
html = self._sso_redirect_confirm_template.render(
|
||||
display_url=redirect_url_no_params,
|
||||
redirect_url=redirect_url,
|
||||
server_name=self._server_name,
|
||||
).encode("utf-8")
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html),))
|
||||
request.write(html)
|
||||
finish_request(request)
|
||||
|
||||
@staticmethod
|
||||
def add_query_param_to_url(url: str, param_name: str, param: Any):
|
||||
url_parts = list(urllib.parse.urlparse(url))
|
||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
||||
query.update({param_name: param})
|
||||
url_parts[4] = urllib.parse.urlencode(query)
|
||||
return urllib.parse.urlunparse(url_parts)
|
||||
|
||||
|
||||
@attr.s
|
||||
class MacaroonGenerator(object):
|
||||
|
|
|
@ -13,11 +13,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import string
|
||||
from typing import List
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -30,6 +28,7 @@ from synapse.api.errors import (
|
|||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -57,7 +56,13 @@ class DirectoryHandler(BaseHandler):
|
|||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_association(self, room_alias, room_id, servers=None, creator=None):
|
||||
def _create_association(
|
||||
self,
|
||||
room_alias: RoomAlias,
|
||||
room_id: str,
|
||||
servers: Optional[Iterable[str]] = None,
|
||||
creator: Optional[str] = None,
|
||||
):
|
||||
# general association creation for both human users and app services
|
||||
|
||||
for wchar in string.whitespace:
|
||||
|
@ -83,17 +88,21 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def create_association(
|
||||
self, requester, room_alias, room_id, servers=None, check_membership=True,
|
||||
self,
|
||||
requester: Requester,
|
||||
room_alias: RoomAlias,
|
||||
room_id: str,
|
||||
servers: Optional[List[str]] = None,
|
||||
check_membership: bool = True,
|
||||
):
|
||||
"""Attempt to create a new alias
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
room_alias (RoomAlias)
|
||||
room_id (str)
|
||||
servers (list[str]|None): List of servers that others servers
|
||||
should try and join via
|
||||
check_membership (bool): Whether to check if the user is in the room
|
||||
requester
|
||||
room_alias
|
||||
room_id
|
||||
servers: Iterable of servers that others servers should try and join via
|
||||
check_membership: Whether to check if the user is in the room
|
||||
before the alias can be set (if the server's config requires it).
|
||||
|
||||
Returns:
|
||||
|
@ -147,15 +156,15 @@ class DirectoryHandler(BaseHandler):
|
|||
yield self._create_association(room_alias, room_id, servers, creator=user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_association(self, requester, room_alias):
|
||||
def delete_association(self, requester: Requester, room_alias: RoomAlias):
|
||||
"""Remove an alias from the directory
|
||||
|
||||
(this is only meant for human users; AS users should call
|
||||
delete_appservice_association)
|
||||
|
||||
Args:
|
||||
requester (Requester):
|
||||
room_alias (RoomAlias):
|
||||
requester
|
||||
room_alias
|
||||
|
||||
Returns:
|
||||
Deferred[unicode]: room id that the alias used to point to
|
||||
|
@ -191,16 +200,16 @@ class DirectoryHandler(BaseHandler):
|
|||
room_id = yield self._delete_association(room_alias)
|
||||
|
||||
try:
|
||||
yield self._update_canonical_alias(
|
||||
requester, requester.user.to_string(), room_id, room_alias
|
||||
)
|
||||
yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
||||
except AuthError as e:
|
||||
logger.info("Failed to update alias events: %s", e)
|
||||
|
||||
return room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_appservice_association(self, service, room_alias):
|
||||
def delete_appservice_association(
|
||||
self, service: ApplicationService, room_alias: RoomAlias
|
||||
):
|
||||
if not service.is_interested_in_alias(room_alias.to_string()):
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
@ -210,7 +219,7 @@ class DirectoryHandler(BaseHandler):
|
|||
yield self._delete_association(room_alias)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _delete_association(self, room_alias):
|
||||
def _delete_association(self, room_alias: RoomAlias):
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
|
||||
|
@ -219,7 +228,7 @@ class DirectoryHandler(BaseHandler):
|
|||
return room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association(self, room_alias):
|
||||
def get_association(self, room_alias: RoomAlias):
|
||||
room_id = None
|
||||
if self.hs.is_mine(room_alias):
|
||||
result = yield self.get_association_from_room_alias(room_alias)
|
||||
|
@ -284,7 +293,9 @@ class DirectoryHandler(BaseHandler):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
|
||||
def _update_canonical_alias(
|
||||
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
||||
):
|
||||
"""
|
||||
Send an updated canonical alias event if the removed alias was set as
|
||||
the canonical alias or listed in the alt_aliases field.
|
||||
|
@ -307,15 +318,17 @@ class DirectoryHandler(BaseHandler):
|
|||
send_update = True
|
||||
content.pop("alias", "")
|
||||
|
||||
# Filter alt_aliases for the removed alias.
|
||||
alt_aliases = content.pop("alt_aliases", None)
|
||||
# If the aliases are not a list (or not found) do not attempt to modify
|
||||
# the list.
|
||||
if isinstance(alt_aliases, collections.Sequence):
|
||||
# Filter the alt_aliases property for the removed alias. Note that the
|
||||
# value is not modified if alt_aliases is of an unexpected form.
|
||||
alt_aliases = content.get("alt_aliases")
|
||||
if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases:
|
||||
send_update = True
|
||||
alt_aliases = [alias for alias in alt_aliases if alias != alias_str]
|
||||
|
||||
if alt_aliases:
|
||||
content["alt_aliases"] = alt_aliases
|
||||
else:
|
||||
del content["alt_aliases"]
|
||||
|
||||
if send_update:
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
|
@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association_from_room_alias(self, room_alias):
|
||||
def get_association_from_room_alias(self, room_alias: RoomAlias):
|
||||
result = yield self.store.get_association_from_room_alias(room_alias)
|
||||
if not result:
|
||||
# Query AS to see if it exists
|
||||
|
@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler):
|
|||
result = yield as_handler.query_room_alias_exists(room_alias)
|
||||
return result
|
||||
|
||||
def can_modify_alias(self, alias, user_id=None):
|
||||
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
|
||||
# Any application service "interested" in an alias they are regexing on
|
||||
# can modify the alias.
|
||||
# Users can only modify the alias if ALL the interested services have
|
||||
|
@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler):
|
|||
return defer.succeed(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _user_can_delete_alias(self, alias, user_id):
|
||||
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
||||
"""Determine whether a user can delete an alias.
|
||||
|
||||
One of the following must be true:
|
||||
|
||||
1. The user created the alias.
|
||||
2. The user is a server administrator.
|
||||
3. The user has a power-level sufficient to send a canonical alias event
|
||||
for the current room.
|
||||
|
||||
"""
|
||||
creator = yield self.store.get_room_alias_creator(alias.to_string())
|
||||
|
||||
if creator is not None and creator == user_id:
|
||||
return True
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
|
||||
return is_admin
|
||||
# Resolve the alias to the corresponding room.
|
||||
room_mapping = yield self.get_association(alias)
|
||||
room_id = room_mapping["room_id"]
|
||||
if not room_id:
|
||||
return False
|
||||
|
||||
res = yield self.auth.check_can_change_room_list(
|
||||
room_id, UserID.from_string(user_id)
|
||||
)
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_room_list(self, requester, room_id, visibility):
|
||||
def edit_published_room_list(
|
||||
self, requester: Requester, room_id: str, visibility: str
|
||||
):
|
||||
"""Edit the entry of the room in the published room list.
|
||||
|
||||
requester
|
||||
room_id (str)
|
||||
visibility (str): "public" or "private"
|
||||
room_id
|
||||
visibility: "public" or "private"
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler):
|
|||
if room is None:
|
||||
raise SynapseError(400, "Unknown room")
|
||||
|
||||
yield self.auth.check_can_change_room_list(room_id, requester.user)
|
||||
can_change_room_list = yield self.auth.check_can_change_room_list(
|
||||
room_id, requester.user
|
||||
)
|
||||
if not can_change_room_list:
|
||||
raise AuthError(
|
||||
403,
|
||||
"This server requires you to be a moderator in the room to"
|
||||
" edit its room list entry",
|
||||
)
|
||||
|
||||
making_public = visibility == "public"
|
||||
if making_public:
|
||||
|
@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_appservice_room_list(
|
||||
self, appservice_id, network_id, room_id, visibility
|
||||
self, appservice_id: str, network_id: str, room_id: str, visibility: str
|
||||
):
|
||||
"""Add or remove a room from the appservice/network specific public
|
||||
room list.
|
||||
|
||||
Args:
|
||||
appservice_id (str): ID of the appservice that owns the list
|
||||
network_id (str): The ID of the network the list is associated with
|
||||
room_id (str)
|
||||
visibility (str): either "public" or "private"
|
||||
appservice_id: ID of the appservice that owns the list
|
||||
network_id: The ID of the network the list is associated with
|
||||
room_id
|
||||
visibility: either "public" or "private"
|
||||
"""
|
||||
if visibility not in ["public", "private"]:
|
||||
raise SynapseError(400, "Invalid visibility setting")
|
||||
|
|
|
@ -207,6 +207,13 @@ class E2eRoomKeysHandler(object):
|
|||
changed = False # if anything has changed, we need to update the etag
|
||||
for room_id, room in iteritems(room_keys["rooms"]):
|
||||
for session_id, room_key in iteritems(room["sessions"]):
|
||||
if not isinstance(room_key["is_verified"], bool):
|
||||
msg = (
|
||||
"is_verified must be a boolean in keys for session %s in"
|
||||
"room %s" % (session_id, room_id)
|
||||
)
|
||||
raise SynapseError(400, msg, Codes.INVALID_PARAM)
|
||||
|
||||
log_kv(
|
||||
{
|
||||
"message": "Trying to upload room key",
|
||||
|
|
|
@ -888,19 +888,60 @@ class EventCreationHandler(object):
|
|||
yield self.base_handler.maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
# Validate a newly added alias or newly added alt_aliases.
|
||||
|
||||
original_alias = None
|
||||
original_alt_aliases = set()
|
||||
|
||||
original_event_id = event.unsigned.get("replaces_state")
|
||||
if original_event_id:
|
||||
original_event = yield self.store.get_event(original_event_id)
|
||||
|
||||
if original_event:
|
||||
original_alias = original_event.content.get("alias", None)
|
||||
original_alt_aliases = original_event.content.get("alt_aliases", [])
|
||||
|
||||
# Check the alias is currently valid (if it has changed).
|
||||
room_alias_str = event.content.get("alias", None)
|
||||
if room_alias_str:
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
if room_alias_str and room_alias_str != original_alias:
|
||||
room_alias = RoomAlias.from_string(room_alias_str)
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (room_alias_str,),
|
||||
Codes.BAD_ALIAS,
|
||||
)
|
||||
|
||||
# Check that alt_aliases is the proper form.
|
||||
alt_aliases = event.content.get("alt_aliases", [])
|
||||
if not isinstance(alt_aliases, (list, tuple)):
|
||||
raise SynapseError(
|
||||
400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
# If the old version of alt_aliases is of an unknown form,
|
||||
# completely replace it.
|
||||
if not isinstance(original_alt_aliases, (list, tuple)):
|
||||
original_alt_aliases = []
|
||||
|
||||
# Check that each alias is currently valid.
|
||||
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
|
||||
if new_alt_aliases:
|
||||
for alias_str in new_alt_aliases:
|
||||
room_alias = RoomAlias.from_string(alias_str)
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room"
|
||||
% (room_alias_str,),
|
||||
Codes.BAD_ALIAS,
|
||||
)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
|
|
|
@ -25,7 +25,6 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.config import ConfigError
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest.client.v1.login import SSOAuthHandler
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
map_username_to_mxid_localpart,
|
||||
|
@ -48,7 +47,7 @@ class Saml2SessionData:
|
|||
class SamlHandler:
|
||||
def __init__(self, hs):
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
|
@ -116,7 +115,7 @@ class SamlHandler:
|
|||
self.expire_sessions()
|
||||
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
|
||||
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
|
||||
try:
|
||||
|
|
|
@ -27,10 +27,15 @@ import inspect
|
|||
import logging
|
||||
import threading
|
||||
import types
|
||||
from typing import Any, List
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.logging.scopecontextmanager import _LogContextScope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
|
@ -91,7 +96,7 @@ class ContextResourceUsage(object):
|
|||
"evt_db_fetch_count",
|
||||
]
|
||||
|
||||
def __init__(self, copy_from=None):
|
||||
def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
|
||||
"""Create a new ContextResourceUsage
|
||||
|
||||
Args:
|
||||
|
@ -101,27 +106,28 @@ class ContextResourceUsage(object):
|
|||
if copy_from is None:
|
||||
self.reset()
|
||||
else:
|
||||
self.ru_utime = copy_from.ru_utime
|
||||
self.ru_stime = copy_from.ru_stime
|
||||
self.db_txn_count = copy_from.db_txn_count
|
||||
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
|
||||
self.ru_utime = copy_from.ru_utime # type: float
|
||||
self.ru_stime = copy_from.ru_stime # type: float
|
||||
self.db_txn_count = copy_from.db_txn_count # type: int
|
||||
|
||||
self.db_txn_duration_sec = copy_from.db_txn_duration_sec
|
||||
self.db_sched_duration_sec = copy_from.db_sched_duration_sec
|
||||
self.evt_db_fetch_count = copy_from.evt_db_fetch_count
|
||||
self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
|
||||
self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
|
||||
self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "ContextResourceUsage":
|
||||
return ContextResourceUsage(copy_from=self)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.ru_stime = 0.0
|
||||
self.ru_utime = 0.0
|
||||
self.db_txn_count = 0
|
||||
|
||||
self.db_txn_duration_sec = 0
|
||||
self.db_sched_duration_sec = 0
|
||||
self.db_txn_duration_sec = 0.0
|
||||
self.db_sched_duration_sec = 0.0
|
||||
self.evt_db_fetch_count = 0
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
|
||||
"db_txn_count='%r', db_txn_duration_sec='%r', "
|
||||
|
@ -135,7 +141,7 @@ class ContextResourceUsage(object):
|
|||
self.evt_db_fetch_count,
|
||||
)
|
||||
|
||||
def __iadd__(self, other):
|
||||
def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
|
||||
"""Add another ContextResourceUsage's stats to this one's.
|
||||
|
||||
Args:
|
||||
|
@ -149,7 +155,7 @@ class ContextResourceUsage(object):
|
|||
self.evt_db_fetch_count += other.evt_db_fetch_count
|
||||
return self
|
||||
|
||||
def __isub__(self, other):
|
||||
def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
|
||||
self.ru_utime -= other.ru_utime
|
||||
self.ru_stime -= other.ru_stime
|
||||
self.db_txn_count -= other.db_txn_count
|
||||
|
@ -158,17 +164,20 @@ class ContextResourceUsage(object):
|
|||
self.evt_db_fetch_count -= other.evt_db_fetch_count
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
|
||||
res = ContextResourceUsage(copy_from=self)
|
||||
res += other
|
||||
return res
|
||||
|
||||
def __sub__(self, other):
|
||||
def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
|
||||
res = ContextResourceUsage(copy_from=self)
|
||||
res -= other
|
||||
return res
|
||||
|
||||
|
||||
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
|
||||
|
||||
|
||||
class LoggingContext(object):
|
||||
"""Additional context for log formatting. Contexts are scoped within a
|
||||
"with" block.
|
||||
|
@ -201,7 +210,14 @@ class LoggingContext(object):
|
|||
class Sentinel(object):
|
||||
"""Sentinel to represent the root context"""
|
||||
|
||||
__slots__ = [] # type: List[Any]
|
||||
__slots__ = ["previous_context", "alive", "request", "scope"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Minimal set for compatibility with LoggingContext
|
||||
self.previous_context = None
|
||||
self.alive = None
|
||||
self.request = None
|
||||
self.scope = None
|
||||
|
||||
def __str__(self):
|
||||
return "sentinel"
|
||||
|
@ -235,7 +251,7 @@ class LoggingContext(object):
|
|||
|
||||
sentinel = Sentinel()
|
||||
|
||||
def __init__(self, name=None, parent_context=None, request=None):
|
||||
def __init__(self, name=None, parent_context=None, request=None) -> None:
|
||||
self.previous_context = LoggingContext.current_context()
|
||||
self.name = name
|
||||
|
||||
|
@ -250,7 +266,7 @@ class LoggingContext(object):
|
|||
self.request = None
|
||||
self.tag = ""
|
||||
self.alive = True
|
||||
self.scope = None
|
||||
self.scope = None # type: Optional[_LogContextScope]
|
||||
|
||||
self.parent_context = parent_context
|
||||
|
||||
|
@ -261,13 +277,13 @@ class LoggingContext(object):
|
|||
# the request param overrides the request from the parent context
|
||||
self.request = request
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.request:
|
||||
return str(self.request)
|
||||
return "%s@%x" % (self.name, id(self))
|
||||
|
||||
@classmethod
|
||||
def current_context(cls):
|
||||
def current_context(cls) -> LoggingContextOrSentinel:
|
||||
"""Get the current logging context from thread local storage
|
||||
|
||||
Returns:
|
||||
|
@ -276,7 +292,9 @@ class LoggingContext(object):
|
|||
return getattr(cls.thread_local, "current_context", cls.sentinel)
|
||||
|
||||
@classmethod
|
||||
def set_current_context(cls, context):
|
||||
def set_current_context(
|
||||
cls, context: LoggingContextOrSentinel
|
||||
) -> LoggingContextOrSentinel:
|
||||
"""Set the current logging context in thread local storage
|
||||
Args:
|
||||
context(LoggingContext): The context to activate.
|
||||
|
@ -291,7 +309,7 @@ class LoggingContext(object):
|
|||
context.start()
|
||||
return current
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "LoggingContext":
|
||||
"""Enters this logging context into thread local storage"""
|
||||
old_context = self.set_current_context(self)
|
||||
if self.previous_context != old_context:
|
||||
|
@ -304,7 +322,7 @@ class LoggingContext(object):
|
|||
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
def __exit__(self, type, value, traceback) -> None:
|
||||
"""Restore the logging context in thread local storage to the state it
|
||||
was before this context was entered.
|
||||
Returns:
|
||||
|
@ -318,7 +336,6 @@ class LoggingContext(object):
|
|||
logger.warning(
|
||||
"Expected logging context %s but found %s", self, current
|
||||
)
|
||||
self.previous_context = None
|
||||
self.alive = False
|
||||
|
||||
# if we have a parent, pass our CPU usage stats on
|
||||
|
@ -330,7 +347,7 @@ class LoggingContext(object):
|
|||
# reset them in case we get entered again
|
||||
self._resource_usage.reset()
|
||||
|
||||
def copy_to(self, record):
|
||||
def copy_to(self, record) -> None:
|
||||
"""Copy logging fields from this context to a log record or
|
||||
another LoggingContext
|
||||
"""
|
||||
|
@ -341,14 +358,14 @@ class LoggingContext(object):
|
|||
# we also track the current scope:
|
||||
record.scope = self.scope
|
||||
|
||||
def copy_to_twisted_log_entry(self, record):
|
||||
def copy_to_twisted_log_entry(self, record) -> None:
|
||||
"""
|
||||
Copy logging fields from this context to a Twisted log record.
|
||||
"""
|
||||
record["request"] = self.request
|
||||
record["scope"] = self.scope
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
if get_thread_id() != self.main_thread:
|
||||
logger.warning("Started logcontext %s on different thread", self)
|
||||
return
|
||||
|
@ -358,7 +375,7 @@ class LoggingContext(object):
|
|||
if not self.usage_start:
|
||||
self.usage_start = get_thread_resource_usage()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
if get_thread_id() != self.main_thread:
|
||||
logger.warning("Stopped logcontext %s on different thread", self)
|
||||
return
|
||||
|
@ -378,7 +395,7 @@ class LoggingContext(object):
|
|||
|
||||
self.usage_start = None
|
||||
|
||||
def get_resource_usage(self):
|
||||
def get_resource_usage(self) -> ContextResourceUsage:
|
||||
"""Get resources used by this logcontext so far.
|
||||
|
||||
Returns:
|
||||
|
@ -398,11 +415,13 @@ class LoggingContext(object):
|
|||
|
||||
return res
|
||||
|
||||
def _get_cputime(self):
|
||||
def _get_cputime(self) -> Tuple[float, float]:
|
||||
"""Get the cpu usage time so far
|
||||
|
||||
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
|
||||
"""
|
||||
assert self.usage_start is not None
|
||||
|
||||
current = get_thread_resource_usage()
|
||||
|
||||
# Indicate to mypy that we know that self.usage_start is None.
|
||||
|
@ -430,13 +449,13 @@ class LoggingContext(object):
|
|||
|
||||
return utime_delta, stime_delta
|
||||
|
||||
def add_database_transaction(self, duration_sec):
|
||||
def add_database_transaction(self, duration_sec: float) -> None:
|
||||
if duration_sec < 0:
|
||||
raise ValueError("DB txn time can only be non-negative")
|
||||
self._resource_usage.db_txn_count += 1
|
||||
self._resource_usage.db_txn_duration_sec += duration_sec
|
||||
|
||||
def add_database_scheduled(self, sched_sec):
|
||||
def add_database_scheduled(self, sched_sec: float) -> None:
|
||||
"""Record a use of the database pool
|
||||
|
||||
Args:
|
||||
|
@ -447,7 +466,7 @@ class LoggingContext(object):
|
|||
raise ValueError("DB scheduling time can only be non-negative")
|
||||
self._resource_usage.db_sched_duration_sec += sched_sec
|
||||
|
||||
def record_event_fetch(self, event_count):
|
||||
def record_event_fetch(self, event_count: int) -> None:
|
||||
"""Record a number of events being fetched from the db
|
||||
|
||||
Args:
|
||||
|
@ -464,10 +483,10 @@ class LoggingContextFilter(logging.Filter):
|
|||
missing fields
|
||||
"""
|
||||
|
||||
def __init__(self, **defaults):
|
||||
def __init__(self, **defaults) -> None:
|
||||
self.defaults = defaults
|
||||
|
||||
def filter(self, record):
|
||||
def filter(self, record) -> Literal[True]:
|
||||
"""Add each fields from the logging contexts to the record.
|
||||
Returns:
|
||||
True to include the record in the log output.
|
||||
|
@ -492,12 +511,13 @@ class PreserveLoggingContext(object):
|
|||
|
||||
__slots__ = ["current_context", "new_context", "has_parent"]
|
||||
|
||||
def __init__(self, new_context=None):
|
||||
def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
|
||||
if new_context is None:
|
||||
new_context = LoggingContext.sentinel
|
||||
self.new_context = new_context
|
||||
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
|
||||
else:
|
||||
self.new_context = new_context
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
"""Captures the current logging context"""
|
||||
self.current_context = LoggingContext.set_current_context(self.new_context)
|
||||
|
||||
|
@ -506,7 +526,7 @@ class PreserveLoggingContext(object):
|
|||
if not self.current_context.alive:
|
||||
logger.debug("Entering dead context: %s", self.current_context)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
def __exit__(self, type, value, traceback) -> None:
|
||||
"""Restores the current logging context"""
|
||||
context = LoggingContext.set_current_context(self.current_context)
|
||||
|
||||
|
@ -525,7 +545,9 @@ class PreserveLoggingContext(object):
|
|||
logger.debug("Restoring dead context: %s", self.current_context)
|
||||
|
||||
|
||||
def nested_logging_context(suffix, parent_context=None):
|
||||
def nested_logging_context(
|
||||
suffix: str, parent_context: Optional[LoggingContext] = None
|
||||
) -> LoggingContext:
|
||||
"""Creates a new logging context as a child of another.
|
||||
|
||||
The nested logging context will have a 'request' made up of the parent context's
|
||||
|
@ -546,10 +568,12 @@ def nested_logging_context(suffix, parent_context=None):
|
|||
Returns:
|
||||
LoggingContext: new logging context.
|
||||
"""
|
||||
if parent_context is None:
|
||||
parent_context = LoggingContext.current_context()
|
||||
if parent_context is not None:
|
||||
context = parent_context # type: LoggingContextOrSentinel
|
||||
else:
|
||||
context = LoggingContext.current_context()
|
||||
return LoggingContext(
|
||||
parent_context=parent_context, request=parent_context.request + "-" + suffix
|
||||
parent_context=context, request=str(context.request) + "-" + suffix
|
||||
)
|
||||
|
||||
|
||||
|
@ -654,7 +678,10 @@ def make_deferred_yieldable(deferred):
|
|||
return deferred
|
||||
|
||||
|
||||
def _set_context_cb(result, context):
|
||||
ResultT = TypeVar("ResultT")
|
||||
|
||||
|
||||
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
|
||||
"""A callback function which just sets the logging context"""
|
||||
LoggingContext.set_current_context(context)
|
||||
return result
|
||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import UserID
|
||||
|
||||
|
@ -211,3 +212,21 @@ class ModuleApi(object):
|
|||
Deferred[object]: result of func
|
||||
"""
|
||||
return self._store.db.runInteraction(desc, func, *args, **kwargs)
|
||||
|
||||
def complete_sso_login(
|
||||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
||||
):
|
||||
"""Complete a SSO login by redirecting the user to a page to confirm whether they
|
||||
want their access token sent to `client_redirect_url`, or redirect them to that
|
||||
URL with a token directly if the URL matches with one of the whitelisted clients.
|
||||
|
||||
Args:
|
||||
registered_user_id: The MXID that has been registered as a previous step of
|
||||
of this SSO login.
|
||||
request: The request to respond to.
|
||||
client_redirect_url: The URL to which to offer to redirect the user (or to
|
||||
redirect them directly if whitelisted).
|
||||
"""
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url,
|
||||
)
|
||||
|
|
|
@ -555,10 +555,12 @@ class Mailer(object):
|
|||
else:
|
||||
# If the reason room doesn't have a name, say who the messages
|
||||
# are from explicitly to avoid, "messages in the Bob room"
|
||||
room_id = reason["room_id"]
|
||||
|
||||
sender_ids = list(
|
||||
{
|
||||
notif_events[n["event_id"]].sender
|
||||
for n in notifs_by_room[reason["room_id"]]
|
||||
for n in notifs_by_room[room_id]
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import logging
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import event_type_from_format_version
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
@ -38,6 +38,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
{
|
||||
"events": [{
|
||||
"event": { .. serialized event .. },
|
||||
"room_version": .., // "1", "2", "3", etc: the version of the room
|
||||
// containing the event
|
||||
"event_format_version": .., // 1,2,3 etc: the event format version
|
||||
"internal_metadata": { .. serialized internal_metadata .. },
|
||||
"rejected_reason": .., // The event.rejected_reason field
|
||||
"context": { .. serialized event context .. },
|
||||
|
@ -73,6 +76,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
event_payloads.append(
|
||||
{
|
||||
"event": event.get_pdu_json(),
|
||||
"room_version": event.room_version.identifier,
|
||||
"event_format_version": event.format_version,
|
||||
"internal_metadata": event.internal_metadata.get_dict(),
|
||||
"rejected_reason": event.rejected_reason,
|
||||
|
@ -95,12 +99,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
event_and_contexts = []
|
||||
for event_payload in event_payloads:
|
||||
event_dict = event_payload["event"]
|
||||
format_ver = event_payload["event_format_version"]
|
||||
room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
|
||||
internal_metadata = event_payload["internal_metadata"]
|
||||
rejected_reason = event_payload["rejected_reason"]
|
||||
|
||||
EventType = event_type_from_format_version(format_ver)
|
||||
event = EventType(event_dict, internal_metadata, rejected_reason)
|
||||
event = make_event_from_dict(
|
||||
event_dict, room_ver, internal_metadata, rejected_reason
|
||||
)
|
||||
|
||||
context = EventContext.deserialize(
|
||||
self.storage, event_payload["context"]
|
||||
|
|
|
@ -17,7 +17,8 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import event_type_from_format_version
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
@ -37,6 +38,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
|
||||
{
|
||||
"event": { .. serialized event .. },
|
||||
"room_version": .., // "1", "2", "3", etc: the version of the room
|
||||
// containing the event
|
||||
"event_format_version": .., // 1,2,3 etc: the event format version
|
||||
"internal_metadata": { .. serialized internal_metadata .. },
|
||||
"rejected_reason": .., // The event.rejected_reason field
|
||||
"context": { .. serialized event context .. },
|
||||
|
@ -77,6 +81,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
|
||||
payload = {
|
||||
"event": event.get_pdu_json(),
|
||||
"room_version": event.room_version.identifier,
|
||||
"event_format_version": event.format_version,
|
||||
"internal_metadata": event.internal_metadata.get_dict(),
|
||||
"rejected_reason": event.rejected_reason,
|
||||
|
@ -93,12 +98,13 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
content = parse_json_object_from_request(request)
|
||||
|
||||
event_dict = content["event"]
|
||||
format_ver = content["event_format_version"]
|
||||
room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
||||
internal_metadata = content["internal_metadata"]
|
||||
rejected_reason = content["rejected_reason"]
|
||||
|
||||
EventType = event_type_from_format_version(format_ver)
|
||||
event = EventType(event_dict, internal_metadata, rejected_reason)
|
||||
event = make_event_from_dict(
|
||||
event_dict, room_ver, internal_metadata, rejected_reason
|
||||
)
|
||||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
context = EventContext.deserialize(self.storage, content["context"])
|
||||
|
|
14
synapse/res/templates/sso_redirect_confirm.html
Normal file
14
synapse/res/templates/sso_redirect_confirm.html
Normal file
|
@ -0,0 +1,14 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>SSO redirect confirmation</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
|
||||
<p>If you don't recognise this address, you should ignore this and close this tab.</p>
|
||||
<p>
|
||||
<a href="{{ redirect_url | e }}">I trust this address</a>
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
|
@ -211,9 +211,7 @@ class UserRestServletV2(RestServlet):
|
|||
if target_user == auth_user and not set_admin_to:
|
||||
raise SynapseError(400, "You may not demote yourself.")
|
||||
|
||||
await self.admin_handler.set_user_server_admin(
|
||||
target_user, set_admin_to
|
||||
)
|
||||
await self.store.set_server_admin(target_user, set_admin_to)
|
||||
|
||||
if "password" in body:
|
||||
if (
|
||||
|
@ -651,6 +649,6 @@ class UserAdminServlet(RestServlet):
|
|||
if target_user == auth_user and not set_admin_to:
|
||||
raise SynapseError(400, "You may not demote yourself.")
|
||||
|
||||
await self.store.set_user_server_admin(target_user, set_admin_to)
|
||||
await self.store.set_server_admin(target_user, set_admin_to)
|
||||
|
||||
return 200, {}
|
||||
|
|
|
@ -28,7 +28,7 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
|
@ -548,6 +548,16 @@ class SSOAuthHandler(object):
|
|||
self._registration_handler = hs.get_registration_handler()
|
||||
self._macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
# Load the redirect page HTML template
|
||||
self._template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
|
||||
)[0]
|
||||
|
||||
self._server_name = hs.config.server_name
|
||||
|
||||
# cast to tuple for use with str.startswith
|
||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||
|
||||
async def on_successful_auth(
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
|
@ -580,36 +590,9 @@ class SSOAuthHandler(object):
|
|||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self.complete_sso_login(registered_user_id, request, client_redirect_url)
|
||||
|
||||
def complete_sso_login(
|
||||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
Args:
|
||||
registered_user_id:
|
||||
request:
|
||||
client_redirect_url:
|
||||
"""
|
||||
|
||||
login_token = self._macaroon_gen.generate_short_term_login_token(
|
||||
registered_user_id
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url
|
||||
)
|
||||
redirect_url = self._add_login_token_to_redirect_url(
|
||||
client_redirect_url, login_token
|
||||
)
|
||||
# Load page
|
||||
request.redirect(redirect_url)
|
||||
finish_request(request)
|
||||
|
||||
@staticmethod
|
||||
def _add_login_token_to_redirect_url(url, token):
|
||||
url_parts = list(urllib.parse.urlparse(url))
|
||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
||||
query.update({"loginToken": token})
|
||||
url_parts[4] = urllib.parse.urlencode(query)
|
||||
return urllib.parse.urlunparse(url_parts)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -18,8 +18,6 @@ from typing import Dict, Set
|
|||
from canonicaljson import encode_canonical_json, json
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import (
|
||||
|
@ -125,8 +123,7 @@ class RemoteKey(DirectServeResource):
|
|||
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||
async def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||
logger.info("Handling query for keys %r", query)
|
||||
|
||||
store_queries = []
|
||||
|
@ -143,7 +140,7 @@ class RemoteKey(DirectServeResource):
|
|||
for key_id in key_ids:
|
||||
store_queries.append((server_name, key_id, None))
|
||||
|
||||
cached = yield self.store.get_server_keys_json(store_queries)
|
||||
cached = await self.store.get_server_keys_json(store_queries)
|
||||
|
||||
json_results = set()
|
||||
|
||||
|
@ -215,8 +212,8 @@ class RemoteKey(DirectServeResource):
|
|||
json_results.add(bytes(result["key_json"]))
|
||||
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
yield self.fetcher.get_keys(cache_misses)
|
||||
yield self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||
await self.fetcher.get_keys(cache_misses)
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||
else:
|
||||
signed_keys = []
|
||||
for key_json in json_results:
|
||||
|
|
|
@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
return range_end
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.received_ts"
|
||||
" FROM event_push_actions AS ep"
|
||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
||||
" WHERE ep.stream_ordering > ?"
|
||||
" ORDER BY ep.stream_ordering ASC"
|
||||
" LIMIT 1"
|
||||
)
|
||||
txn.execute(sql, (stream_ordering,))
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
class EventPushActionsStore(EventPushActionsWorkerStore):
|
||||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
|
||||
|
@ -735,23 +752,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
|
||||
return push_actions
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.received_ts"
|
||||
" FROM event_push_actions AS ep"
|
||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
||||
" WHERE ep.stream_ordering > ?"
|
||||
" ORDER BY ep.stream_ordering ASC"
|
||||
" LIMIT 1"
|
||||
)
|
||||
txn.execute(sql, (stream_ordering,))
|
||||
return txn.fetchone()
|
||||
|
||||
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
|
||||
return result[0] if result else None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_latest_push_action_stream_ordering(self):
|
||||
def f(txn):
|
||||
|
|
|
@ -1168,7 +1168,11 @@ class EventsStore(
|
|||
and original_event.internal_metadata.is_redacted()
|
||||
):
|
||||
# Redaction was allowed
|
||||
pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
|
||||
pruned_json = encode_json(
|
||||
prune_event_dict(
|
||||
original_event.room_version, original_event.get_dict()
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Redaction wasn't allowed
|
||||
pruned_json = None
|
||||
|
@ -1929,7 +1933,9 @@ class EventsStore(
|
|||
return
|
||||
|
||||
# Prune the event's dict then convert it to JSON.
|
||||
pruned_json = encode_json(prune_event_dict(event.get_dict()))
|
||||
pruned_json = encode_json(
|
||||
prune_event_dict(event.room_version, event.get_dict())
|
||||
)
|
||||
|
||||
# Update the event_json table to replace the event's JSON with the pruned
|
||||
# JSON.
|
||||
|
|
|
@ -28,9 +28,12 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.api.room_versions import EventFormatVersions
|
||||
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# of a event format version, so it must be a V1 event.
|
||||
format_version = EventFormatVersions.V1
|
||||
|
||||
original_ev = event_type_from_format_version(format_version)(
|
||||
room_version_id = row["room_version_id"]
|
||||
|
||||
if not room_version_id:
|
||||
# this should only happen for out-of-band membership events
|
||||
if not internal_metadata.get("out_of_band_membership"):
|
||||
logger.warning(
|
||||
"Room %s for event %s is unknown", d["room_id"], event_id
|
||||
)
|
||||
continue
|
||||
|
||||
# take a wild stab at the room version based on the event format
|
||||
if format_version == EventFormatVersions.V1:
|
||||
room_version = RoomVersions.V1
|
||||
elif format_version == EventFormatVersions.V2:
|
||||
room_version = RoomVersions.V3
|
||||
else:
|
||||
room_version = RoomVersions.V5
|
||||
else:
|
||||
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||
if not room_version:
|
||||
logger.error(
|
||||
"Event %s in room %s has unknown room version %s",
|
||||
event_id,
|
||||
d["room_id"],
|
||||
room_version_id,
|
||||
)
|
||||
continue
|
||||
|
||||
if room_version.event_format != format_version:
|
||||
logger.error(
|
||||
"Event %s in room %s with version %s has wrong format: "
|
||||
"expected %s, was %s",
|
||||
event_id,
|
||||
d["room_id"],
|
||||
room_version_id,
|
||||
room_version.event_format,
|
||||
format_version,
|
||||
)
|
||||
continue
|
||||
|
||||
original_ev = make_event_from_dict(
|
||||
event_dict=d,
|
||||
room_version=room_version,
|
||||
internal_metadata_dict=internal_metadata,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
|
@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
of EventFormatVersions. 'None' means the event predates
|
||||
EventFormatVersions (so the event is format V1).
|
||||
|
||||
* room_version_id (str|None): The version of the room which contains the event.
|
||||
Hopefully one of RoomVersions.
|
||||
|
||||
Due to historical reasons, there may be a few events in the database which
|
||||
do not have an associated room; in this case None will be returned here.
|
||||
|
||||
* rejected_reason (str|None): if the event was rejected, the reason
|
||||
why.
|
||||
|
||||
|
@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
event_dict = {}
|
||||
for evs in batch_iter(event_ids, 200):
|
||||
sql = (
|
||||
"SELECT "
|
||||
" e.event_id, "
|
||||
" e.internal_metadata,"
|
||||
" e.json,"
|
||||
" e.format_version, "
|
||||
" rej.reason "
|
||||
" FROM event_json as e"
|
||||
" LEFT JOIN rejections as rej USING (event_id)"
|
||||
" WHERE "
|
||||
)
|
||||
sql = """\
|
||||
SELECT
|
||||
e.event_id,
|
||||
e.internal_metadata,
|
||||
e.json,
|
||||
e.format_version,
|
||||
r.room_version,
|
||||
rej.reason
|
||||
FROM event_json as e
|
||||
LEFT JOIN rooms r USING (room_id)
|
||||
LEFT JOIN rejections as rej USING (event_id)
|
||||
WHERE """
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "e.event_id", evs
|
||||
|
@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"internal_metadata": row[1],
|
||||
"json": row[2],
|
||||
"format_version": row[3],
|
||||
"rejected_reason": row[4],
|
||||
"room_version_id": row[4],
|
||||
"rejected_reason": row[5],
|
||||
"redactions": [],
|
||||
}
|
||||
|
||||
|
|
|
@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
|||
|
||||
def _count_users(txn):
|
||||
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
|
||||
|
||||
txn.execute(sql)
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return self.db.runInteraction("count_users", _count_users)
|
||||
|
||||
@cached(num_args=0)
|
||||
def get_monthly_active_count_by_service(self):
|
||||
"""Generates current count of monthly active users broken down by service.
|
||||
A service is typically an appservice but also includes native matrix users.
|
||||
Since the `monthly_active_users` table is populated from the `user_ips` table
|
||||
`config.track_appservice_user_ips` must be set to `true` for this
|
||||
method to return anything other than native matrix users.
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: dict that includes a mapping between app_service_id
|
||||
and the number of occurrences.
|
||||
|
||||
"""
|
||||
|
||||
def _count_users_by_service(txn):
|
||||
sql = """
|
||||
SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
|
||||
FROM monthly_active_users
|
||||
LEFT JOIN users ON monthly_active_users.user_id=users.name
|
||||
GROUP BY appservice_id;
|
||||
"""
|
||||
|
||||
txn.execute(sql)
|
||||
result = txn.fetchall()
|
||||
return dict(result)
|
||||
|
||||
return self.db.runInteraction("count_users_by_service", _count_users_by_service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_registered_reserved_users(self):
|
||||
"""Of the reserved threepids defined in config, which are associated
|
||||
|
@ -291,6 +318,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||
)
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_monthly_active_count_by_service, ()
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.user_last_seen_monthly_active, (user_id,)
|
||||
)
|
||||
|
|
|
@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
admin (bool): true iff the user is to be a server admin,
|
||||
false otherwise.
|
||||
"""
|
||||
return self.db.simple_update_one(
|
||||
table="users",
|
||||
keyvalues={"name": user.to_string()},
|
||||
updatevalues={"admin": 1 if admin else 0},
|
||||
desc="set_server_admin",
|
||||
)
|
||||
|
||||
def set_server_admin_txn(txn):
|
||||
self.db.simple_update_one_txn(
|
||||
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user.to_string(),)
|
||||
)
|
||||
|
||||
return self.db.runInteraction("set_server_admin", set_server_admin_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
|
|||
# if the CSDs haven't changed between prev_stream_id and now, we
|
||||
# know for certain that they haven't changed between prev_stream_id and
|
||||
# max_stream_id.
|
||||
return max_stream_id, []
|
||||
return defer.succeed((max_stream_id, []))
|
||||
|
||||
def get_current_state_deltas_txn(txn):
|
||||
# First we calculate the max stream id that will give us less than
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Tuple
|
||||
from time import monotonic as monotonic_time
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
from six.moves import intern, range
|
||||
|
@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
|
|||
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.stringutils import exception_to_unicode
|
||||
|
||||
# import a function which will return a monotonic time, in seconds
|
||||
try:
|
||||
# on python 3, use time.monotonic, since time.clock can go backwards
|
||||
from time import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
# ... but python 2 doesn't have it
|
||||
from time import clock as monotonic_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
MAX_TXN_ID = sys.maxint - 1
|
||||
except AttributeError:
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
|
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
|||
|
||||
|
||||
def make_pool(
|
||||
reactor, db_config: DatabaseConnectionConfig, engine
|
||||
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
) -> adbapi.ConnectionPool:
|
||||
"""Get the connection pool for the database.
|
||||
"""
|
||||
|
@ -90,7 +80,9 @@ def make_pool(
|
|||
)
|
||||
|
||||
|
||||
def make_conn(db_config: DatabaseConnectionConfig, engine):
|
||||
def make_conn(
|
||||
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
) -> Connection:
|
||||
"""Make a new connection to the database and return it.
|
||||
|
||||
Returns:
|
||||
|
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
|
|||
return db_conn
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||
#
|
||||
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
|
||||
# that mypy sees the type but the runtime python doesn't.
|
||||
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
class LoggingTransaction:
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method.
|
||||
|
||||
Args:
|
||||
txn: The database transcation object to wrap.
|
||||
name (str): The name of this transactions for logging.
|
||||
database_engine (Sqlite3Engine|PostgresEngine)
|
||||
after_callbacks(list|None): A list that callbacks will be appended to
|
||||
name: The name of this transactions for logging.
|
||||
database_engine
|
||||
after_callbacks: A list that callbacks will be appended to
|
||||
that have been added by `call_after` which should be run on
|
||||
successful completion of the transaction. None indicates that no
|
||||
callbacks should be allowed to be scheduled to run.
|
||||
exception_callbacks(list|None): A list that callbacks will be appended
|
||||
exception_callbacks: A list that callbacks will be appended
|
||||
to that have been added by `call_on_exception` which should be run
|
||||
if transaction ends with an error. None indicates that no callbacks
|
||||
should be allowed to be scheduled to run.
|
||||
|
@ -135,46 +134,67 @@ class LoggingTransaction(object):
|
|||
]
|
||||
|
||||
def __init__(
|
||||
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
|
||||
self,
|
||||
txn: Cursor,
|
||||
name: str,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||
object.__setattr__(self, "exception_callbacks", exception_callbacks)
|
||||
self.txn = txn
|
||||
self.name = name
|
||||
self.database_engine = database_engine
|
||||
self.after_callbacks = after_callbacks
|
||||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(self, callback, *args, **kwargs):
|
||||
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
"""
|
||||
# if self.after_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
assert self.after_callbacks is not None
|
||||
self.after_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def call_on_exception(self, callback, *args, **kwargs):
|
||||
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||
# if self.exception_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
assert self.exception_callbacks is not None
|
||||
self.exception_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.txn, name)
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
return self.txn.fetchall()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self.txn, name, value)
|
||||
def fetchone(self) -> Tuple:
|
||||
return self.txn.fetchone()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[Tuple]:
|
||||
return self.txn.__iter__()
|
||||
|
||||
@property
|
||||
def rowcount(self) -> int:
|
||||
return self.txn.rowcount
|
||||
|
||||
@property
|
||||
def description(self) -> Any:
|
||||
return self.txn.description
|
||||
|
||||
def execute_batch(self, sql, args):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch
|
||||
from psycopg2.extras import execute_batch # type: ignore
|
||||
|
||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||
else:
|
||||
for val in args:
|
||||
self.execute(sql, val)
|
||||
|
||||
def execute(self, sql, *args):
|
||||
def execute(self, sql: str, *args: Any):
|
||||
self._do_execute(self.txn.execute, sql, *args)
|
||||
|
||||
def executemany(self, sql, *args):
|
||||
def executemany(self, sql: str, *args: Any):
|
||||
self._do_execute(self.txn.executemany, sql, *args)
|
||||
|
||||
def _make_sql_one_line(self, sql):
|
||||
|
@ -207,6 +227,9 @@ class LoggingTransaction(object):
|
|||
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
|
||||
sql_query_timer.labels(sql.split()[0]).observe(secs)
|
||||
|
||||
def close(self):
|
||||
self.txn.close()
|
||||
|
||||
|
||||
class PerformanceCounters(object):
|
||||
def __init__(self):
|
||||
|
@ -251,7 +274,9 @@ class Database(object):
|
|||
|
||||
_TXN_ID = 0
|
||||
|
||||
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
|
||||
def __init__(
|
||||
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self._database_config = database_config
|
||||
|
@ -259,9 +284,9 @@ class Database(object):
|
|||
|
||||
self.updates = BackgroundUpdater(hs, self)
|
||||
|
||||
self._previous_txn_total_time = 0
|
||||
self._current_txn_total_time = 0
|
||||
self._previous_loop_ts = 0
|
||||
self._previous_txn_total_time = 0.0
|
||||
self._current_txn_total_time = 0.0
|
||||
self._previous_loop_ts = 0.0
|
||||
|
||||
# TODO(paul): These can eventually be removed once the metrics code
|
||||
# is running in mainline, and we have some nice monitoring frontends
|
||||
|
@ -463,23 +488,23 @@ class Database(object):
|
|||
sql_txn_timer.labels(desc).observe(duration)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
|
||||
"""Starts a transaction on the database and runs a given function
|
||||
|
||||
Arguments:
|
||||
desc (str): description of the transaction, for logging and metrics
|
||||
func (func): callback function, which will be called with a
|
||||
desc: description of the transaction, for logging and metrics
|
||||
func: callback function, which will be called with a
|
||||
database transaction (twisted.enterprise.adbapi.Transaction) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
args: positional args to pass to `func`
|
||||
kwargs: named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
after_callbacks = []
|
||||
exception_callbacks = []
|
||||
after_callbacks = [] # type: List[_CallbackListEntry]
|
||||
exception_callbacks = [] # type: List[_CallbackListEntry]
|
||||
|
||||
if LoggingContext.current_context() == LoggingContext.sentinel:
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
|
@ -505,15 +530,15 @@ class Database(object):
|
|||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runWithConnection(self, func, *args, **kwargs):
|
||||
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
Arguments:
|
||||
func (func): callback function, which will be called with a
|
||||
func: callback function, which will be called with a
|
||||
database connection (twisted.enterprise.adbapi.Connection) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
args: positional args to pass to `func`
|
||||
kwargs: named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
|
@ -800,7 +825,7 @@ class Database(object):
|
|||
return False
|
||||
|
||||
# We didn't find any existing rows, so insert a new one
|
||||
allvalues = {}
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(values)
|
||||
allvalues.update(insertion_values)
|
||||
|
@ -829,7 +854,7 @@ class Database(object):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
allvalues = {}
|
||||
allvalues = {} # type: Dict[str, Any]
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(insertion_values)
|
||||
|
||||
|
@ -916,7 +941,7 @@ class Database(object):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
allnames = []
|
||||
allnames = [] # type: List[str]
|
||||
allnames.extend(key_names)
|
||||
allnames.extend(value_names)
|
||||
|
||||
|
@ -1100,7 +1125,7 @@ class Database(object):
|
|||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
results = []
|
||||
results = [] # type: List[Dict[str, Any]]
|
||||
|
||||
if not iterable:
|
||||
return results
|
||||
|
@ -1439,7 +1464,7 @@ class Database(object):
|
|||
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||
|
||||
where_clause = "WHERE " if filters or keyvalues else ""
|
||||
arg_list = []
|
||||
arg_list = [] # type: List[Any]
|
||||
if filters:
|
||||
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
|
||||
arg_list += list(filters.values())
|
||||
|
|
|
@ -12,29 +12,31 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import platform
|
||||
|
||||
from ._base import IncorrectDatabaseSetup
|
||||
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
|
||||
from .postgres import PostgresEngine
|
||||
from .sqlite import Sqlite3Engine
|
||||
|
||||
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
|
||||
|
||||
|
||||
def create_engine(database_config):
|
||||
def create_engine(database_config) -> BaseDatabaseEngine:
|
||||
name = database_config["name"]
|
||||
engine_class = SUPPORTED_MODULE.get(name, None)
|
||||
|
||||
if engine_class:
|
||||
if name == "sqlite3":
|
||||
import sqlite3
|
||||
|
||||
return Sqlite3Engine(sqlite3, database_config)
|
||||
|
||||
if name == "psycopg2":
|
||||
# pypy requires psycopg2cffi rather than psycopg2
|
||||
if name == "psycopg2" and platform.python_implementation() == "PyPy":
|
||||
name = "psycopg2cffi"
|
||||
module = importlib.import_module(name)
|
||||
return engine_class(module, database_config)
|
||||
if platform.python_implementation() == "PyPy":
|
||||
import psycopg2cffi as psycopg2 # type: ignore
|
||||
else:
|
||||
import psycopg2 # type: ignore
|
||||
|
||||
return PostgresEngine(psycopg2, database_config)
|
||||
|
||||
raise RuntimeError("Unsupported database engine '%s'" % (name,))
|
||||
|
||||
|
||||
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
|
||||
__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
|
||||
|
|
|
@ -12,7 +12,94 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from synapse.storage.types import Connection
|
||||
|
||||
|
||||
class IncorrectDatabaseSetup(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
ConnectionType = TypeVar("ConnectionType", bound=Connection)
|
||||
|
||||
|
||||
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
|
||||
def __init__(self, module, database_config: dict):
|
||||
self.module = module
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def single_threaded(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def can_native_upsert(self) -> bool:
|
||||
"""
|
||||
Do we support native UPSERTs?
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supports_tuple_comparison(self) -> bool:
|
||||
"""
|
||||
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supports_using_any_list(self) -> bool:
|
||||
"""
|
||||
Do we support using `a = ANY(?)` and passing a list
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_database(
|
||||
self, db_conn: ConnectionType, allow_outdated_version: bool = False
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_new_database(self, txn) -> None:
|
||||
"""Gets called when setting up a brand new database. This allows us to
|
||||
apply stricter checks on new databases versus existing database.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def convert_param_style(self, sql: str) -> str:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_new_connection(self, db_conn: ConnectionType) -> None:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_deadlock(self, error: Exception) -> bool:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_connection_closed(self, conn: ConnectionType) -> bool:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock_table(self, txn, table: str) -> None:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_state_group_id(self, txn) -> int:
|
||||
"""Returns an int that can be used as a new state_group ID
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def server_version(self) -> str:
|
||||
"""Gets a string giving the server version. For example: '3.22.0'
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -15,16 +15,14 @@
|
|||
|
||||
import logging
|
||||
|
||||
from ._base import IncorrectDatabaseSetup
|
||||
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresEngine(object):
|
||||
single_threaded = False
|
||||
|
||||
class PostgresEngine(BaseDatabaseEngine):
|
||||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
super().__init__(database_module, database_config)
|
||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||
|
||||
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
|
||||
|
@ -36,6 +34,10 @@ class PostgresEngine(object):
|
|||
self.synchronous_commit = database_config.get("synchronous_commit", True)
|
||||
self._version = None # unknown as yet
|
||||
|
||||
@property
|
||||
def single_threaded(self) -> bool:
|
||||
return False
|
||||
|
||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||
# docs: The number is formed by converting the major, minor, and
|
||||
|
|
|
@ -12,16 +12,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sqlite3
|
||||
import struct
|
||||
import threading
|
||||
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
|
||||
class Sqlite3Engine(object):
|
||||
single_threaded = True
|
||||
|
||||
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
|
||||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
super().__init__(database_module, database_config)
|
||||
|
||||
database = database_config.get("args", {}).get("database")
|
||||
self._is_in_memory = database in (None, ":memory:",)
|
||||
|
@ -31,6 +31,10 @@ class Sqlite3Engine(object):
|
|||
self._current_state_group_id = None
|
||||
self._current_state_group_id_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def single_threaded(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_native_upsert(self):
|
||||
"""
|
||||
|
@ -68,7 +72,6 @@ class Sqlite3Engine(object):
|
|||
return sql
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
|
||||
# We need to import here to avoid an import loop.
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
|
|
65
synapse/storage/types.py
Normal file
65
synapse/storage/types.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Iterable, Iterator, List, Tuple
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
"""
|
||||
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
|
||||
"""
|
||||
|
||||
|
||||
class Cursor(Protocol):
|
||||
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
|
||||
...
|
||||
|
||||
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
|
||||
...
|
||||
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
...
|
||||
|
||||
def fetchone(self) -> Tuple:
|
||||
...
|
||||
|
||||
@property
|
||||
def description(self) -> Any:
|
||||
return None
|
||||
|
||||
@property
|
||||
def rowcount(self) -> int:
|
||||
return 0
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple]:
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Connection(Protocol):
|
||||
def cursor(self) -> Cursor:
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
...
|
||||
|
||||
def commit(self) -> None:
|
||||
...
|
||||
|
||||
def rollback(self, *args, **kwargs) -> None:
|
||||
...
|
|
@ -23,7 +23,7 @@ import attr
|
|||
from signedjson.key import decode_verify_key_bytes
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
# define a version of typing.Collection that works on python 3.5
|
||||
if sys.version_info[:3] >= (3, 6, 0):
|
||||
|
@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
|
|||
return self
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, s):
|
||||
def from_string(cls, s: str):
|
||||
"""Parse the string given by 's' into a structure object."""
|
||||
if len(s) < 1 or s[0:1] != cls.SIGIL:
|
||||
raise SynapseError(
|
||||
400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL)
|
||||
400,
|
||||
"Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
parts = s[1:].split(":", 1)
|
||||
|
@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
|
|||
400,
|
||||
"Expected %s of the form '%slocalname:domain'"
|
||||
% (cls.__name__, cls.SIGIL),
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
domain = parts[1]
|
||||
|
@ -235,11 +238,13 @@ class GroupID(DomainSpecificString):
|
|||
def from_string(cls, s):
|
||||
group_id = super(GroupID, cls).from_string(s)
|
||||
if not group_id.localpart:
|
||||
raise SynapseError(400, "Group ID cannot be empty")
|
||||
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
|
||||
|
||||
if contains_invalid_mxid_characters(group_id.localpart):
|
||||
raise SynapseError(
|
||||
400, "Group ID can only contain characters a-z, 0-9, or '=_-./'"
|
||||
400,
|
||||
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
return group_id
|
||||
|
|
|
@ -119,6 +119,9 @@ def filter_events_for_client(
|
|||
|
||||
the original event if they can see it as normal.
|
||||
"""
|
||||
if event.type == "org.matrix.dummy_event":
|
||||
return None
|
||||
|
||||
if not event.is_state() and event.sender in ignore_list:
|
||||
return None
|
||||
|
||||
|
|
|
@ -27,6 +27,11 @@ class FrontendProxyTests(HomeserverTestCase):
|
|||
|
||||
return hs
|
||||
|
||||
def default_config(self, name="test"):
|
||||
c = super().default_config(name)
|
||||
c["worker_app"] = "synapse.app.frontend_proxy"
|
||||
return c
|
||||
|
||||
def test_listen_http_with_presence_enabled(self):
|
||||
"""
|
||||
When presence is on, the stub servlet will not register.
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.utils import (
|
||||
copy_power_levels_contents,
|
||||
|
@ -36,9 +37,9 @@ class PruneEventTestCase(unittest.TestCase):
|
|||
""" Asserts that a new event constructed with `evdict` will look like
|
||||
`matchdict` when it is redacted. """
|
||||
|
||||
def run_test(self, evdict, matchdict):
|
||||
def run_test(self, evdict, matchdict, **kwargs):
|
||||
self.assertEquals(
|
||||
prune_event(make_event_from_dict(evdict)).get_dict(), matchdict
|
||||
prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
|
||||
)
|
||||
|
||||
def test_minimal(self):
|
||||
|
@ -128,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_alias_event(self):
|
||||
"""Alias events have special behavior up through room version 6."""
|
||||
self.run_test(
|
||||
{
|
||||
"type": "m.room.aliases",
|
||||
"event_id": "$test:domain",
|
||||
"content": {"aliases": ["test"]},
|
||||
},
|
||||
{
|
||||
"type": "m.room.aliases",
|
||||
"event_id": "$test:domain",
|
||||
"content": {"aliases": ["test"]},
|
||||
"signatures": {},
|
||||
"unsigned": {},
|
||||
},
|
||||
)
|
||||
|
||||
def test_msc2432_alias_event(self):
|
||||
"""After MSC2432, alias events have no special behavior."""
|
||||
self.run_test(
|
||||
{"type": "m.room.aliases", "content": {"aliases": ["test"]}},
|
||||
{
|
||||
"type": "m.room.aliases",
|
||||
"content": {},
|
||||
"signatures": {},
|
||||
"unsigned": {},
|
||||
},
|
||||
room_version=RoomVersions.MSC2432_DEV,
|
||||
)
|
||||
|
||||
|
||||
class SerializeEventTestCase(unittest.TestCase):
|
||||
def serialize(self, ev, fields):
|
||||
|
|
|
@ -18,6 +18,7 @@ from mock import Mock
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
import synapse.api.errors
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.config.room_directory import RoomDirectoryConfig
|
||||
|
@ -87,38 +88,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
|||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def test_delete_alias_not_allowed(self):
|
||||
room_id = "!8765qwer:test"
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(self.my_room, room_id, ["test"])
|
||||
)
|
||||
|
||||
self.get_failure(
|
||||
self.handler.delete_association(
|
||||
create_requester("@user:test"), self.my_room
|
||||
),
|
||||
synapse.api.errors.AuthError,
|
||||
)
|
||||
|
||||
def test_delete_alias(self):
|
||||
room_id = "!8765qwer:test"
|
||||
user_id = "@user:test"
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.my_room, room_id, ["test"], user_id
|
||||
)
|
||||
)
|
||||
|
||||
result = self.get_success(
|
||||
self.handler.delete_association(create_requester(user_id), self.my_room)
|
||||
)
|
||||
self.assertEquals(room_id, result)
|
||||
|
||||
# The alias should not be found.
|
||||
self.get_failure(
|
||||
self.handler.get_association(self.my_room), synapse.api.errors.SynapseError
|
||||
)
|
||||
|
||||
def test_incoming_fed_query(self):
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
|
@ -133,6 +102,119 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
|
||||
|
||||
|
||||
class TestDeleteAlias(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
directory.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.handler = hs.get_handlers().directory_handler
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
# Create user
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
# Create a test room
|
||||
self.room_id = self.helper.create_room_as(
|
||||
self.admin_user, tok=self.admin_user_tok
|
||||
)
|
||||
|
||||
self.test_alias = "#test:test"
|
||||
self.room_alias = RoomAlias.from_string(self.test_alias)
|
||||
|
||||
# Create a test user.
|
||||
self.test_user = self.register_user("user", "pass", admin=False)
|
||||
self.test_user_tok = self.login("user", "pass")
|
||||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
|
||||
|
||||
def _create_alias(self, user):
|
||||
# Create a new alias to this room.
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.room_alias, self.room_id, ["test"], user
|
||||
)
|
||||
)
|
||||
|
||||
def test_delete_alias_not_allowed(self):
|
||||
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
|
||||
self._create_alias(self.admin_user)
|
||||
self.get_failure(
|
||||
self.handler.delete_association(
|
||||
create_requester(self.test_user), self.room_alias
|
||||
),
|
||||
synapse.api.errors.AuthError,
|
||||
)
|
||||
|
||||
def test_delete_alias_creator(self):
|
||||
"""An alias creator can delete their own alias."""
|
||||
# Create an alias from a different user.
|
||||
self._create_alias(self.test_user)
|
||||
|
||||
# Delete the user's alias.
|
||||
result = self.get_success(
|
||||
self.handler.delete_association(
|
||||
create_requester(self.test_user), self.room_alias
|
||||
)
|
||||
)
|
||||
self.assertEquals(self.room_id, result)
|
||||
|
||||
# Confirm the alias is gone.
|
||||
self.get_failure(
|
||||
self.handler.get_association(self.room_alias),
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_delete_alias_admin(self):
|
||||
"""A server admin can delete an alias created by another user."""
|
||||
# Create an alias from a different user.
|
||||
self._create_alias(self.test_user)
|
||||
|
||||
# Delete the user's alias as the admin.
|
||||
result = self.get_success(
|
||||
self.handler.delete_association(
|
||||
create_requester(self.admin_user), self.room_alias
|
||||
)
|
||||
)
|
||||
self.assertEquals(self.room_id, result)
|
||||
|
||||
# Confirm the alias is gone.
|
||||
self.get_failure(
|
||||
self.handler.get_association(self.room_alias),
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_delete_alias_sufficient_power(self):
|
||||
"""A user with a sufficient power level should be able to delete an alias."""
|
||||
self._create_alias(self.admin_user)
|
||||
|
||||
# Increase the user's power level.
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.power_levels",
|
||||
{"users": {self.test_user: 100}},
|
||||
tok=self.admin_user_tok,
|
||||
)
|
||||
|
||||
# They can now delete the alias.
|
||||
result = self.get_success(
|
||||
self.handler.delete_association(
|
||||
create_requester(self.test_user), self.room_alias
|
||||
)
|
||||
)
|
||||
self.assertEquals(self.room_id, result)
|
||||
|
||||
# Confirm the alias is gone.
|
||||
self.get_failure(
|
||||
self.handler.get_association(self.room_alias),
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
|
||||
class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
"""Test modifications of the canonical alias when delete aliases.
|
||||
"""
|
||||
|
@ -159,30 +241,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.test_alias = "#test:test"
|
||||
self.room_alias = RoomAlias.from_string(self.test_alias)
|
||||
self.room_alias = self._add_alias(self.test_alias)
|
||||
|
||||
def _add_alias(self, alias: str) -> RoomAlias:
|
||||
"""Add an alias to the test room."""
|
||||
room_alias = RoomAlias.from_string(alias)
|
||||
|
||||
# Create a new alias to this room.
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
self.room_alias, self.room_id, ["test"], self.admin_user
|
||||
room_alias, self.room_id, ["test"], self.admin_user
|
||||
)
|
||||
)
|
||||
return room_alias
|
||||
|
||||
def _set_canonical_alias(self, content):
|
||||
"""Configure the canonical alias state on the room."""
|
||||
self.helper.send_state(
|
||||
self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
|
||||
)
|
||||
|
||||
def _get_canonical_alias(self):
|
||||
"""Get the canonical alias state of the room."""
|
||||
return self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
|
||||
def test_remove_alias(self):
|
||||
"""Removing an alias that is the canonical alias should remove it there too."""
|
||||
# Set this new alias as the canonical alias for this room
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.canonical_alias",
|
||||
{"alias": self.test_alias, "alt_aliases": [self.test_alias]},
|
||||
tok=self.admin_user_tok,
|
||||
self._set_canonical_alias(
|
||||
{"alias": self.test_alias, "alt_aliases": [self.test_alias]}
|
||||
)
|
||||
|
||||
data = self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
data = self._get_canonical_alias()
|
||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
||||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
|
||||
|
||||
|
@ -193,11 +287,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
data = self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
data = self._get_canonical_alias()
|
||||
self.assertNotIn("alias", data["content"])
|
||||
self.assertNotIn("alt_aliases", data["content"])
|
||||
|
||||
|
@ -205,29 +295,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
"""Removing an alias listed as in alt_aliases should remove it there too."""
|
||||
# Create a second alias.
|
||||
other_test_alias = "#test2:test"
|
||||
other_room_alias = RoomAlias.from_string(other_test_alias)
|
||||
self.get_success(
|
||||
self.store.create_room_alias_association(
|
||||
other_room_alias, self.room_id, ["test"], self.admin_user
|
||||
)
|
||||
)
|
||||
other_room_alias = self._add_alias(other_test_alias)
|
||||
|
||||
# Set the alias as the canonical alias for this room.
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
"m.room.canonical_alias",
|
||||
self._set_canonical_alias(
|
||||
{
|
||||
"alias": self.test_alias,
|
||||
"alt_aliases": [self.test_alias, other_test_alias],
|
||||
},
|
||||
tok=self.admin_user_tok,
|
||||
}
|
||||
)
|
||||
|
||||
data = self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
data = self._get_canonical_alias()
|
||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
||||
self.assertEqual(
|
||||
data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
|
||||
|
@ -240,11 +318,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
data = self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
data = self._get_canonical_alias()
|
||||
self.assertEqual(data["content"]["alias"], self.test_alias)
|
||||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
|||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.room import RoomEventSource
|
||||
|
@ -58,6 +59,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
|
||||
return super(SlavedEventStoreTestCase, self).setUp()
|
||||
|
||||
def prepare(self, *args, **kwargs):
|
||||
super().prepare(*args, **kwargs)
|
||||
|
||||
self.get_success(
|
||||
self.master_store.store_room(
|
||||
ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
|
||||
)
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
[unpatch() for unpatch in self.unpatches]
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
@ -371,22 +372,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.url = "/_synapse/admin/v2/users/@bob:test"
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
self.other_user = self.register_user("user", "pass")
|
||||
self.other_user_token = self.login("user", "pass")
|
||||
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
|
||||
self.other_user
|
||||
)
|
||||
|
||||
def test_requester_is_no_admin(self):
|
||||
"""
|
||||
If the user is not a server admin, an error is returned.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
url = "/_synapse/admin/v2/users/@bob:test"
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.other_user_token,
|
||||
"GET", url, access_token=self.other_user_token,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
|
@ -394,7 +397,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual("You are not a server admin", channel.json_body["error"])
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT", self.url, access_token=self.other_user_token, content=b"{}",
|
||||
"PUT", url, access_token=self.other_user_token, content=b"{}",
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
|
@ -417,24 +420,26 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(404, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
|
||||
|
||||
def test_requester_is_admin(self):
|
||||
def test_create_server_admin(self):
|
||||
"""
|
||||
If the user is a server admin, a new user is created.
|
||||
Check that a new admin user is created successfully.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
url = "/_synapse/admin/v2/users/@bob:test"
|
||||
|
||||
# Create user (server admin)
|
||||
body = json.dumps(
|
||||
{
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"displayname": "Bob's name",
|
||||
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
|
||||
}
|
||||
)
|
||||
|
||||
# Create user
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url,
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
|
@ -442,29 +447,85 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("bob", channel.json_body["displayname"])
|
||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
"GET", url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("bob", channel.json_body["displayname"])
|
||||
self.assertEqual(1, channel.json_body["admin"])
|
||||
self.assertEqual(0, channel.json_body["is_guest"])
|
||||
self.assertEqual(0, channel.json_body["deactivated"])
|
||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
self.assertEqual(False, channel.json_body["is_guest"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
|
||||
def test_create_user(self):
|
||||
"""
|
||||
Check that a new regular user is created successfully.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
url = "/_synapse/admin/v2/users/@bob:test"
|
||||
|
||||
# Create user
|
||||
body = json.dumps(
|
||||
{
|
||||
"password": "abc123",
|
||||
"admin": False,
|
||||
"displayname": "Bob's name",
|
||||
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
|
||||
}
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
self.assertEqual(False, channel.json_body["admin"])
|
||||
self.assertEqual(False, channel.json_body["is_guest"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
|
||||
def test_set_password(self):
|
||||
"""
|
||||
Test setting a new password for another user.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
# Change password
|
||||
body = json.dumps({"password": "hahaha"})
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url,
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
|
@ -472,41 +533,133 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
def test_set_displayname(self):
|
||||
"""
|
||||
Test setting the displayname of another user.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
# Modify user
|
||||
body = json.dumps({"displayname": "foobar"})
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual("foobar", channel.json_body["displayname"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url_other_user, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual("foobar", channel.json_body["displayname"])
|
||||
|
||||
def test_set_threepid(self):
|
||||
"""
|
||||
Test setting threepid for an other user.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
# Delete old and add new threepid to user
|
||||
body = json.dumps(
|
||||
{
|
||||
"displayname": "foobar",
|
||||
"deactivated": True,
|
||||
"threepids": [{"medium": "email", "address": "bob2@bob.bob"}],
|
||||
}
|
||||
{"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url,
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("foobar", channel.json_body["displayname"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url_other_user, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
|
||||
def test_deactivate_user(self):
|
||||
"""
|
||||
Test deactivating another user.
|
||||
"""
|
||||
|
||||
# Deactivate user
|
||||
body = json.dumps({"deactivated": True})
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
# the user is deactivated, the threepid will be deleted
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
"GET", self.url_other_user, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("foobar", channel.json_body["displayname"])
|
||||
self.assertEqual(1, channel.json_body["admin"])
|
||||
self.assertEqual(0, channel.json_body["is_guest"])
|
||||
self.assertEqual(1, channel.json_body["deactivated"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
|
||||
def test_set_user_as_admin(self):
|
||||
"""
|
||||
Test setting the admin flag on a user.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
# Set a user as an admin
|
||||
body = json.dumps({"admin": True})
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url_other_user, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["admin"])
|
||||
|
||||
def test_accidental_deactivation_prevention(self):
|
||||
"""
|
||||
|
@ -514,13 +667,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
for the deactivated body parameter
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
url = "/_synapse/admin/v2/users/@bob:test"
|
||||
|
||||
# Create user
|
||||
body = json.dumps({"password": "abc123"})
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url,
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
|
@ -532,7 +686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
"GET", url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
|
@ -546,7 +700,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url,
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content=body.encode(encoding="utf_8"),
|
||||
)
|
||||
|
@ -556,7 +710,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Check user is not deactivated
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url, access_token=self.admin_user_tok,
|
||||
"GET", url, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import json
|
||||
import urllib.parse
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.rest.client.v1 import login
|
||||
|
@ -252,3 +255,111 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200, channel.result)
|
||||
|
||||
|
||||
class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.base_url = "https://matrix.goodserver.com/"
|
||||
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
|
||||
|
||||
config = self.default_config()
|
||||
config["cas_config"] = {
|
||||
"enabled": True,
|
||||
"server_url": "https://fake.test",
|
||||
"service_url": "https://matrix.goodserver.com:8448",
|
||||
}
|
||||
|
||||
async def get_raw(uri, args):
|
||||
"""Return an example response payload from a call to the `/proxyValidate`
|
||||
endpoint of a CAS server, copied from
|
||||
https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
|
||||
|
||||
This needs to be returned by an async function (as opposed to set as the
|
||||
mock's return value) because the corresponding Synapse code awaits on it.
|
||||
"""
|
||||
return """
|
||||
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
|
||||
<cas:authenticationSuccess>
|
||||
<cas:user>username</cas:user>
|
||||
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
|
||||
<cas:proxies>
|
||||
<cas:proxy>https://proxy2/pgtUrl</cas:proxy>
|
||||
<cas:proxy>https://proxy1/pgtUrl</cas:proxy>
|
||||
</cas:proxies>
|
||||
</cas:authenticationSuccess>
|
||||
</cas:serviceResponse>
|
||||
"""
|
||||
|
||||
mocked_http_client = Mock(spec=["get_raw"])
|
||||
mocked_http_client.get_raw.side_effect = get_raw
|
||||
|
||||
self.hs = self.setup_test_homeserver(
|
||||
config=config, proxied_http_client=mocked_http_client,
|
||||
)
|
||||
|
||||
return self.hs
|
||||
|
||||
def test_cas_redirect_confirm(self):
|
||||
"""Tests that the SSO login flow serves a confirmation page before redirecting a
|
||||
user to the redirect URL.
|
||||
"""
|
||||
base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
|
||||
redirect_url = "https://dodgy-site.com/"
|
||||
|
||||
url_parts = list(urllib.parse.urlparse(base_url))
|
||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
||||
query.update({"redirectUrl": redirect_url})
|
||||
query.update({"ticket": "ticket"})
|
||||
url_parts[4] = urllib.parse.urlencode(query)
|
||||
cas_ticket_url = urllib.parse.urlunparse(url_parts)
|
||||
|
||||
# Get Synapse to call the fake CAS and serve the template.
|
||||
request, channel = self.make_request("GET", cas_ticket_url)
|
||||
self.render(request)
|
||||
|
||||
# Test that the response is HTML.
|
||||
self.assertEqual(channel.code, 200)
|
||||
content_type_header_value = ""
|
||||
for header in channel.result.get("headers", []):
|
||||
if header[0] == b"Content-Type":
|
||||
content_type_header_value = header[1].decode("utf8")
|
||||
|
||||
self.assertTrue(content_type_header_value.startswith("text/html"))
|
||||
|
||||
# Test that the body isn't empty.
|
||||
self.assertTrue(len(channel.result["body"]) > 0)
|
||||
|
||||
# And that it contains our redirect link
|
||||
self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"sso": {
|
||||
"client_whitelist": [
|
||||
"https://legit-site.com/",
|
||||
"https://other-site.com/",
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_cas_redirect_whitelisted(self):
|
||||
"""Tests that the SSO login flow serves a redirect to a whitelisted url
|
||||
"""
|
||||
redirect_url = "https://legit-site.com/"
|
||||
cas_ticket_url = (
|
||||
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
|
||||
% (urllib.parse.quote(redirect_url))
|
||||
)
|
||||
|
||||
# Get Synapse to call the fake CAS and serve the template.
|
||||
request, channel = self.make_request("GET", cas_ticket_url)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.code, 302)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
||||
|
|
|
@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
|
||||
class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.room_owner = self.register_user("room_owner", "test")
|
||||
self.room_owner_tok = self.login("room_owner", "test")
|
||||
|
||||
self.room_id = self.helper.create_room_as(
|
||||
self.room_owner, tok=self.room_owner_tok
|
||||
)
|
||||
|
||||
self.alias = "#alias:test"
|
||||
self._set_alias_via_directory(self.alias)
|
||||
|
||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
|
||||
url = "/_matrix/client/r0/directory/room/" + alias
|
||||
data = {"room_id": self.room_id}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"PUT", url, request_data, access_token=self.room_owner_tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
|
||||
"""Calls the endpoint under test. returns the json response object."""
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
|
||||
access_token=self.room_owner_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
res = channel.json_body
|
||||
self.assertIsInstance(res, dict)
|
||||
return res
|
||||
|
||||
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
|
||||
"""Calls the endpoint under test. returns the json response object."""
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
|
||||
json.dumps(content),
|
||||
access_token=self.room_owner_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
res = channel.json_body
|
||||
self.assertIsInstance(res, dict)
|
||||
return res
|
||||
|
||||
def test_canonical_alias(self):
|
||||
"""Test a basic alias message."""
|
||||
# There is no canonical alias to start with.
|
||||
self._get_canonical_alias(expected_code=404)
|
||||
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alias": self.alias})
|
||||
|
||||
# Canonical alias now exists!
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {"alias": self.alias})
|
||||
|
||||
# Now remove the alias.
|
||||
self._set_canonical_alias({})
|
||||
|
||||
# There is an alias event, but it is empty.
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {})
|
||||
|
||||
def test_alt_aliases(self):
|
||||
"""Test a canonical alias message with alt_aliases."""
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alt_aliases": [self.alias]})
|
||||
|
||||
# Canonical alias now exists!
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {"alt_aliases": [self.alias]})
|
||||
|
||||
# Now remove the alt_aliases.
|
||||
self._set_canonical_alias({})
|
||||
|
||||
# There is an alias event, but it is empty.
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {})
|
||||
|
||||
def test_alias_alt_aliases(self):
|
||||
"""Test a canonical alias message with an alias and alt_aliases."""
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
|
||||
# Canonical alias now exists!
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
|
||||
# Now remove the alias and alt_aliases.
|
||||
self._set_canonical_alias({})
|
||||
|
||||
# There is an alias event, but it is empty.
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {})
|
||||
|
||||
def test_partial_modify(self):
|
||||
"""Test removing only the alt_aliases."""
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
|
||||
# Canonical alias now exists!
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
|
||||
# Now remove the alt_aliases.
|
||||
self._set_canonical_alias({"alias": self.alias})
|
||||
|
||||
# There is an alias event, but it is empty.
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {"alias": self.alias})
|
||||
|
||||
def test_add_alias(self):
|
||||
"""Test removing only the alt_aliases."""
|
||||
# Create an additional alias.
|
||||
second_alias = "#second:test"
|
||||
self._set_alias_via_directory(second_alias)
|
||||
|
||||
# Add the canonical alias.
|
||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
|
||||
# Then add the second alias.
|
||||
self._set_canonical_alias(
|
||||
{"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
|
||||
)
|
||||
|
||||
# Canonical alias now exists!
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(
|
||||
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
|
||||
)
|
||||
|
||||
def test_bad_data(self):
|
||||
"""Invalid data for alt_aliases should cause errors."""
|
||||
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
|
||||
|
||||
def test_bad_alias(self):
|
||||
"""An alias which does not point to the room raises a SynapseError."""
|
||||
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
||||
|
|
|
@ -303,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
|||
self.pump()
|
||||
|
||||
self.store.upsert_monthly_active_user.assert_not_called()
|
||||
|
||||
def test_get_monthly_active_count_by_service(self):
|
||||
appservice1_user1 = "@appservice1_user1:example.com"
|
||||
appservice1_user2 = "@appservice1_user2:example.com"
|
||||
|
||||
appservice2_user1 = "@appservice2_user1:example.com"
|
||||
native_user1 = "@native_user1:example.com"
|
||||
|
||||
service1 = "service1"
|
||||
service2 = "service2"
|
||||
native = "native"
|
||||
|
||||
self.store.register_user(
|
||||
user_id=appservice1_user1, password_hash=None, appservice_id=service1
|
||||
)
|
||||
self.store.register_user(
|
||||
user_id=appservice1_user2, password_hash=None, appservice_id=service1
|
||||
)
|
||||
self.store.register_user(
|
||||
user_id=appservice2_user1, password_hash=None, appservice_id=service2
|
||||
)
|
||||
self.store.register_user(user_id=native_user1, password_hash=None)
|
||||
self.pump()
|
||||
|
||||
count = self.store.get_monthly_active_count_by_service()
|
||||
self.assertEqual({}, self.get_success(count))
|
||||
|
||||
self.store.upsert_monthly_active_user(native_user1)
|
||||
self.store.upsert_monthly_active_user(appservice1_user1)
|
||||
self.store.upsert_monthly_active_user(appservice1_user2)
|
||||
self.store.upsert_monthly_active_user(appservice2_user1)
|
||||
self.pump()
|
||||
|
||||
count = self.store.get_monthly_active_count()
|
||||
self.assertEqual(4, self.get_success(count))
|
||||
|
||||
count = self.store.get_monthly_active_count_by_service()
|
||||
result = self.get_success(count)
|
||||
|
||||
self.assertEqual(2, result[service1])
|
||||
self.assertEqual(1, result[service2])
|
||||
self.assertEqual(1, result[native])
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse import event_auth
|
|||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
|
||||
class EventAuthTestCase(unittest.TestCase):
|
||||
|
@ -51,7 +52,7 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
_random_state_event(joiner),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_default_level(self):
|
||||
"""
|
||||
|
@ -87,6 +88,83 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
|
||||
)
|
||||
|
||||
def test_alias_event(self):
|
||||
"""Alias events have special behavior up through room version 6."""
|
||||
creator = "@creator:example.com"
|
||||
other = "@other:example.com"
|
||||
auth_events = {
|
||||
("m.room.create", ""): _create_event(creator),
|
||||
("m.room.member", creator): _join_event(creator),
|
||||
}
|
||||
|
||||
# creator should be able to send aliases
|
||||
event_auth.check(
|
||||
RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
|
||||
)
|
||||
|
||||
# Reject an event with no state key.
|
||||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.V1,
|
||||
_alias_event(creator, state_key=""),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# If the domain of the sender does not match the state key, reject.
|
||||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.V1,
|
||||
_alias_event(creator, state_key="test.com"),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# Note that the member does *not* need to be in the room.
|
||||
event_auth.check(
|
||||
RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
|
||||
)
|
||||
|
||||
def test_msc2432_alias_event(self):
|
||||
"""After MSC2432, alias events have no special behavior."""
|
||||
creator = "@creator:example.com"
|
||||
other = "@other:example.com"
|
||||
auth_events = {
|
||||
("m.room.create", ""): _create_event(creator),
|
||||
("m.room.member", creator): _join_event(creator),
|
||||
}
|
||||
|
||||
# creator should be able to send aliases
|
||||
event_auth.check(
|
||||
RoomVersions.MSC2432_DEV,
|
||||
_alias_event(creator),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# No particular checks are done on the state key.
|
||||
event_auth.check(
|
||||
RoomVersions.MSC2432_DEV,
|
||||
_alias_event(creator, state_key=""),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
event_auth.check(
|
||||
RoomVersions.MSC2432_DEV,
|
||||
_alias_event(creator, state_key="test.com"),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# Per standard auth rules, the member must be in the room.
|
||||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.MSC2432_DEV,
|
||||
_alias_event(other),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
|
||||
# helpers for making events
|
||||
|
||||
|
@ -131,6 +209,19 @@ def _power_levels_event(sender, content):
|
|||
)
|
||||
|
||||
|
||||
def _alias_event(sender, **kwargs):
|
||||
data = {
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"event_id": _get_event_id(),
|
||||
"type": "m.room.aliases",
|
||||
"sender": sender,
|
||||
"state_key": get_domain_from_id(sender),
|
||||
"content": {"aliases": []},
|
||||
}
|
||||
data.update(**kwargs)
|
||||
return make_event_from_dict(data)
|
||||
|
||||
|
||||
def _random_state_event(sender):
|
||||
return make_event_from_dict(
|
||||
{
|
||||
|
|
|
@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase):
|
|||
self.fail("Parsing '%s' should raise exception" % id_string)
|
||||
except SynapseError as exc:
|
||||
self.assertEqual(400, exc.code)
|
||||
self.assertEqual("M_UNKNOWN", exc.errcode)
|
||||
self.assertEqual("M_INVALID_PARAM", exc.errcode)
|
||||
|
||||
|
||||
class MapUsernameTestCase(unittest.TestCase):
|
||||
|
|
8
tox.ini
8
tox.ini
|
@ -168,7 +168,6 @@ commands=
|
|||
coverage html
|
||||
|
||||
[testenv:mypy]
|
||||
basepython = python3.7
|
||||
skip_install = True
|
||||
deps =
|
||||
{[base]deps}
|
||||
|
@ -179,10 +178,14 @@ env =
|
|||
extras = all
|
||||
commands = mypy \
|
||||
synapse/api \
|
||||
synapse/config/ \
|
||||
synapse/appservice \
|
||||
synapse/config \
|
||||
synapse/events/spamcheck.py \
|
||||
synapse/federation/federation_base.py \
|
||||
synapse/federation/federation_client.py \
|
||||
synapse/federation/sender \
|
||||
synapse/federation/transport \
|
||||
synapse/handlers/directory.py \
|
||||
synapse/handlers/presence.py \
|
||||
synapse/handlers/sync.py \
|
||||
synapse/handlers/ui_auth \
|
||||
|
@ -192,6 +195,7 @@ commands = mypy \
|
|||
synapse/rest \
|
||||
synapse/spam_checker_api \
|
||||
synapse/storage/engines \
|
||||
synapse/storage/database.py \
|
||||
synapse/streams
|
||||
|
||||
# To find all folders that pass mypy you run:
|
||||
|
|
Loading…
Reference in a new issue