Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-05-29 13:27:12 +01:00
commit 6fdf5ef66b
12 changed files with 138 additions and 102 deletions

1
changelog.d/7385.feature Normal file
View file

@ -0,0 +1 @@
For SAML authentication, add the ability to pass email addresses to be added to new users' accounts via SAML attributes. Contributed by Christopher Cooper.

1
changelog.d/7561.misc Normal file
View file

@ -0,0 +1 @@
Convert the identity handler to async/await.

1
changelog.d/7575.bugfix Normal file
View file

@ -0,0 +1 @@
Fix str placeholders in an instance of `PrepareDatabaseException`. Introduced in Synapse v1.8.0.

1
changelog.d/7591.misc Normal file
View file

@ -0,0 +1 @@
Add comment to systemd example to show postgresql dependency.

1
changelog.d/7597.bugfix Normal file
View file

@ -0,0 +1 @@
Fix metrics failing when there is a large number of active background processes.

View file

@ -15,6 +15,9 @@
[Unit]
Description=Synapse Matrix homeserver
# If you are using postgresql to persist data, uncomment this line to make sure
# synapse starts after the postgresql service.
# After=postgresql.service
[Service]
Type=notify

View file

@ -138,6 +138,8 @@ A custom mapping provider must specify the following methods:
* `mxid_localpart` - Required. The mxid localpart of the new user.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
* `emails` - A list of emails for the new user. If not provided, will
default to an empty list.
### Default SAML Mapping Provider

View file

