Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2021-01-14 15:29:29 +00:00
commit 43dc637136
125 changed files with 4659 additions and 1102 deletions

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from synapse.storage.engines import create_engine
logger = logging.getLogger("create_postgres_db")

View file

@ -1,5 +1,30 @@
Synapse 1.25.0rc1 (2021-01-06)
==============================
Synapse 1.25.0 (2021-01-13)
===========================
Ending Support for Python 3.5 and Postgres 9.5
----------------------------------------------
With this release, the Synapse team is announcing a formal deprecation policy for our platform dependencies, like Python and PostgreSQL:
All future releases of Synapse will follow the upstream end-of-life schedules.
Which means:
* This is the last release which guarantees support for Python 3.5.
* We will end support for PostgreSQL 9.5 early next month.
* We will end support for Python 3.6 and PostgreSQL 9.6 near the end of the year.
Crucially, this means __we will not produce .deb packages for Debian 9 (Stretch) or Ubuntu 16.04 (Xenial)__ beyond the transition period described below.
The website https://endoflife.date/ has convenient summaries of the support schedules for projects like [Python](https://endoflife.date/python) and [PostgreSQL](https://endoflife.date/postgresql).
If you are unable to upgrade your environment to a supported version of Python or Postgres, we encourage you to consider using the [Synapse Docker images](./INSTALL.md#docker-images-and-ansible-playbooks) instead.
### Transition Period
We will make a good faith attempt to avoid breaking compatibility in all releases through the end of March 2021. However, critical security vulnerabilities in dependencies or other unanticipated circumstances may arise which necessitate breaking compatibility earlier.
We intend to continue producing .deb packages for Debian 9 (Stretch) and Ubuntu 16.04 (Xenial) through the transition period.
Removal warning
---------------
@ -12,6 +37,15 @@ are deprecated and will be removed in a future release. They will be replaced by
`POST /_synapse/admin/v1/rooms/<room_id>/delete` replaces `POST /_synapse/admin/v1/purge_room` and
`POST /_synapse/admin/v1/shutdown_room/<room_id>`.
Bugfixes
--------
- Fix HTTP proxy support when using a proxy that is on a blacklisted IP. Introduced in v1.25.0rc1. Contributed by @Bubu. ([\#9084](https://github.com/matrix-org/synapse/issues/9084))
Synapse 1.25.0rc1 (2021-01-06)
==============================
Features
--------
@ -61,7 +95,7 @@ Improved Documentation
- Combine related media admin API docs. ([\#8839](https://github.com/matrix-org/synapse/issues/8839))
- Fix an error in the documentation for the SAML username mapping provider. ([\#8873](https://github.com/matrix-org/synapse/issues/8873))
- Clarify comments around template directories in `sample_config.yaml`. ([\#8891](https://github.com/matrix-org/synapse/issues/8891))
- Moved instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987))
- Move instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by @fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987))
- Update the example value of `group_creation_prefix` in the sample configuration. ([\#8992](https://github.com/matrix-org/synapse/issues/8992))
- Link the Synapse developer room to the development section in the docs. ([\#9002](https://github.com/matrix-org/synapse/issues/9002))

View file

@ -257,7 +257,7 @@ for a number of platforms.
#### Docker images and Ansible playbooks
There is an offical synapse image available at
There is an official synapse image available at
<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
the docker-compose file available at [contrib/docker](contrib/docker). Further
information on this including configuration options is available in the README

View file

@ -243,7 +243,7 @@ Then update the ``users`` table in the database::
Synapse Development
===================
Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
Before setting up a development environment for synapse, make sure you have the
system dependencies (such as the python header files) installed - see

View file

@ -5,6 +5,16 @@ Before upgrading check if any special steps are required to upgrade from the
version you currently have installed to the current version of Synapse. The extra
instructions that may be required are listed later in this document.
* Check that your versions of Python and PostgreSQL are still supported.
Synapse follows upstream lifecycles for `Python`_ and `PostgreSQL`_, and
removes support for versions which are no longer maintained.
The website https://endoflife.date also offers convenient summaries.
.. _Python: https://devguide.python.org/devcycle/#end-of-life-branches
.. _PostgreSQL: https://www.postgresql.org/support/versioning/
* If Synapse was installed using `prebuilt packages
<INSTALL.md#prebuilt-packages>`_, you will need to follow the normal process
for upgrading those packages.
@ -78,6 +88,18 @@ for example:
Upgrading to v1.25.0
====================
Last release supporting Python 3.5
----------------------------------
This is the last release of Synapse which guarantees support with Python 3.5,
which passed its upstream End of Life date several months ago.
We will attempt to maintain support through March 2021, but without guarantees.
In the future, Synapse will follow upstream schedules for ending support of
older versions of Python and PostgreSQL. Please upgrade to at least Python 3.6
and PostgreSQL 9.6 as soon as possible.
Blacklisting IP ranges
----------------------

1
changelog.d/8868.misc Normal file
View file

@ -0,0 +1 @@
Improve efficiency of large state resolutions.

1
changelog.d/8932.feature Normal file
View file

@ -0,0 +1 @@
Remove a user's avatar URL and display name when deactivated with the Admin API.

1
changelog.d/8948.feature Normal file
View file

@ -0,0 +1 @@
Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users.

1
changelog.d/9016.misc Normal file
View file

@ -0,0 +1 @@
Ensure rejected events get added to some metadata tables.

1
changelog.d/9025.misc Normal file
View file

@ -0,0 +1 @@
Removed an unused column from `access_tokens` table.

1
changelog.d/9029.misc Normal file
View file

@ -0,0 +1 @@
Improve efficiency of large state resolutions.

1
changelog.d/9036.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9038.misc Normal file
View file

@ -0,0 +1 @@
Configure the linters to run on a consistent set of files.

1
changelog.d/9039.removal Normal file
View file

@ -0,0 +1 @@
Remove broken and unmaintained `demo/webserver.py` script.

1
changelog.d/9040.doc Normal file
View file

@ -0,0 +1 @@
Corrected a typo in `INSTALL.md`.

1
changelog.d/9051.bugfix Normal file
View file

@ -0,0 +1 @@
Fix error handling during insertion of client IPs into the database.

1
changelog.d/9053.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where we didn't correctly record CPU time spent in 'on_new_event' block.

1
changelog.d/9054.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a minor bug which could cause confusing error messages from invalid configurations.

1
changelog.d/9055.misc Normal file
View file

@ -0,0 +1 @@
Drop unused database tables.

1
changelog.d/9057.doc Normal file
View file

@ -0,0 +1 @@
Add missing user_mapping_provider configuration to the Keycloak OIDC example. Contributed by @chris-ruecker.

1
changelog.d/9058.misc Normal file
View file

@ -0,0 +1 @@
Remove unused `SynapseService` class.

1
changelog.d/9059.bugfix Normal file
View file

@ -0,0 +1 @@
Fix incorrect exit code when there is an error at startup.

1
changelog.d/9063.misc Normal file
View file

@ -0,0 +1 @@
Removes unnecessary declarations in the tests for the admin API.

1
changelog.d/9067.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9068.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes.

1
changelog.d/9069.misc Normal file
View file

@ -0,0 +1 @@
Remove `SynapseRequest.get_user_agent`.

1
changelog.d/9070.bugfix Normal file
View file

@ -0,0 +1 @@
Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers.

1
changelog.d/9071.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "Failed to send request" errors when a client provides an invalid room alias.

1
changelog.d/9080.misc Normal file
View file

@ -0,0 +1 @@
Remove redundant `Homeserver.get_ip_from_request` method.

1
changelog.d/9081.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9082.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9092.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for handling `/devices` API on worker processes.

1
changelog.d/9098.misc Normal file
View file

@ -0,0 +1 @@
Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung.

1
changelog.d/9105.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9106.misc Normal file
View file

@ -0,0 +1 @@
Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`.

1
changelog.d/9107.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

11
debian/changelog vendored
View file

@ -1,3 +1,14 @@
matrix-synapse-py3 (1.25.0) stable; urgency=medium
[ Dan Callahan ]
* Update dependencies to account for the removal of the transitional
dh-systemd package from Debian Bullseye.
[ Synapse Packaging team ]
* New synapse release 1.25.0.
-- Synapse Packaging team <packages@matrix.org> Wed, 13 Jan 2021 10:14:55 +0000
matrix-synapse-py3 (1.24.0) stable; urgency=medium
* New synapse release 1.24.0.

6
debian/control vendored
View file

@ -3,9 +3,11 @@ Section: contrib/python
Priority: extra
Maintainer: Synapse Packaging team <packages@matrix.org>
# keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv.
# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial
# On all other supported releases, it's merely a transitional package which
# does nothing but depends on debhelper (> 9.20160709)
Build-Depends:
debhelper (>= 9),
dh-systemd,
debhelper (>= 9.20160709) | dh-systemd,
dh-virtualenv (>= 1.1),
libsystemd-dev,
libpq-dev,

View file

@ -1,59 +0,0 @@
import argparse
import BaseHTTPServer
import os
import SimpleHTTPServer
import cgi, logging
from daemonize import Daemonize
class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
UPLOAD_PATH = "upload"
"""
Accept all post request as file upload
"""
def do_POST(self):
path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
length = self.headers["content-length"]
data = self.rfile.read(int(length))
with open(path, "wb") as fh:
fh.write(data)
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
# Return the absolute path of the uploaded file
self.wfile.write('{"url":"/%s"}' % path)
def setup():
parser = argparse.ArgumentParser()
parser.add_argument("directory")
parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
args = parser.parse_args()
# Get absolute path to directory to serve, as daemonize changes to '/'
os.chdir(args.directory)
dr = os.getcwd()
httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
def run():
os.chdir(dr)
httpd.serve_forever()
daemon = Daemonize(
app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
)
daemon.start()
if __name__ == "__main__":
setup()

View file

@ -50,17 +50,22 @@ FROM ${distro}
ARG distro=""
ENV distro ${distro}
# Python < 3.7 assumes LANG="C" means ASCII-only and throws on printing unicode
# http://bugs.python.org/issue19846
ENV LANG C.UTF-8
# Install the build dependencies
#
# NB: keep this list in sync with the list of build-deps in debian/control
# TODO: it would be nice to do that automatically.
# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial
# it's a transitional package on all other, more recent releases
RUN apt-get update -qq -o Acquire::Languages=none \
&& env DEBIAN_FRONTEND=noninteractive apt-get install \
-yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \
build-essential \
debhelper \
devscripts \
dh-systemd \
libsystemd-dev \
lsb-release \
pkg-config \
@ -70,7 +75,10 @@ RUN apt-get update -qq -o Acquire::Languages=none \
python3-venv \
sqlite3 \
libpq-dev \
xmlsec1
xmlsec1 \
&& ( env DEBIAN_FRONTEND=noninteractive apt-get install \
-yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \
dh-systemd || true )
COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /

View file

@ -98,6 +98,8 @@ Body parameters:
- ``deactivated``, optional. If unspecified, deactivation state will be left
unchanged on existing accounts and set to ``false`` for new accounts.
A user cannot be erased by deactivating with this API. For details on deactivating users see
`Deactivate Account <#deactivate-account>`_.
If the user already exists then optional parameters default to the current value.
@ -248,6 +250,25 @@ server admin: see `README.rst <README.rst>`_.
The erase parameter is optional and defaults to ``false``.
An empty body may be passed for backwards compatibility.
The following actions are performed when deactivating an user:
- Try to unpind 3PIDs from the identity server
- Remove all 3PIDs from the homeserver
- Delete all devices and E2EE keys
- Delete all access tokens
- Delete the password hash
- Removal from all rooms the user is a member of
- Remove the user from the user directory
- Reject all pending invites
- Remove all account validity information related to the user
The following additional actions are performed during deactivation if``erase``
is set to ``true``:
- Remove the user's display name
- Remove the user's avatar URL
- Mark the user as erased
Reset password
==============
@ -337,6 +358,10 @@ A response body like the following is returned:
"total": 2
}
The server returns the list of rooms of which the user and the server
are member. If the user is local, all the rooms of which the user is
member are returned.
**Parameters**
The following parameters should be set in the URL:

32
docs/auth_chain_diff.dot Normal file
View file

@ -0,0 +1,32 @@
digraph auth {
nodesep=0.5;
rankdir="RL";
C [label="Create (1,1)"];
BJ [label="Bob's Join (2,1)", color=red];
BJ2 [label="Bob's Join (2,2)", color=red];
BJ2 -> BJ [color=red, dir=none];
subgraph cluster_foo {
A1 [label="Alice's invite (4,1)", color=blue];
A2 [label="Alice's Join (4,2)", color=blue];
A3 [label="Alice's Join (4,3)", color=blue];
A3 -> A2 -> A1 [color=blue, dir=none];
color=none;
}
PL1 [label="Power Level (3,1)", color=darkgreen];
PL2 [label="Power Level (3,2)", color=darkgreen];
PL2 -> PL1 [color=darkgreen, dir=none];
{rank = same; C; BJ; PL1; A1;}
A1 -> C [color=grey];
A1 -> BJ [color=grey];
PL1 -> C [color=grey];
BJ2 -> PL1 [penwidth=2];
A3 -> PL2 [penwidth=2];
A1 -> PL1 -> BJ -> C [penwidth=2];
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View file

@ -0,0 +1,108 @@
# Auth Chain Difference Algorithm
The auth chain difference algorithm is used by V2 state resolution, where a
naive implementation can be a significant source of CPU and DB usage.
### Definitions
A *state set* is a set of state events; e.g. the input of a state resolution
algorithm is a collection of state sets.
The *auth chain* of a set of events are all the events' auth events and *their*
auth events, recursively (i.e. the events reachable by walking the graph induced
by an event's auth events links).
The *auth chain difference* of a collection of state sets is the union minus the
intersection of the sets of auth chains corresponding to the state sets, i.e an
event is in the auth chain difference if it is reachable by walking the auth
event graph from at least one of the state sets but not from *all* of the state
sets.
## Breadth First Walk Algorithm
A way of calculating the auth chain difference without calculating the full auth
chains for each state set is to do a parallel breadth first walk (ordered by
depth) of each state set's auth chain. By tracking which events are reachable
from each state set we can finish early if every pending event is reachable from
every state set.
This can work well for state sets that have a small auth chain difference, but
can be very inefficient for larger differences. However, this algorithm is still
used if we don't have a chain cover index for the room (e.g. because we're in
the process of indexing it).
## Chain Cover Index
Synapse computes auth chain differences by pre-computing a "chain cover" index
for the auth chain in a room, allowing efficient reachability queries like "is
event A in the auth chain of event B". This is done by assigning every event a
*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
is in the auth chain of `B`) if and only if either:
1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
sequence number; or
2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
`L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
There are actually two potential implementations, one where we store links from
each chain to every other reachable chain (the transitive closure of the links
graph), and one where we remove redundant links (the transitive reduction of the
links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
would not be stored. Synapse uses the former implementations so that it doesn't
need to recurse to test reachability between chains.
### Example
An example auth graph would look like the following, where chains have been
formed based on type/state_key and are denoted by colour and are labelled with
`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
are those that would be remove in the second implementation described above).
![Example](auth_chain_diff.dot.png)
Note that we don't include all links between events and their auth events, as
most of those links would be redundant. For example, all events point to the
create event, but each chain only needs the one link from it's base to the
create event.
## Using the Index
This index can be used to calculate the auth chain difference of the state sets
by looking at the chain ID and sequence numbers reachable from each state set:
1. For every state set lookup the chain ID/sequence numbers of each state event
2. Use the index to find all chains and the maximum sequence number reachable
from each state set.
3. The auth chain difference is then all events in each chain that have sequence
numbers between the maximum sequence number reachable from *any* state set and
the minimum reachable by *all* state sets (if any).
Note that steps 2 is effectively calculating the auth chain for each state set
(in terms of chain IDs and sequence numbers), and step 3 is calculating the
difference between the union and intersection of the auth chains.
### Worked Example
For example, given the above graph, we can calculate the difference between
state sets consisting of:
1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
Using the index we see that the following auth chains are reachable from each
state set:
1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
And so, for each the ranges that are in the auth chain difference:
1. Chain 1: None, (since everything can reach the create event).
2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
sets and the maximum reachable is `2` (corresponding to Bob's second join).
3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
level).
4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
So the final result is: Bob's second join `(2,2)`, the second power level
`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.

View file

@ -158,6 +158,10 @@ oidc_config:
client_id: "synapse"
client_secret: "copy secret generated from above"
scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
```
### [Auth0][auth0]

View file

@ -214,6 +214,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
^/_matrix/client/(api/v1|r0|unstable)/devices$
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
^/_matrix/client/versions$

View file

@ -103,6 +103,7 @@ files =
tests/replication,
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py

View file

@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"],
"rooms": ["is_public", "has_auth_chain_index"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],

View file

@ -15,16 +15,7 @@
# Stub for frozendict.
from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.

View file

@ -7,17 +7,17 @@ from typing import (
Callable,
Dict,
Hashable,
Iterator,
Iterable,
ItemsView,
Iterable,
Iterator,
KeysView,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Tuple,
Union,
ValuesView,
overload,

View file

@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi.
"""
from typing import List, Optional, Union, Type
from typing import List, Optional, Type, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...

View file

@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.25.0rc1"
__version__ = "1.25.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
@ -186,8 +187,8 @@ class Auth:
AuthError if access is denied for the user in the access token
"""
try:
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.get_user_agent("")
ip_addr = request.getClientIP()
user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request)
@ -275,7 +276,7 @@ class Auth:
return None, None
if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request))
ip_address = IPAddress(request.getClientIP())
if ip_address not in app_service.ip_range_whitelist:
return None, None

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
# Copyright 2019-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,7 +20,7 @@ import signal
import socket
import sys
import traceback
from typing import Iterable
from typing import Awaitable, Callable, Iterable
from typing_extensions import NoReturn
@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn:
sys.exit(1)
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
"""Register a callback with the reactor, to be called once it is running
This can be used to initialise parts of the system which require an asynchronous
setup.
Any exception raised by the callback will be printed and logged, and the process
will exit.
"""
async def wrapper():
try:
await cb(*args, **kwargs)
except Exception:
# previously, we used Failure().printTraceback() here, in the hope that
# would give better tracebacks than traceback.print_exc(). However, that
# doesn't handle chained exceptions (with a __cause__ or __context__) well,
# and I *think* the need for Failure() is reduced now that we mostly use
# async/await.
# Write the exception to both the logs *and* the unredirected stderr,
# because people tend to get confused if it only goes to one or the other.
#
# One problem with this is that if people are using a logging config that
# logs to the console (as is common eg under docker), they will get two
# copies of the exception. We could maybe try to detect that, but it's
# probably a cost we can bear.
logger.fatal("Error during startup", exc_info=True)
print("Error during startup:", file=sys.__stderr__)
traceback.print_exc(file=sys.__stderr__)
# it's no use calling sys.exit here, since that just raises a SystemExit
# exception which is then caught by the reactor, and everything carries
# on as normal.
os._exit(1)
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses, port):
"""
Start Prometheus metrics server.
@ -227,7 +267,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
"""
Start a Synapse server or worker.
@ -241,75 +281,67 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs: homeserver instance
listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
try:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
reactor = hs.get_reactor()
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
for i, args, kwargs in _sighup_callbacks:
i(*args, **kwargs)
sdnotify(b"READY=1")
# We defer running the sighup handlers until next reactor tick. This
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
# `callFromThread` should be "signal safe" as well as thread
# safe.
reactor.callFromThread(handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup)
register_sighup(refresh_certificate, hs)
# Load the certificate from disk.
refresh_certificate(hs)
# Start the tracer
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs
)
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
# Log when we start the shut down process.
hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", logger.info, "Shutting down..."
)
setup_sentry(hs)
setup_sdnotify(hs)
# If background tasks are running on the main process, start collecting the
# phone home stats.
if hs.config.run_background_tasks:
start_phone_stats_home(hs)
# We now freeze all allocated objects in the hopes that (almost)
# everything currently allocated are things that will be used for the
# rest of time. Doing so means less work each GC (hopefully).
#
# This only works on Python 3.7
if sys.version_info >= (3, 7):
gc.collect()
gc.freeze()
except Exception:
traceback.print_exc(file=sys.stderr)
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
reactor = hs.get_reactor()
if reactor.running:
reactor.stop()
sys.exit(1)
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
for i, args, kwargs in _sighup_callbacks:
i(*args, **kwargs)
sdnotify(b"READY=1")
# We defer running the sighup handlers until next reactor tick. This
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
# `callFromThread` should be "signal safe" as well as thread
# safe.
reactor.callFromThread(handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup)
register_sighup(refresh_certificate, hs)
# Load the certificate from disk.
refresh_certificate(hs)
# Start the tracer
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs
)
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
# Log when we start the shut down process.
hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", logger.info, "Shutting down..."
)
setup_sentry(hs)
setup_sdnotify(hs)
# If background tasks are running on the main process, start collecting the
# phone home stats.
if hs.config.run_background_tasks:
start_phone_stats_home(hs)
# We now freeze all allocated objects in the hopes that (almost)
# everything currently allocated are things that will be used for the
# rest of time. Doing so means less work each GC (hopefully).
#
# This only works on Python 3.7
if sys.version_info >= (3, 7):
gc.collect()
gc.freeze()
def setup_sentry(hs):

View file

@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
from twisted.internet import address, reactor
from twisted.internet import address
import synapse
import synapse.events
@ -34,6 +34,7 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
from synapse.app._base import register_start
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
@ -99,14 +100,19 @@ from synapse.rest.client.v1.profile import (
)
from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import (
AccountDataServlet,
RoomAccountDataServlet,
)
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
from synapse.rest.client.v2_alpha.keys import (
KeyChangesServlet,
KeyQueryServlet,
OneTimeKeyServlet,
)
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
@ -115,6 +121,7 @@ from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
@ -446,6 +453,7 @@ class GenericWorkerSlavedStore(
UserDirectoryStore,
StatsStore,
UIAuthWorkerStore,
EndToEndRoomKeyStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
@ -502,7 +510,9 @@ class GenericWorkerServer(HomeServer):
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
OneTimeKeyServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
VoipRestServlet(self).register(resource)
PushRuleRestServlet(self).register(resource)
@ -520,6 +530,7 @@ class GenericWorkerServer(HomeServer):
room.register_servlets(self, resource, True)
room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
room_keys.register_servlets(self, resource)
SendToDeviceRestServlet(self).register(resource)
@ -960,9 +971,7 @@ def start(config_options):
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, hs, config.worker_listeners
)
register_start(_base.start, hs, config.worker_listeners)
_base.start_worker_reactor("synapse-generic-worker", config)