@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client()
self.hs = hs
@defer.inlineCallbacks
def threepid_from_creds(self, id_server, creds):
async def threepid_from_creds(self, id_server, creds):
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler):
url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
try:
data = yield self.http_client.get_json(url, query_params)
data = await self.http_client.get_json(url, query_params)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler):
logger.info("%s reported non-validated threepid: %s", id_server, creds)
return None
@defer.inlineCallbacks
def bind_threepid(
async def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
):
"""Bind a 3PID to an identity server
@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
data = yield self.blacklisting_http_client.post_json_get_json(
data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
# Remember where we bound the threepid
yield self.store.add_user_bound_threepid(
await self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler):
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
res = yield self.bind_threepid(
res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False
)
return res
@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
async def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler):
if threepid.get("id_server"):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server
)
return changed
@defer.inlineCallbacks
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server
Args:
@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler):
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
yield self.blacklisting_http_client.post_json_get_json(
await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
changed = True
@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
yield self.store.remove_user_bound_threepid(
await self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
@ -376,8 +371,7 @@ class IdentityHandler(BaseHandler):
return session_id
@defer.inlineCallbacks
def requestEmailToken(
async def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None
):
"""
@ -412,7 +406,7 @@ class IdentityHandler(BaseHandler):
)
try:
data = yield self.http_client.post_json_get_json(
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params,
)
@ -423,8 +417,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@defer.inlineCallbacks
def requestMsisdnToken(
async def requestMsisdnToken(
self,
id_server,
country,
@ -466,7 +459,7 @@ class IdentityHandler(BaseHandler):
)
try:
data = yield self.http_client.post_json_get_json(
data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params,
)
@ -487,8 +480,7 @@ class IdentityHandler(BaseHandler):
)
return data
@defer.inlineCallbacks
def validate_threepid_session(self, client_secret, sid):
async def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
@ -510,12 +502,12 @@ class IdentityHandler(BaseHandler):
# Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server
validation_session = yield self.threepid_from_creds(
validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details
validation_session = yield self.store.get_threepid_validation_session(
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
@ -525,14 +517,13 @@ class IdentityHandler(BaseHandler):
# Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = yield self.threepid_from_creds(
validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
return validation_session
@defer.inlineCallbacks
def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
@ -553,11 +544,9 @@ class IdentityHandler(BaseHandler):
body = {"client_secret": client_secret, "sid": sid, "token": token}
try:
return (
yield self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
body,
)
return await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
body,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@ -565,8 +554,7 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
@defer.inlineCallbacks
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
@ -582,7 +570,7 @@ class IdentityHandler(BaseHandler):
"""
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
results = await self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
@ -601,10 +589,9 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e)
return None
return (yield self._lookup_3pid_v1(id_server, medium, address))
return await self._lookup_3pid_v1(id_server, medium, address)
@defer.inlineCallbacks
def _lookup_3pid_v1(self, id_server, medium, address):
async def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
@ -617,7 +604,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.blacklisting_http_client.get_json(
data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
@ -625,7 +612,7 @@ class IdentityHandler(BaseHandler):
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server)
await self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@ -634,8 +621,7 @@ class IdentityHandler(BaseHandler):
return None
@defer.inlineCallbacks
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
@ -650,7 +636,7 @@ class IdentityHandler(BaseHandler):
"""
# Check what hashing details are supported by this identity server
try:
hash_details = yield self.blacklisting_http_client.get_json(
hash_details = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
@ -717,7 +703,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = yield self.blacklisting_http_client.post_json_get_json(
lookup_results = await self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
@ -745,13 +731,12 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
@defer.inlineCallbacks
def _verify_any_signature(self, data, server_hostname):
async def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = yield self.blacklisting_http_client.get_json(
key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
@ -770,8 +755,7 @@ class IdentityHandler(BaseHandler):
)
return
@defer.inlineCallbacks
def ask_id_server_for_third_party_invite(
async def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
@ -844,7 +828,7 @@ class IdentityHandler(BaseHandler):
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
data = await self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
@ -864,7 +848,7 @@ class IdentityHandler(BaseHandler):
url = base_url + "/api/v1/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
@ -882,7 +866,7 @@ class IdentityHandler(BaseHandler):
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
data = yield self.blacklisting_http_client.post_urlencoded_get_json(
data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:

View file

@ -271,6 +271,7 @@ class SamlHandler:
raise SynapseError(500, "Error parsing SAML2 response")
displayname = attribute_dict.get("displayname")
emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
@ -288,7 +289,9 @@ class SamlHandler:
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayname
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
)
await self._datastore.record_user_external_id(
@ -381,6 +384,7 @@ class DefaultSamlMappingProvider(object):
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
* emails (list[str]): Any emails for the user
"""
try:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
@ -403,9 +407,13 @@ class DefaultSamlMappingProvider(object):
# If displayname is None, the mxid_localpart will be used instead
displayname = saml_response.ava.get("displayName", [None])[0]
# Retrieve any emails present in the saml response
emails = saml_response.ava.get("email", [])
return {
"mxid_localpart": localpart,
"displayname": displayname,
"emails": emails,
}
@staticmethod
@ -444,4 +452,4 @@ class DefaultSamlMappingProvider(object):
second set consists of those attributes which can be used if
available, but are not necessary
"""
return {"uid", config.mxid_source_attribute}, {"displayName"}
return {"uid", config.mxid_source_attribute}, {"displayName", "email"}

View file

@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
async def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
threepid = yield identity_handler.threepid_from_creds(
threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds
)
elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email
threepid = yield identity_handler.threepid_from_creds(
threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds
)
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None
row = yield self.store.get_threepid_validation_session(
row = await self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
}
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
await self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Email address verification is not enabled on this homeserver"
@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
)
def check_auth(self, authdict, clientip):
return self._check_threepid("email", authdict)
return defer.ensureDeferred(self._check_threepid("email", authdict))
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip):
return self._check_threepid("msisdn", authdict)
return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
INTERACTIVE_AUTH_CHECKERS = [

View file

@ -17,16 +17,18 @@ import logging
import threading
from asyncio import iscoroutine
from functools import wraps
from typing import Dict, Set
from typing import TYPE_CHECKING, Dict, Optional, Set
import six
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
if TYPE_CHECKING:
import resource
logger = logging.getLogger(__name__)
@ -36,6 +38,12 @@ _background_process_start_count = Counter(
["name"],
)
_background_process_in_flight_count = Gauge(
"synapse_background_process_in_flight_count",
"Number of background processes in flight",
labelnames=["name"],
)
# we set registry=None in all of these to stop them getting registered with
# the default registry. Instead we collect them all via the CustomCollector,
# which ensures that we can update them before they are collected.
@ -83,13 +91,17 @@ _background_process_db_sched_duration = Counter(
# it's much simpler to do so than to try to combine them.)
_background_process_counts = {} # type: Dict[str, int]
# map from description to the currently running background processes.
# Set of all running background processes that became active active since the
# last time metrics were scraped (i.e. background processes that performed some
# work since the last scrape.)
#
# it's kept as a dict of sets rather than a big set so that we can keep track
# of process descriptions that no longer have any active processes.
_background_processes = {} # type: Dict[str, Set[_BackgroundProcess]]
# We do it like this to handle the case where we have a large number of
# background processes stacking up behind a lock or linearizer, where we then
# only need to iterate over and update metrics for the process that have
# actually been active and can ignore the idle ones.
_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess]
# A lock that covers the above dicts
# A lock that covers the above set and dict
_bg_metrics_lock = threading.Lock()
@ -101,25 +113,16 @@ class _Collector(object):
"""
def collect(self):
background_process_in_flight_count = GaugeMetricFamily(
"synapse_background_process_in_flight_count",
"Number of background processes in flight",
labels=["name"],
)
global _background_processes_active_since_last_scrape
# We copy the dict so that it doesn't change from underneath us.
# We also copy the process lists as that can also change
# We swap out the _background_processes set with an empty one so that
# we can safely iterate over the set without holding the lock.
with _bg_metrics_lock:
_background_processes_copy = {
k: list(v) for k, v in six.iteritems(_background_processes)
}
_background_processes_copy = _background_processes_active_since_last_scrape
_background_processes_active_since_last_scrape = set()
for desc, processes in six.iteritems(_background_processes_copy):
background_process_in_flight_count.add_metric((desc,), len(processes))
for process in processes:
process.update_metrics()
yield background_process_in_flight_count
for process in _background_processes_copy:
process.update_metrics()
# now we need to run collect() over each of the static Counters, and
# yield each metric they return.
@ -191,13 +194,10 @@ def run_as_background_process(desc, func, *args, **kwargs):
_background_process_counts[desc] = count + 1
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
with LoggingContext(desc) as context:
with BackgroundProcessLoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
proc = _BackgroundProcess(desc, context)
with _bg_metrics_lock:
_background_processes.setdefault(desc, set()).add(proc)
try:
result = func(*args, **kwargs)
@ -214,10 +214,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
except Exception:
logger.exception("Background process '%s' threw an exception", desc)
finally:
proc.update_metrics()
with _bg_metrics_lock:
_background_processes[desc].remove(proc)
_background_process_in_flight_count.labels(desc).dec()
with PreserveLoggingContext():
return run()
@ -238,3 +235,42 @@ def wrap_as_background_process(desc):
return wrap_as_background_process_inner_2
return wrap_as_background_process_inner
class BackgroundProcessLoggingContext(LoggingContext):
"""A logging context that tracks in flight metrics for background
processes.
"""
__slots__ = ["_proc"]
def __init__(self, name: str):
super().__init__(name)
self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource._RUsage]"):
"""Log context has started running (again).
"""
super().start(rusage)
# We've become active again so we make sure we're in the list of active
# procs. (Note that "start" here means we've become active, as opposed
# to starting for the first time.)
with _bg_metrics_lock:
_background_processes_active_since_last_scrape.add(self._proc)
def __exit__(self, type, value, traceback) -> None:
"""Log context has finished.
"""
super().__exit__(type, value, traceback)
# The background process has finished. We explictly remove and manually
# update the metrics here so that if nothing is scraping metrics the set
# doesn't infinitely grow.
with _bg_metrics_lock:
_background_processes_active_since_last_scrape.discard(self._proc)
self._proc.update_metrics()

View file

@ -366,9 +366,8 @@ def _upgrade_existing_database(
if duplicates:
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
"Found multiple delta files with the same name in v%d: %s",
v,
duplicates,
"Found multiple delta files with the same name in v%d: %s"
% (v, duplicates,)
)
# We sort to ensure that we apply the delta files in a consistent