View file

@ -15,15 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import os
import sys
from typing import Iterable, Iterator
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -40,7 +37,7 @@ from synapse.api.urls import (
WEB_CLIENT_PREFIX,
)
from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
from synapse.config._base import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
@ -73,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.homeserver")
@ -417,40 +413,28 @@ def setup(config_options):
_base.refresh_certificate(hs)
async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
await acme.start_listening()
await do_acme()
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
await acme.start_listening()
await do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
await oidc.load_jwks()
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
_base.start(hs, config.listeners)
await _base.start(hs, config.listeners)
hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
hs.get_datastore().db_pool.updates.start_doing_background_updates()
# this gives better tracebacks than traceback.print_exc()
Failure().printTraceback(file=sys.stderr)
if reactor.running:
reactor.stop()
sys.exit(1)
reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
register_start(start)
return hs
@ -487,25 +471,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
e = e.__cause__
class SynapseService(service.Service):
"""
A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
def __init__(self, config):
self.config = config
def startService(self):
hs = setup(self.config)
change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
def stopService(self):
return self._port.stopListening()
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:

View file

@ -56,7 +56,7 @@ def json_error_to_config_error(
"""
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
for p in list(e.absolute_path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Type
import attr
from synapse.config._util import validate_config
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
@ -25,65 +32,32 @@ class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
self.oidc_enabled = False
validate_config(MAIN_CONFIG_SCHEMA, config, ())
self.oidc_provider = None # type: Optional[OidcProviderConfig]
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
if not oidc_config or not oidc_config.get("enabled", False):
if not self.oidc_provider:
return
try:
check_requirements("oidc")
except DependencyException as e:
raise ConfigError(e.message)
raise ConfigError(e.message) from e
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
self.oidc_enabled = True
self.oidc_discover = oidc_config.get("discover", True)
self.oidc_issuer = oidc_config["issuer"]
self.oidc_client_id = oidc_config["client_id"]
self.oidc_client_secret = oidc_config["client_secret"]
self.oidc_client_auth_method = oidc_config.get(
"client_auth_method", "client_secret_basic"
)
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.oidc_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
@property
def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider
return bool(self.oidc_provider)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
@ -224,3 +198,154 @@ class OIDCConfig(Config):
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
"required": ["issuer", "client_id", "client_secret"],
"properties": {
"discover": {"type": "boolean"},
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
# to avoid importing authlib here.
"enum": ["client_secret_basic", "client_secret_post", "none"],
},
"scopes": {"type": "array", "items": {"type": "string"}},
"authorization_endpoint": {"type": "string"},
"token_endpoint": {"type": "string"},
"userinfo_endpoint": {"type": "string"},
"jwks_uri": {"type": "string"},
"skip_verification": {"type": "boolean"},
"user_profile_method": {
"type": "string",
"enum": ["auto", "userinfo_endpoint"],
},
"allow_existing_users": {"type": "boolean"},
"user_mapping_provider": {"type": ["object", "null"]},
},
}
# the `oidc_config` setting can either be None (as it is in the default
# config), or an object. If an object, it is ignored unless it has an "enabled: True"
# property.
#
# It's *possible* to represent this with jsonschema, but the resultant errors aren't
# particularly clear, so we just check for either an object or a null here, and do
# additional checks in the code.
OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
MAIN_CONFIG_SCHEMA = {
"type": "object",
"properties": {"oidc_config": OIDC_CONFIG_SCHEMA},
}
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig
Raises:
ConfigError if the configuration is malformed.
"""
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
ump_config, ("oidc_config", "user_mapping_provider")
)
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
return OidcProviderConfig(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"],
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
token_endpoint=oidc_config.get("token_endpoint"),
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
jwks_uri=oidc_config.get("jwks_uri"),
skip_verification=oidc_config.get("skip_verification", False),
user_profile_method=oidc_config.get("user_profile_method", "auto"),
allow_existing_users=oidc_config.get("allow_existing_users", False),
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
)
@attr.s
class OidcProviderConfig:
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
issuer = attr.ib(type=str)
# oauth2 client id to use
client_id = attr.ib(type=str)
# oauth2 client secret to use
client_secret = attr.ib(type=str)
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'.
client_auth_method = attr.ib(type=str)
# list of scopes to request
scopes = attr.ib(type=Collection[str])
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint = attr.ib(type=Optional[str])
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint = attr.ib(type=Optional[str])
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
userinfo_endpoint = attr.ib(type=Optional[str])
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
jwks_uri = attr.ib(type=Optional[str])
# Whether to skip metadata verification
skip_verification = attr.ib(type=bool)
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
user_profile_method = attr.ib(type=str)
# whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing
allow_existing_users = attr.ib(type=bool)
# the class of the user mapping provider
user_mapping_provider_class = attr.ib(type=Type)
# the config of the user mapping provider
user_mapping_provider_config = attr.ib()

View file

@ -49,8 +49,13 @@ from synapse.api.errors import (
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers._base import BaseHandler
from synapse.handlers.ui_auth import (
INTERACTIVE_AUTH_CHECKERS,
UIAuthSessionDataConstants,
)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@ -284,7 +287,6 @@ class AuthHandler(BaseHandler):
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
clientip: str,
description: str,
) -> Tuple[dict, Optional[str]]:
"""
@ -301,8 +303,6 @@ class AuthHandler(BaseHandler):
request_body: The body of the request sent by the client
clientip: The IP address of the client.
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
@ -338,10 +338,10 @@ class AuthHandler(BaseHandler):
request_body.pop("auth", None)
return request_body, None
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
@ -349,13 +349,16 @@ class AuthHandler(BaseHandler):
)
flows = [[login_type] for login_type in supported_ui_auth_types]
def get_new_session_data() -> JsonDict:
return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
try:
result, params, session_id = await self.check_ui_auth(
flows, request, request_body, clientip, description
flows, request, request_body, description, get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise
# find the completed login type
@ -363,14 +366,14 @@ class AuthHandler(BaseHandler):
if login_type not in result:
continue
user_id = result[login_type]
validated_user_id = result[login_type]
break
else:
# this can't happen
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth")
# Note that the access token has been validated.
@ -402,13 +405,9 @@ class AuthHandler(BaseHandler):
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
if await self.hs.get_sso_handler().get_identity_providers_for_user(
user.to_string()
):
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
@ -426,8 +425,8 @@ class AuthHandler(BaseHandler):
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientip: str,
description: str,
get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@ -448,11 +447,16 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip: The IP address of the client.
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
get_new_session_data:
an optional callback which will be called when starting a new session.
it should return data to be stored as part of the session.
The keys of the returned data should be entries in
UIAuthSessionDataConstants.
Returns:
A tuple of (creds, params, session_id).
@ -480,10 +484,15 @@ class AuthHandler(BaseHandler):
# If there's no session ID, create a new session.
if not sid:
new_session_data = get_new_session_data() if get_new_session_data else {}
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
for k, v in new_session_data.items():
await self.set_session_data(session.session_id, k, v)
else:
try:
session = await self.store.get_ui_auth_session(sid)
@ -539,7 +548,8 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = request.get_user_agent("")
user_agent = get_request_user_agent(request)
clientip = request.getClientIP()
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
@ -644,7 +654,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
key: The key to store the data under. An entry from
UIAuthSessionDataConstants.
value: The data to store
"""
try:
@ -660,7 +671,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
key: The key the data was stored under. An entry from
UIAuthSessionDataConstants.
default: Value to return if the key has not been set
"""
try:
@ -1334,12 +1346,12 @@ class AuthHandler(BaseHandler):
else:
return False
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Args:
redirect_url: The URL to redirect to the SSO provider.
request: The incoming HTTP request
session_id: The user interactive authentication session ID.
Returns:
@ -1349,6 +1361,35 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
user_id_to_verify = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
)
if not idps:
# we checked that the user had some remote identities before offering an SSO
# flow, so either it's been deleted or the client has requested SSO despite
# it not being offered.
raise SynapseError(400, "User has no SSO identities")
# for now, just pick one
idp_id, sso_auth_provider = next(iter(idps.items()))
if len(idps) > 0:
logger.warning(
"User %r has previously logged in with multiple SSO IdPs; arbitrarily "
"picking %r",
user_id_to_verify,
idp_id,
)
redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session_id
)
return self._sso_auth_confirm_template.render(
description=session.description, redirect_url=redirect_url,
)

View file

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester
from synapse.types import Requester, UserID, create_requester
from ._base import BaseHandler
@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
async def deactivate_account(
self, user_id: str, erase_data: bool, id_server: Optional[str] = None
self,
user_id: str,
erase_data: bool,
requester: Requester,
id_server: Optional[str] = None,
by_admin: bool = False,
) -> bool:
"""Deactivate a user's account
Args:
user_id: ID of user to be deactivated
erase_data: whether to GDPR-erase the user's data
requester: The user attempting to make this change.
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
by_admin: Whether this change was made by an administrator.
Returns:
True if identity server supports removing threepids, otherwise False.
@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as erased, if they asked for that
if erase_data:
user = UserID.from_string(user_id)
# Remove avatar URL from this user
await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
# Remove displayname from this user
await self._profile_handler.set_displayname(user, requester, "", by_admin)
logger.info("Marking %s as erased", user_id)
await self.store.mark_user_erased(user_id)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode
import attr
@ -35,7 +35,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.config.oidc_config import OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@ -71,6 +71,131 @@ JWK = Dict[str, str]
JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
provider_conf = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider_conf is not None
self._token_generator = OidcSessionTokenGenerator(hs)
self._provider = OidcProvider(hs, self._token_generator, provider_conf)
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
Called at startup to ensure we have everything we need.
"""
await self._provider.load_metadata()
await self._provider.load_jwks()
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
Once we know the session is legit, we then delegate to the OIDC Provider
implementation, which will exchange the code with the provider and complete the
login/authentication.
Args:
request: the incoming request from the browser.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
code = request.args[b"code"][0].decode()
await self._provider.handle_oidc_callback(request, session_data, code)
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
@ -85,38 +210,47 @@ class OidcError(Exception):
return self.error
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
class OidcProvider:
"""Wraps the config for a single OIDC IdentityProvider
Provides methods for handling redirect requests and callbacks via that particular
IdP.
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
def __init__(
self,
hs: "HomeServer",
token_generator: "OidcSessionTokenGenerator",
provider: OidcProviderConfig,
):
self._store = hs.get_datastore()
self._token_generator = token_generator
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
hs.config.oidc_client_auth_method,
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint,
token_endpoint=hs.config.oidc_token_endpoint,
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
jwks_uri=hs.config.oidc_jwks_uri,
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._provider_needs_discovery = provider.discover
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table
self.idp_id = "oidc"
@ -519,11 +653,13 @@ class OidcHandler(BaseHandler):
if not client_redirect_url:
client_redirect_url = b""
cookie = self._generate_oidc_session_token(
cookie = self._token_generator.generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
session_data=OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
),
)
request.addCookie(
SESSION_COOKIE_NAME,
@ -546,22 +682,16 @@ class OidcHandler(BaseHandler):
nonce=nonce,
)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._sso_handler.render_error`` which displays an HTML page for the error.
By this time we have already validated the session on the synapse side, and
now need to do the provider-specific operations. This includes:
Most of the OpenID Connect logic happens here:
- first, we check if there was any error returned by the provider and
display it
- then we fetch the session cookie, decode and verify it
- the ``state`` query parameter should match with the one stored in the
session cookie
- once we known this session is legit, exchange the code with the
provider using the ``token_endpoint`` (see ``_exchange_code``)
- exchange the code with the provider using the ``token_endpoint`` (see
``_exchange_code``)
- once we have the token, use it to either extract the UserInfo from
the ``id_token`` (``_parse_id_token``), or use the ``access_token``
to fetch UserInfo from the ``userinfo_endpoint``
@ -571,88 +701,12 @@ class OidcHandler(BaseHandler):
Args:
request: the incoming request from the browser.
session_data: the session data, extracted from our cookie
code: The authorization code we got from the callback.
"""
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
# error response from the auth server. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2.1
# https://openid.net/specs/openid-connect-core-1_0.html#AuthError
error = request.args[b"error"][0].decode()
description = request.args.get(b"error_description", [b""])[0].decode()
# Most of the errors returned by the provider could be due by
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
# https://tools.ietf.org/html/rfc6749#section-4.1.2
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
# once, it will fail next time and the code will already be exchanged.
# Removing it early avoids spamming the provider with token requests.
request.addCookie(
SESSION_COOKIE_NAME,
b"",
path="/_synapse/oidc",
expires="Thu, Jan 01 1970 00:00:00 UTC",
httpOnly=True,
sameSite="lax",
)
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
# Deserialize the session token and verify it.
try:
(
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
logger.debug("Exchanging code")
code = request.args[b"code"][0].decode()
try:
logger.debug("Exchanging code")
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
@ -674,14 +728,14 @@ class OidcHandler(BaseHandler):
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# first check if we're doing a UIA
if ui_auth_session_id:
if session_data.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
@ -690,7 +744,7 @@ class OidcHandler(BaseHandler):
return
return await self._sso_handler.complete_sso_ui_auth_request(
self.idp_id, remote_user_id, ui_auth_session_id, request
self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
)
# otherwise, it's a login
@ -698,133 +752,12 @@ class OidcHandler(BaseHandler):
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, client_redirect_url
userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(
self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.clock.time_msec()
return now < expiry
async def _complete_oidc_login(
self,
userinfo: UserInfo,
@ -901,8 +834,8 @@ class OidcHandler(BaseHandler):
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
@ -954,6 +887,148 @@ class OidcHandler(BaseHandler):
return str(remote_user_id)
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._server_name = hs.hostname
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
def generate_oidc_session_token(
self,
state: str,
session_data: "OidcSessionData",
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
session_data: data to include in the session token.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def verify_oidc_session_token(
self, session: bytes, state: str
) -> "OidcSessionData":
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The data extracted from the session cookie
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
)
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
@attr.s(frozen=True, slots=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)

View file

@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
avatar_url_to_set = new_avatar_url # type: Optional[str]
if new_avatar_url == "":
avatar_url_to_set = None
# Same like set_displayname
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
await self.store.set_profile_avatar_url(
target_user.localpart, avatar_url_to_set
)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)

View file

@ -23,6 +23,7 @@ from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@ -166,6 +167,37 @@ class SsoHandler:
"""Get the configured identity providers"""
return self._identity_providers
async def get_identity_providers_for_user(
self, user_id: str
) -> Mapping[str, SsoIdentityProvider]:
"""Get the SsoIdentityProviders which a user has used
Given a user id, get the identity providers that that user has used to log in
with in the past (and thus could use to re-identify themselves for UI Auth).
Args:
user_id: MXID of user to look up
Raises:
a map of idp_id to SsoIdentityProvider
"""
external_ids = await self._store.get_external_ids_by_user(user_id)
valid_idps = {}
for idp_id, _ in external_ids:
idp = self._identity_providers.get(idp_id)
if not idp:
logger.warning(
"User %r has an SSO mapping for IdP %r, but this is no longer "
"configured.",
user_id,
idp_id,
)
else:
valid_idps[idp_id] = idp
return valid_idps
def render_error(
self,
request: Request,
@ -362,7 +394,7 @@ class SsoHandler:
attributes,
auth_provider_id,
remote_user_id,
request.get_user_agent(""),
get_request_user_agent(request),
request.getClientIP(),
)
@ -628,7 +660,7 @@ class SsoHandler:
attributes,
session.auth_provider_id,
session.remote_user_id,
request.get_user_agent(""),
get_request_user_agent(request),
request.getClientIP(),
)

View file

@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
class UIAuthSessionDataConstants:
"""Constants for use with AuthHandler.set_session_data"""
# used during registration and password reset to store a hashed copy of the
# password, so that the client does not need to submit it each time.
PASSWORD_HASH = "password_hash"
# used during registration to store the mxid of the registered user
REGISTERED_USER_ID = "registered_user_id"
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"

View file

@ -17,6 +17,7 @@ import re
from twisted.internet import task
from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest
from synapse.api.errors import SynapseError
@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
pass
def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default.
"""
# There could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
h = request.getHeader(b"User-Agent")
return h.decode("ascii", "replace") if h else default

View file

@ -32,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
from netaddr import IPAddress, IPSet
from netaddr import AddrFormatError, IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@ -261,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
try:
ip_address = IPAddress(h.hostname)
except AddrFormatError:
# Not an IP
pass
else:
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
# Not an IP
pass
return self._agent.request(
method, uri, headers=headers, bodyProducer=bodyProducer
@ -341,6 +341,7 @@ class SimpleHttpClient:
self.agent = ProxyAgent(
self.reactor,
hs.get_reactor(),
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,

View file

@ -102,7 +102,6 @@ class MatrixFederationAgent:
pool=self._pool,
contextFactory=tls_client_options_factory,
),
self._reactor,
ip_blacklist=ip_blacklist,
),
user_agent=self.user_agent,

View file

@ -174,6 +174,16 @@ async def _handle_json_response(
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = await make_deferred_yieldable(d)
except ValueError as e:
# The JSON content was invalid.
logger.warning(
"{%s} [%s] Failed to parse JSON response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=False) from e
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",

View file

@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase):
reactor: twisted reactor to place outgoing
connections.
proxy_reactor: twisted reactor to use for connections to the proxy server
reactor might have some blacklisting applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
verification parameters of OpenSSL. The default is to use a
`BrowserLikePolicyForHTTPS`, so unless you have special
@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase):
def __init__(
self,
reactor,
proxy_reactor=None,
contextFactory=BrowserLikePolicyForHTTPS(),
connectTimeout=None,
bindAddress=None,
@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase):
):
_AgentBase.__init__(self, reactor, pool)
if proxy_reactor is None:
self.proxy_reactor = reactor
else:
self.proxy_reactor = proxy_reactor
self._endpoint_kwargs = {}
if connectTimeout is not None:
self._endpoint_kwargs["timeout"] = connectTimeout
@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase):
self._endpoint_kwargs["bindAddress"] = bindAddress
self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, reactor, **self._endpoint_kwargs
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
self.https_proxy_endpoint = _http_proxy_endpoint(
https_proxy, reactor, **self._endpoint_kwargs
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
self._policy_for_https = contextFactory
@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase):
request_path = uri
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
endpoint = HTTPConnectProxyEndpoint(
self._reactor,
self.proxy_reactor,
self.https_proxy_endpoint,
parsed_uri.host,
parsed_uri.port,

View file

@ -20,7 +20,7 @@ from twisted.python.failure import Failure
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.types import Requester
@ -113,15 +113,6 @@ class SynapseRequest(Request):
method = self.method.decode("ascii")
return method
def get_user_agent(self, default: str) -> str:
"""Return the last User-Agent header, or the given default.
"""
user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
if user_agent is None:
return default
return user_agent.decode("ascii", "replace")
def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@ -292,12 +283,7 @@ class SynapseRequest(Request):
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
user_agent = self.get_user_agent("-")
user_agent = get_request_user_agent(self, "-")
code = str(self.code)
if not self.finished:

View file

@ -396,31 +396,30 @@ class Notifier:
Will wake up all listeners for the given users and rooms.
"""
with PreserveLoggingContext():
with Measure(self.clock, "on_new_event"):
user_streams = set()
with Measure(self.clock, "on_new_event"):
user_streams = set()
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
except Exception:
logger.exception("Failed to notify listener")
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
except Exception:
logger.exception("Failed to notify listener")
self.notify_replication()
self.notify_replication()
# Notify appservices
self._notify_app_services_ephemeral(
stream_key, new_token, users,
)
# Notify appservices
self._notify_app_services_ephemeral(
stream_key, new_token, users,
)
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened

View file

@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet):
if deactivate and not user["deactivated"]:
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine(UserID.from_string(target_user_id)):
raise SynapseError(400, "Can only deactivate local users")
if not await self.store.get_user_by_id(target_user_id):
raise NotFoundError("User not found")
async def on_POST(self, request, target_user_id):
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
UserID.from_string(target_user_id)
result = await self._deactivate_account_handler.deactivate_account(
target_user_id, erase
target_user_id, erase, requester, by_admin=True
)
if result:
id_server_unbind_result = "success"
@ -714,13 +722,6 @@ class UserMembershipRestServlet(RestServlet):
async def on_GET(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
raise NotFoundError("Unknown user")
room_ids = await self.store.get_rooms_for_user(user_id)
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
return 200, ret

View file

@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet):
# register themselves with the main SSOHandler.
if hs.config.cas_enabled:
hs.get_cas_handler()
elif hs.config.saml2_enabled:
if hs.config.saml2_enabled:
hs.get_saml_handler()
elif hs.config.oidc_enabled:
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()

View file

@ -20,9 +20,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@ -31,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
requester, request, body, "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
user_id = requester.user.to_string()
@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
password_hash = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
else:
# UI validation was skipped, but the request did not include a new
@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
requester.user.to_string(), erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"deactivate your account",
requester, request, body, "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
requester.user.to_string(),
erase,
requester,
id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"add a third-party identifier to your account",
requester, request, body, "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(

View file

@ -19,7 +19,6 @@ from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string
@ -46,22 +45,6 @@ class AuthRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._oidc_enabled = hs.config.oidc_enabled
if self._oidc_enabled:
self._oidc_handler = hs.get_oidc_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template
@ -90,21 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider.
if self._cas_enabled:
sso_auth_provider = self._cas_handler # type: SsoIdentityProvider
elif self._saml_enabled:
sso_auth_provider = self._saml_handler
elif self._oidc_enabled:
sso_auth_provider = self._oidc_handler
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
sso_redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session
)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(request, session)
else:
raise SynapseError(404, "Unknown auth stage type")
@ -128,7 +97,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
@ -144,7 +113,7 @@ class AuthRestServlet(RestServlet):
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
LoginType.TERMS, authdict, request.getClientIP()
)
if success:

View file

@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"remove device(s) from your account",
requester, request, body, "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"remove a device from your account",
requester, request, body, "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)

View file

@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"add a device signing key to your account",
requester, request, body, "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)

View file

@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
ip = self.hs.get_ip_from_request(request)
ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
# Ensure that the username is valid.
@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
self._registration_flows, request, body, "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
session_id,
UIAuthSessionDataConstants.REGISTERED_USER_ID,
registered_user_id,
)
registered = True

View file

@ -283,10 +283,6 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
return self._reactor
def get_ip_from_request(self, request) -> str:
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
return domain_specific_string.domain == self.hostname
@ -505,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return InitialSyncHandler(self)
@cache_in_self
def get_profile_handler(self):
def get_profile_handler(self) -> ProfileHandler:
return ProfileHandler(self)
@cache_in_self

View file

@ -179,6 +179,9 @@ class LoggingDatabaseConnection:
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@ -266,6 +269,20 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
Always sets fetch=True when caling `execute_values`, so will return the
results.
"""
assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values # type: ignore
return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
)
def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
@ -276,7 +293,7 @@ class LoggingTransaction:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute(self, func, sql: str, *args: Any) -> None:
def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@ -347,9 +364,6 @@ class PerformanceCounters:
return top_n_counters
R = TypeVar("R")
class DatabasePool:
"""Wraps a single physical database and connection pool.

View file

@ -312,12 +312,9 @@ class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"account_data_max_stream_id",
"room_account_data",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
extra_tables=[("room_tags_revisions", "stream_id")],
)
super().__init__(database, db_conn, hs)
@ -362,14 +359,6 @@ class AccountDataStore(AccountDataWorkerStore):
lock=False,
)
# it's theoretically possible for the above to succeed and the
# below to fail - in which case we might reuse a stream id on
# restart, and the above update might not get propagated. That
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
@ -402,18 +391,6 @@ class AccountDataStore(AccountDataWorkerStore):
content,
)
# it's theoretically possible for the above to succeed and the
# below to fail - in which case we might reuse a stream id on
# restart, and the above update might not get propagated. That
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
#
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
@ -486,24 +463,3 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
next_id: The the revision to advance to.
"""
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

View file

@ -407,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on.
The result might be slightly out of date as client IPs are inserted in batches.
Args:
user_id: The user to fetch devices for.
device_id: If None fetches all devices for the user
Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
return {(d["user_id"], d["device_id"]): d for d in res}
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@ -470,43 +498,35 @@ class ClientIpStore(ClientIpWorkerStore):
for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
self.db_pool.simple_upsert_txn(
self.db_pool.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
values={
"user_agent": user_agent,
"device_id": device_id,
"last_seen": last_seen,
},
lock=False,
)
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db_pool.simple_update_txn(
txn,
table="user_ips",
keyvalues={
"user_id": user_id,
"access_token": access_token,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
values={
"user_agent": user_agent,
"device_id": device_id,
"last_seen": last_seen,
},
lock=False,
)
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db_pool.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
)
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
@ -520,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
ret = await super().get_last_client_ip_by_device(user_id, device_id)
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
# Update what is retrieved from the database with data which is pending insertion.
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:

View file

@ -707,50 +707,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
"""Get the current stream id from the _device_list_id_gen"""
...
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
@ -840,6 +796,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(

View file

@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
The set of the difference in auth chains.
"""
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
self._get_auth_chain_difference_using_cover_index_txn,
room_id,
state_sets,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
See docs/auth_chain_difference_algorithm.md for details
"""
# First we look up the chain ID/sequence numbers for all the events, and
# work out the chain/sequence numbers reachable from each state set.
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Map from event_id -> (chain ID, seq no)
chain_info = {} # type: Dict[str, Tuple[int, int]]
# Map from chain ID -> seq no -> event Id
chain_to_event = {} # type: Dict[int, Dict[int, str]]
# All the chains that we've found that are reachable from the state
# sets.
seen_chains = set() # type: Set[int]
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
set_to_chain = [] # type: List[Dict[int, int]]
for state_set in state_sets:
chains = {} # type: Dict[int, int]
set_to_chain.append(chains)
for event_id in state_set:
chain_id, seq_no = chain_info[event_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
# Now we look up all links for the chains we have, adding chains to
# set_to_chain that are reachable from each set.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
for chains in set_to_chain:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number, chains.get(target_chain_id, 0),
)
seen_chains.add(target_chain_id)
# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
if min_seq_no < max_seq_no:
# We have a non empty gap, try and fill it from the events that
# we have, otherwise add them to the list of gaps to pull out
# from the DB.
for seq_no in range(min_seq_no + 1, max_seq_no + 1):
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
if event_id:
result.add(event_id)
else:
chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
break
if not chain_to_gap:
# If there are no gaps to fetch, we're done!
return result
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
WHERE
c.chain_id = l.chain_id
AND min_seq < sequence_number AND sequence_number <= max_seq
"""
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
]
rows = txn.execute_values(sql, args)
result.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
"""
for chain_id, (min_no, max_no) in chain_to_gap.items():
txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for r, in txn)
return result
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
This is used when we don't have a cover index for the room.
"""
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~

View file

@ -17,7 +17,17 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
)
import attr
from prometheus_client import Counter
@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
from synapse.util.iterutils import batch_iter, sorted_topologically
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -89,6 +100,14 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self._event_chain_id_gen = build_sequence_generator(
db.engine, get_chain_id_txn, "event_auth_chain_id"
)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@ -366,26 +385,7 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for event, _ in events_and_contexts
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@ -407,6 +407,403 @@ class PersistEventsStore:
# room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _persist_event_auth_chain_txn(
self, txn: LoggingTransaction, events: List[EventBase],
) -> None:
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
rows = self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
)
rooms_using_chain_index = {
row["room_id"] for row in rows if row["has_auth_chain_index"]
}
state_events = {
event.event_id: event
for event in events
if event.is_state() and event.room_id in rooms_using_chain_index
}
if not state_events:
return
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {
e.event_id: (e.type, e.state_key) for e in state_events.values()
}
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn, event_to_room_id, event_to_types, event_to_auth_chain
)
def _add_chain_cover_index(
self,
txn,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
) -> None:
"""Calculate the chain cover index for the given events.
Args:
event_to_room_id: Event ID to the room ID of the event
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
"""
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
# Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(event_to_room_id)
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
rows = self.db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"),
)
for row in rows:
event_id = row["event_id"]
event_type = row["type"]
state_key = row["state_key"]
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
auth_events = self.db_pool.simple_select_onecol_txn(
txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
)
events_to_calc_chain_id_for.add(event_id)
event_to_types[event_id] = (event_type, state_key)
event_to_auth_chain[event_id] = auth_events
# First we get the chain ID and sequence numbers for the events'
# auth events (that aren't also currently being persisted).
#
# Note that there there is an edge case here where we might not have
# calculated chains and sequence numbers for events that were "out
# of band". We handle this case by fetching the necessary info and
# adding it to the set of events to calculate chain IDs for.
missing_auth_chains = {
a_id
for auth_events in event_to_auth_chain.values()
for a_id in auth_events
if a_id not in events_to_calc_chain_id_for
}
# We loop here in case we find an out of band membership and need to
# fetch their auth event info.
while missing_auth_chains:
sql = """
SELECT event_id, events.type, state_key, chain_id, sequence_number
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", missing_auth_chains,
)
txn.execute(sql + clause, args)
missing_auth_chains.clear()
for auth_id, event_type, state_key, chain_id, sequence_number in txn:
event_to_types[auth_id] = (event_type, state_key)
if chain_id is None:
# No chain ID, so the event was persisted out of band.
# We add to list of events to calculate auth chains for.
events_to_calc_chain_id_for.add(auth_id)
event_to_auth_chain[
auth_id
] = self.db_pool.simple_select_onecol_txn(
txn,
"event_auth",
keyvalues={"event_id": auth_id},
retcol="auth_id",
)
missing_auth_chains.update(
e
for e in event_to_auth_chain[auth_id]
if e not in event_to_types
)
else:
chain_map[auth_id] = (chain_id, sequence_number)
# Now we check if we have any events where we don't have auth chain,
# this should only be out of band memberships.
for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
for auth_id in event_to_auth_chain[event_id]:
if (
auth_id not in chain_map
and auth_id not in events_to_calc_chain_id_for
):
events_to_calc_chain_id_for.discard(event_id)
# If this is an event we're trying to persist we add it to
# the list of events to calculate chain IDs for next time
# around. (Otherwise we will have already added it to the
# table).
room_id = event_to_room_id.get(event_id)
if room_id:
e_type, state_key = event_to_types[event_id]
self.db_pool.simple_insert_txn(
txn,
table="event_auth_chain_to_calculate",
values={
"event_id": event_id,
"room_id": room_id,
"type": e_type,
"state_key": state_key,
},
)
# We stop checking the event's auth events since we've
# discarded it.
break
if not events_to_calc_chain_id_for:
return
# We now calculate the chain IDs/sequence numbers for the events. We
# do this by looking at the chain ID and sequence number of any auth
# event with the same type/state_key and incrementing the sequence
# number by one. If there was no match or the chain ID/sequence
# number is already taken we generate a new chain.
#
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events
# before the event itself.
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
if not new_chain_tuple:
new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
values=[
{"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
self.db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
iterable=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by
# the new events.
#
# Links are pairs of chain ID/sequence numbers such that for any
# event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
# if and only if there is at least one link (CA, S1) -> (CB, S2)
# where SA >= S1 and S2 >= SB.
#
# We try and avoid adding redundant links to the table, e.g. if we
# have two links between two chains which both start/end at the
# sequence number event (or cross) then one can be safely dropped.
#
# To calculate new links we look at every new event and:
# 1. Fetch the chain ID/sequence numbers of its auth events,
# discarding any that are reachable by other auth events, or
# that have the same chain ID as the event.
# 2. For each retained auth event we:
# a. Add a link from the event's to the auth event's chain
# ID/sequence number; and
# b. Add a link from the event to every chain reachable by the
# auth event.
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
rows = self.db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
iterable={chain_id for chain_id, _ in chain_map.values()},
keyvalues={},
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
)
for row in rows:
chain_links.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]),
(row["target_chain_id"], row["target_sequence_number"]),
new=False,
)
# We do this in toplogical order to avoid adding redundant links.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
chain_id, sequence_number = chain_map[event_id]
# Filter out auth events that are reachable by other auth
# events. We do this by looking at every permutation of pairs of
# auth events (A, B) to check if B is reachable from A.
reduction = {
a_id
for a_id in event_to_auth_chain.get(event_id, [])
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
event_to_auth_chain.get(event_id, []), r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]
):
reduction.discard(end_auth_id)
# Step 2, figure out what the new links are from the reduced
# list of auth events.
for auth_id in reduction:
auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event
chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
# Step 2b, add a link to chains reachable from the auth
# event.
for target_id, target_seq in chain_links.get_links_from(
(auth_chain_id, auth_sequence_number)
):
if target_id == chain_id:
continue
chain_links.add_link(
(chain_id, sequence_number), (target_id, target_seq)
)
self.db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
values=[
{
"origin_chain_id": source_id,
"origin_sequence_number": source_seq,
"target_chain_id": target_id,
"target_sequence_number": target_seq,
}
for (
source_id,
source_seq,
target_id,
target_seq,
) in chain_links.get_additions()
],
)
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
@ -799,7 +1196,8 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
"""Insert new events into the event, event_json, redaction and
state_events tables.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
@ -871,6 +1269,29 @@ class PersistEventsStore:
updatevalues={"have_censored": False},
)
state_events_and_contexts = [
ec for ec in events_and_contexts if ec[0].is_state()
]
state_values = []
for event, context in state_events_and_contexts:
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
state_values.append(vals)
self.db_pool.simple_insert_many_txn(
txn, table="state_events", values=state_values
)
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
@ -987,29 +1408,6 @@ class PersistEventsStore:
txn, [event for event, _ in events_and_contexts]
)
state_events_and_contexts = [
ec for ec in events_and_contexts if ec[0].is_state()
]
state_values = []
for event, context in state_events_and_contexts:
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
state_values.append(vals)
self.db_pool.simple_insert_many_txn(
txn, table="state_events", values=state_values
)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@ -1520,3 +1918,131 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
@attr.s(slots=True)
class _LinkMap:
"""A helper type for tracking links between chains.
"""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
# Stores the links that have been added (with new set to true), as tuples of
# `(source chain ID, source sequence no, target chain ID, target sequence no.)`
additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
def add_link(
self,
src_tuple: Tuple[int, int],
target_tuple: Tuple[int, int],
new: bool = True,
) -> bool:
"""Add a new link between two chains, ensuring no redundant links are added.
New links should be added in topological order.
Args:
src_tuple: The chain ID/sequence number of the source of the link.
target_tuple: The chain ID/sequence number of the target of the link.
new: Whether this is a "new" link, i.e. should it be returned
by `get_additions`.
Returns:
True if a link was added, false if the given link was dropped as redundant
"""
src_chain, src_seq = src_tuple
target_chain, target_seq = target_tuple
current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
assert src_chain != target_chain
if new:
# Check if the new link is redundant
for current_seq_src, current_seq_target in current_links.items():
# If a link "crosses" another link then its redundant. For example
# in the following link 1 (L1) is redundant, as any event reachable
# via L1 is *also* reachable via L2.
#
# Chain A Chain B
# | |
# L1 |------ |
# | | |
# L2 |---- | -->|
# | | |
# | |--->|
# | |
# | |
#
# So we only need to keep links which *do not* cross, i.e. links
# that both start and end above or below an existing link.
#
# Note, since we add links in topological ordering we should never
# see `src_seq` less than `current_seq_src`.
if current_seq_src <= src_seq and target_seq <= current_seq_target:
# This new link is redundant, nothing to do.
return False
self.additions.add((src_chain, src_seq, target_chain, target_seq))
current_links[src_seq] = target_seq
return True
def get_links_from(
self, src_tuple: Tuple[int, int]
) -> Generator[Tuple[int, int], None, None]:
"""Gets the chains reachable from the given chain/sequence number.
Yields:
The chain ID and sequence number the link points to.
"""
src_chain, src_seq = src_tuple
for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
for link_src_seq, target_seq in sequence_numbers.items():
if link_src_seq <= src_seq:
yield target_id, target_seq
def get_links_between(
self, source_chain: int, target_chain: int
) -> Generator[Tuple[int, int], None, None]:
"""Gets the links between two chains.
Yields:
The source and target sequence numbers.
"""
yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
"""Gets any newly added links.
Yields:
The source chain ID/sequence number and target chain ID/sequence number
"""
for src_chain, src_seq, target_chain, _ in self.additions:
target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
if target_seq is not None:
yield (src_chain, src_seq, target_chain, target_seq)
def exists_path_from(
self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
) -> bool:
"""Checks if there is a path between the source chain ID/sequence and
target chain ID/sequence.
"""
src_chain, src_seq = src_tuple
target_chain, target_seq = target_tuple
if src_chain == target_chain:
return target_seq <= src_seq
links = self.get_links_between(src_chain, target_chain)
for link_start_seq, link_end_seq in links:
if link_start_seq <= src_seq and target_seq <= link_end_seq:
return True
return False

View file

@ -14,10 +14,15 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -99,6 +104,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
columns=["user_id", "created_ts"],
)
self.db_pool.updates.register_background_update_handler(
"rejected_events_metadata", self._rejected_events_metadata,
)
self.db_pool.updates.register_background_update_handler(
"chain_cover", self._chain_cover_index,
)
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@ -582,3 +595,302 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
"""Adds rejected events to the `state_events` and `event_auth` metadata
tables.
"""
last_event_id = progress.get("last_event_id", "")
def get_rejected_events(
txn: Cursor,
) -> List[Tuple[str, str, JsonDict, bool, bool]]:
# Fetch rejected event json, their room version and whether we have
# inserted them into the state_events or auth_events tables.
#
# Note we can assume that events that don't have a corresponding
# room version are V1 rooms.
sql = """
SELECT DISTINCT
event_id,
COALESCE(room_version, '1'),
json,
state_events.event_id IS NOT NULL,
event_auth.event_id IS NOT NULL
FROM rejections
INNER JOIN event_json USING (event_id)
LEFT JOIN rooms USING (room_id)
LEFT JOIN state_events USING (event_id)
LEFT JOIN event_auth USING (event_id)
WHERE event_id > ?
ORDER BY event_id
LIMIT ?
"""
txn.execute(sql, (last_event_id, batch_size,))
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events
)
if not results:
await self.db_pool.updates._end_background_update(
"rejected_events_metadata"
)
return 0
state_events = []
auth_events = []
for event_id, room_version, event_json, has_state, has_event_auth in results:
last_event_id = event_id
if has_state and has_event_auth:
continue
room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
if not room_version_obj:
# We no longer support this room version, so we just ignore the
# events entirely.
logger.info(
"Ignoring event with unknown room version %r: %r",
room_version,
event_id,
)
continue
event = make_event_from_dict(event_json, room_version_obj)
if not event.is_state():
continue
if not has_state:
state_events.append(
{
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
)
if not has_event_auth:
for auth_id in event.auth_event_ids():
auth_events.append(
{
"room_id": event.room_id,
"event_id": event.event_id,
"auth_id": auth_id,
}
)
if state_events:
await self.db_pool.simple_insert_many(
table="state_events",
values=state_events,
desc="_rejected_events_metadata_state_events",
)
if auth_events:
await self.db_pool.simple_insert_many(
table="event_auth",
values=auth_events,
desc="_rejected_events_metadata_event_auth",
)
await self.db_pool.updates._background_update_progress(
"rejected_events_metadata", {"last_event_id": last_event_id}
)
if len(results) < batch_size:
await self.db_pool.updates._end_background_update(
"rejected_events_metadata"
)
return len(results)
async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
"""A background updates that iterates over all rooms and generates the
chain cover index for them.
"""
current_room_id = progress.get("current_room_id", "")
# Have we finished processing the current room.
finished = progress.get("finished", True)
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
# Have we set the `has_auth_chain_index` for the room yet.
has_set_room_has_chain_index = progress.get(
"has_set_room_has_chain_index", False
)
if finished:
# If we've finished with the previous room (or its our first
# iteration) we move on to the next room.
def _get_next_room(txn: Cursor) -> Optional[str]:
sql = """
SELECT room_id FROM rooms
WHERE room_id > ?
AND (
NOT has_auth_chain_index
OR has_auth_chain_index IS NULL
)
ORDER BY room_id
LIMIT 1
"""
txn.execute(sql, (current_room_id,))
row = txn.fetchone()
if row:
return row[0]
return None
current_room_id = await self.db_pool.runInteraction(
"_chain_cover_index", _get_next_room
)
if not current_room_id:
await self.db_pool.updates._end_background_update("chain_cover")
return 0
logger.debug("Adding chain cover to %s", current_room_id)
def _calculate_auth_chain(
txn: Cursor, last_depth: int, last_stream: int
) -> Tuple[int, int, int]:
# Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order.
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine,
[
("topological_ordering", last_depth),
("stream_ordering", last_stream),
],
)
sql = """
SELECT
event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE events.room_id = ?
AND event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s
ORDER BY topological_ordering, stream_ordering
LIMIT ?
""" % {
"tuple_cmp": tuple_clause,
}
args = [current_room_id]
args.extend(tuple_args)
args.append(batch_size)
txn.execute(sql, args)
rows = txn.fetchall()
# Put the results in the necessary format for
# `_add_chain_cover_index`
event_to_room_id = {row[0]: current_room_id for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
count = len(rows)
# We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
)
event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(
row["auth_id"]
)
# Calculate and persist the chain cover index for this set of events.
#
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
self.hs.get_datastores().persist_events._add_chain_cover_index(
txn, event_to_room_id, event_to_types, event_to_auth_chain,
)
return new_last_depth, new_last_stream, count
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed = count
if count < batch_size and not has_set_room_has_chain_index:
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": current_room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed += count
# Note that at this point its technically possible that more events
# than our `batch_size` have been persisted without their chain
# cover, so we need to continue processing this room if the last
# count returned was equal to the `batch_size`.
if count < batch_size:
# We've finished calculating the index for this room, move on to the
# next room.
await self.db_pool.updates._background_update_progress(
"chain_cover", {"current_room_id": current_room_id, "finished": True},
)
else:
# We still have outstanding events to calculate the index for.
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
"has_auth_chain_index": has_set_room_has_chain_index,
"finished": False,
},
)
return total_rows_processed

View file

@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_avatar_url(
self, user_localpart: str, new_avatar_url: str
self, user_localpart: str, new_avatar_url: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",

View file

@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
)
@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# It's overridden by RoomStore for the synapse master.
raise NotImplementedError()
async def has_auth_chain_index(self, room_id: str) -> bool:
"""Check if the room has (or can have) a chain cover index.
Defaults to True if we don't have an entry in `rooms` table nor any
events for the room.
"""
has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
table="rooms",
keyvalues={"room_id": room_id},
retcol="has_auth_chain_index",
desc="has_auth_chain_index",
allow_none=True,
)
if has_auth_chain_index:
return True
# It's possible that we already have events for the room in our DB
# without a corresponding room entry. If we do then we don't want to
# mark the room as having an auth chain cover index.
max_ordering = await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"room_id": room_id},
retcol="MAX(stream_ordering)",
allow_none=True,
desc="upsert_room_on_join",
)
return max_ordering is None
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version
currently in the table.
"""
# It's possible that we already have events for the room in our DB
# without a corresponding room entry. If we do then we don't want to
# mark the room as having an auth chain cover index.
has_auth_chain_index = await self.has_auth_chain_index(room_id)
await self.db_pool.simple_upsert(
desc="upsert_room_on_join",
table="rooms",
keyvalues={"room_id": room_id},
values={"room_version": room_version.identifier},
insertion_values={"is_public": False, "creator": ""},
insertion_values={
"is_public": False,
"creator": "",
"has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
lock=False,
@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"creator": room_creator_user_id,
"is_public": is_public,
"room_version": room_version.identifier,
"has_auth_chain_index": True,
},
)
if is_public:
@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
"""
# It's possible that we already have events for the room in our DB
# without a corresponding room entry. If we do then we don't want to
# mark the room as having an auth chain cover index.
has_auth_chain_index = await self.has_auth_chain_index(room_id)
await self.db_pool.simple_upsert(
desc="maybe_store_room_on_outlier_membership",
table="rooms",
@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"room_version": room_version.identifier,
"is_public": False,
"creator": "",
"has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.

View file

@ -0,0 +1,16 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE access_tokens DROP COLUMN last_used;

View file

@ -0,0 +1,62 @@
/*
* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Dropping last_used column from access_tokens table.
CREATE TABLE access_tokens2 (
id BIGINT PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
valid_until_ms BIGINT,
puppets_user_id TEXT,
last_validated BIGINT,
UNIQUE(token)
);
INSERT INTO access_tokens2(id, user_id, device_id, token)
SELECT id, user_id, device_id, token FROM access_tokens;
DROP TABLE access_tokens;
ALTER TABLE access_tokens2 RENAME TO access_tokens;
CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
-- Re-adding foreign key reference in event_txn_id table
CREATE TABLE event_txn_id2 (
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
token_id BIGINT NOT NULL,
txn_id TEXT NOT NULL,
inserted_ts BIGINT NOT NULL,
FOREIGN KEY (event_id)
REFERENCES events (event_id) ON DELETE CASCADE,
FOREIGN KEY (token_id)
REFERENCES access_tokens (id) ON DELETE CASCADE
);
INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
DROP TABLE event_txn_id;
ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);

View file

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5828, 'rejected_events_metadata', '{}');

View file

@ -0,0 +1,52 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- See docs/auth_chain_difference_algorithm.md
CREATE TABLE event_auth_chains (
event_id TEXT PRIMARY KEY,
chain_id BIGINT NOT NULL,
sequence_number BIGINT NOT NULL
);
CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
CREATE TABLE event_auth_chain_links (
origin_chain_id BIGINT NOT NULL,
origin_sequence_number BIGINT NOT NULL,
target_chain_id BIGINT NOT NULL,
target_sequence_number BIGINT NOT NULL
);
CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
-- Events that we have persisted but not calculated auth chains for,
-- e.g. out of band memberships (where we don't have the auth chain)
CREATE TABLE event_auth_chain_to_calculate (
event_id TEXT PRIMARY KEY,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL
);
CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
-- Whether we've calculated the above index for a room.
ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;

View file

@ -0,0 +1,16 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;

View file

@ -0,0 +1,17 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- This is no longer used and was only kept until we bumped the schema version.
DROP TABLE IF EXISTS account_data_max_stream_id;

View file

@ -0,0 +1,17 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- This is no longer used and was only kept until we bumped the schema version.
DROP TABLE IF EXISTS cache_invalidation_stream;

View file

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5906, 'chain_cover', '{}', 'rejected_events_metadata');

View file

@ -255,16 +255,6 @@ class TagsStore(TagsWorkerStore):
self._account_data_stream_cache.entity_has_changed, user_id, next_id
)
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
update_sql = (
"UPDATE room_tags_revisions"
" SET stream_id = ?"

View file

@ -35,9 +35,6 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
# XXX: If you're about to bump this to 59 (or higher) please create an update
# that drops the unused `cache_invalidation_stream` table, as per #7436!
# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
SCHEMA_VERSION = 59
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -37,6 +37,7 @@ from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
from synapse.http.endpoint import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
@ -257,8 +258,13 @@ class DomainSpecificString(
@classmethod
def is_valid(cls: Type[DS], s: str) -> bool:
"""Parses the input string and attempts to ensure it is valid."""
try:
cls.from_string(s)
obj = cls.from_string(s)
# Apply additional validation to the domain. This is only done
# during is_valid (and not part of from_string) since it is
# possible for invalid data to exist in room-state, etc.
parse_and_validate_server_name(obj.domain)
return True
except Exception:
return False

View file

@ -13,8 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from itertools import islice
from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
from typing import (
Dict,
Generator,
Iterable,
Iterator,
Mapping,
Sequence,
Set,
Tuple,
TypeVar,
)
from synapse.types import Collection
T = TypeVar("T")
@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
If the input is empty, no chunks are returned.
"""
return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
def sorted_topologically(
nodes: Iterable[T], graph: Mapping[T, Collection[T]],
) -> Generator[T, None, None]:
"""Given a set of nodes and a graph, yield the nodes in toplogical order.
For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
"""
# This is implemented by Kahn's algorithm.
degree_map = {node: 0 for node in nodes}
reverse_graph = {} # type: Dict[T, Set[T]]
for node, edges in graph.items():
if node not in degree_map:
continue
for edge in edges:
if edge in degree_map:
degree_map[node] += 1
reverse_graph.setdefault(edge, set()).add(node)
reverse_graph.setdefault(node, set())
zero_degree = [node for node, degree in degree_map.items() if degree == 0]
heapq.heapify(zero_degree)
while zero_degree:
node = heapq.heappop(zero_degree)
yield node
for edge in reverse_graph[node]:
if edge in degree_map:
degree_map[edge] -= 1
if degree_map[edge] == 0:
heapq.heappush(zero_degree, edge)

Some files were not shown because too many files have changed in this diff Show more