Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-03-18 17:17:03 +00:00
commit f86962cb6b
49 changed files with 818 additions and 447 deletions

View file

@ -124,12 +124,21 @@ sudo pacman -S base-devel python python-pip \
#### CentOS/Fedora
Installing prerequisites on CentOS 7 or Fedora 25:
Installing prerequisites on CentOS 8 or Fedora>26:
```
sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
libwebp-devel tk-devel redhat-rpm-config \
python3-virtualenv libffi-devel openssl-devel
sudo dnf groupinstall "Development Tools"
```
Installing prerequisites on CentOS 7 or Fedora<=25:
```
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
python-virtualenv libffi-devel openssl-devel
python3-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools"
```

1
changelog.d/6925.doc Normal file
View file

@ -0,0 +1 @@
Updated CentOS8 install instructions. Contributed by Richard Kellner.

1
changelog.d/7026.removal Normal file
View file

@ -0,0 +1 @@
Remove the unused query_auth federation endpoint per MSC2451.

1
changelog.d/7034.removal Normal file
View file

@ -0,0 +1 @@
Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1.

1
changelog.d/7044.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that renders UTF-8 text files incorrectly when loaded from media. Contributed by @TheStranjer.

1
changelog.d/7058.feature Normal file
View file

@ -0,0 +1 @@
Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process.

1
changelog.d/7063.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations and comments to the auth handler.

1
changelog.d/7066.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that would cause Synapse to respond with an error about event visibility if a client tried to request the state of a room at a given token.

1
changelog.d/7067.feature Normal file
View file

@ -0,0 +1 @@
Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process.

1
changelog.d/7070.bugfix Normal file
View file

@ -0,0 +1 @@
Repair a data-corruption issue which was introduced in Synapse 1.10, and fixed in Synapse 1.11, and which could cause `/sync` to return with 404 errors about missing events and unknown rooms.

1
changelog.d/7074.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases.

1
changelog.d/7085.feature Normal file
View file

@ -0,0 +1 @@
Add an optional parameter to control whether other sessions are logged out when a user's password is modified.

1
changelog.d/7094.misc Normal file
View file

@ -0,0 +1 @@
Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections.

1
changelog.d/7095.misc Normal file
View file

@ -0,0 +1 @@
Attempt to improve performance of state res v2 algorithm.

View file

@ -38,6 +38,7 @@ The parameter ``threepids`` is optional.
The parameter ``avatar_url`` is optional.
The parameter ``admin`` is optional and defaults to 'false'.
The parameter ``deactivated`` is optional and defaults to 'false'.
The parameter ``password`` is optional. If provided the user's password is updated and all devices are logged out.
If the user already exists then optional parameters default to the current value.
List Accounts
@ -168,11 +169,14 @@ with a body of:
.. code:: json
{
"new_password": "<secret>"
"new_password": "<secret>",
"logout_devices": true,
}
including an ``access_token`` of a server admin.
The parameter ``new_password`` is required.
The parameter ``logout_devices`` is optional and defaults to ``true``.
Get whether a user is a server administrator or not
===================================================

View file

@ -1347,6 +1347,25 @@ saml2_config:
#
#grandfathered_mxid_source_attribute: upn
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# This template doesn't currently need any variable to render.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
# Enable CAS for registration and login.

View file

@ -15,6 +15,9 @@
# limitations under the License.
import logging
import os
import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@ -160,6 +163,14 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m")
)
template_dir = saml2_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
self.saml2_error_html_content = self.read_file(
os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error",
)
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
@ -325,6 +336,25 @@ class SAML2Config(Config):
# The default is 'uid'.
#
#grandfathered_mxid_source_attribute: upn
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# This template doesn't currently need any variable to render.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
""" % {
"config_dir_path": config_dir_path
}

View file

@ -75,7 +75,7 @@ class ServerContextFactory(ContextFactory):
@implementer(IPolicyForHTTPS)
class ClientTLSOptionsFactory(object):
class FederationPolicyForHTTPS(object):
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections
to remote servers for federation.
@ -103,15 +103,15 @@ class ClientTLSOptionsFactory(object):
# let us do).
minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
self._verify_ssl = CertificateOptions(
_verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
)
self._verify_ssl_context = self._verify_ssl.getContext()
self._verify_ssl_context.set_info_callback(self._context_info_cb)
self._verify_ssl_context = _verify_ssl.getContext()
self._verify_ssl_context.set_info_callback(_context_info_cb)
self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS)
self._no_verify_ssl_context = self._no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(self._context_info_cb)
_no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS)
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
def get_options(self, host: bytes):
@ -136,23 +136,6 @@ class ClientTLSOptionsFactory(object):
return SSLClientConnectionCreator(host, ssl_context, should_verify)
@staticmethod
def _context_info_cb(ssl_connection, where, ret):
"""The 'information callback' for our openssl context object."""
# we assume that the app_data on the connection object has been set to
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
tls_protocol = ssl_connection.get_app_data()
try:
# ... we further assume that SSLClientConnectionCreator has set the
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
tls_protocol._synapse_tls_verifier.verify_context_info_cb(
ssl_connection, where
)
except: # noqa: E722, taken from the twisted implementation
logger.exception("Error during info_callback")
f = Failure()
tls_protocol.failVerification(f)
def creatorForNetloc(self, hostname, port):
"""Implements the IPolicyForHTTPS interace so that this can be passed
directly to agents.
@ -160,6 +143,43 @@ class ClientTLSOptionsFactory(object):
return self.get_options(hostname)
@implementer(IPolicyForHTTPS)
class RegularPolicyForHTTPS(object):
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections
to remote servers, for other than federation.
Always uses the same OpenSSL context object, which uses the default OpenSSL CA
trust root.
"""
def __init__(self):
trust_root = platformTrust()
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
self._ssl_context.set_info_callback(_context_info_cb)
def creatorForNetloc(self, hostname, port):
return SSLClientConnectionCreator(hostname, self._ssl_context, True)
def _context_info_cb(ssl_connection, where, ret):
"""The 'information callback' for our openssl context objects.
Note: Once this is set as the info callback on a Context object, the Context should
only be used with the SSLClientConnectionCreator.
"""
# we assume that the app_data on the connection object has been set to
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
tls_protocol = ssl_connection.get_app_data()
try:
# ... we further assume that SSLClientConnectionCreator has set the
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where)
except: # noqa: E722, taken from the twisted implementation
logger.exception("Error during info_callback")
f = Failure()
tls_protocol.failVerification(f)
@implementer(IOpenSSLClientConnectionCreator)
class SSLClientConnectionCreator(object):
"""Creates openssl connection objects for client connections.

View file

@ -39,10 +39,8 @@ from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
preserve_fn,
)
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__)
@ -57,86 +55,6 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: List[EventBase],
room_version: str,
outlier: bool = False,
include_none: bool = False,
):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
origin
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
res = yield self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
# This should not exist in the base implementation, until
# this is fixed, ignore it for typing. See issue #6997.
res = yield defer.ensureDeferred(
self.get_pdu( # type: ignore
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
)
except SynapseError:
pass
if not res:
logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id
)
return res
handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = yield make_deferred_yieldable(
defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
if include_none:
return valid_pdus
else:
return [p for p in valid_pdus if p]
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]

View file

@ -33,6 +33,7 @@ from typing import (
from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
@ -51,7 +52,7 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
@ -345,6 +346,83 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: List[EventBase],
room_version: str,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
origin
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
res = yield self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
res = yield defer.ensureDeferred(
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version, # type: ignore
outlier=outlier,
timeout=10000,
)
)
except SynapseError:
pass
if not res:
logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id
)
return res
handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = await make_deferred_yieldable(
defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
if include_none:
return valid_pdus
else:
return [p for p in valid_pdus if p]
async def get_event_auth(self, destination, room_id, event_id):
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)

View file

@ -470,57 +470,6 @@ class FederationServer(FederationBase):
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res
async def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
missing (list): A list of event_ids indicating what the other
side (`origin`) think we're missing.
rejects (dict): A mapping from event_id to a 2-tuple of reason
string and a proof (or None) of why the event was rejected.
The keys of this dict give the list of events the `origin` has
rejected.
Args:
origin (str)
content (dict)
event_id (str)
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
room_version = await self.store.get_room_version(room_id)
auth_chain = [
event_from_pdu_json(e, room_version) for e in content["auth_chain"]
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version.identifier
)
ret = await self.handler.on_query_auth(
origin,
event_id,
room_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
)
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
return 200, send_content
@log_function
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)

View file

@ -643,17 +643,6 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
async def on_POST(self, origin, content, query, context, event_id):
new_content = await self.handler.on_query_auth_request(
origin, content, context, event_id
)
return 200, new_content
class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@ -1412,7 +1401,6 @@ FEDERATION_SERVLET_CLASSES = (
FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,

View file

@ -44,7 +44,11 @@ class AccountValidityHandler(object):
self._account_validity = self.hs.config.account_validity
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled.
try:
app_name = self.hs.config.email_app_name

View file

@ -18,10 +18,10 @@ import logging
import time
import unicodedata
import urllib.parse
from typing import Any
from typing import Any, Dict, Iterable, List, Optional
import attr
import bcrypt
import bcrypt # type: ignore[import]
import pymacaroons
from twisted.internet import defer
@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID
from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
self.checkers[inst.AUTH_TYPE] = inst
self.checkers[inst.AUTH_TYPE] = inst # type: ignore
self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip):
def validate_user_via_ui_auth(
self, requester: Requester, request_body: Dict[str, Any], clientip: str
):
"""
Checks that the user is who they claim to be, via a UI auth.
@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn't stolen by re-authenticating them.
Args:
requester (Requester): The user, as given by the access token
requester: The user, as given by the access token
request_body (dict): The body of the request sent by the client
request_body: The body of the request sent by the client
clientip (str): The IP address of the client.
clientip: The IP address of the client.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
return self.checkers.keys()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
def check_auth(
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
decorator.
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
flows: A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
clientip: The IP address of the client.
Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of
@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
"""
authdict = None
sid = None
sid = None # type: Optional[str]
if clientdict and "auth" in clientdict:
authdict = clientdict["auth"]
del clientdict["auth"]
@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
creds = session["creds"]
# check auth type currently being presented
errordict = {}
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
login_type = authdict["type"]
login_type = authdict["type"] # type: str
try:
result = yield self._check_auth_dict(authdict, clientip)
if result:
@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
return True
return False
def get_session_id(self, clientdict):
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
Gets the session ID for a client given the client dictionary
@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request
Returns:
str|None: The string session ID the client sent. If the client did
The string session ID the client sent. If the client did
not send a session ID, returns None.
"""
sid = None
@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
def set_session_data(self, session_id, key, value):
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
Args:
session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
value (any): The data to store
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
value: The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
Retrieve data stored with set_session_data
Args:
session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
default (any): Value to return if the key has not been set
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip):
def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
"""Attempt to validate the auth dict provided by a client
Args:
authdict (object): auth dict provided by the client
clientip (str): IP address of the client
authdict: auth dict provided by the client
clientip: IP address of the client
Returns:
Deferred: result of the stage verification.
@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id
def _get_params_recaptcha(self):
def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key}
def _get_params_terms(self):
def _get_params_terms(self) -> dict:
return {
"policies": {
"privacy_policy": {
@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
}
}
def _auth_dict_for_flows(self, flows, session):
def _auth_dict_for_flows(
self, flows: List[List[str]], session: Dict[str, Any]
) -> Dict[str, Any]:
public_flows = []
for f in flows:
public_flows.append(f)
@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
params = {}
params = {} # type: Dict[str, Any]
for f in public_flows:
for stage in f:
@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
"params": params,
}
def _get_session_info(self, session_id):
def _get_session_info(self, session_id: Optional[str]) -> dict:
"""
Gets or creates a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
if session_id not in self.sessions:
session_id = None
@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
"""
Creates a new access token for the user with the given user ID.
@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already.
Args:
user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
user_id: canonical User ID
device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
valid_until_ms (int|None): when the token is valid until. None for
valid_until_ms: when the token is valid until. None for
no expiry.
Returns:
The access token for the user's session.
@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
return access_token
@defer.inlineCallbacks
def check_user_exists(self, user_id):
def check_user_exists(self, user_id: str):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
(unicode|bytes) user_id: complete @user:id
user_id: complete @user:id
Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or
@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
return None
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
def _find_user_id_and_pwd_hash(self, user_id: str):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
)
return result
def get_supported_login_types(self):
def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is
@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types.
Returns:
Iterable[str]: login types
login types
"""
return self._supported_login_types
@defer.inlineCallbacks
def validate_login(self, username, login_submission):
def validate_login(self, username: str, login_submission: Dict[str, Any]):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
username (str): username supplied by the user
login_submission (dict): the whole of the login submission
username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
Deferred[str, func]: canonical user id, and optional callback
@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
def check_password_provider_3pid(self, medium: str, address: str, password: str):
"""Check if a password provider is able to validate a thirdparty login
Args:
medium (str): The medium of the 3pid (ex. email).
address (str): The address of the 3pid (ex. jdoe@example.com).
password (str): The password of the user.
medium: The medium of the 3pid (ex. email).
address: The address of the 3pid (ex. jdoe@example.com).
password: The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
return None, None
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
multiple inexact matches.
Args:
user_id (unicode): complete @user:id
password (unicode): the provided password
user_id: complete @user:id
password: the provided password
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
def delete_access_token(self, access_token):
def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
access_token (str): access token to be deleted
access_token: access token to be deleted
Returns:
Deferred
@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def delete_access_tokens_for_user(
self, user_id, except_token_id=None, device_id=None
self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be
deleted
device_id (str|None): ID of device the tokens are associated with.
user_id: ID of user the tokens belong to
except_token_id: access_token ID which should *not* be deleted
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address, id_server=None):
def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
Args:
user_id (str)
medium (str)
address (str)
id_server (str|None): Use the given identity server when unbinding
user_id: ID of user to remove the 3pid from.
medium: The medium of the 3pid being removed: "email" or "msisdn".
address: The 3pid address to remove.
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
yield self.store.user_delete_threepid(user_id, medium, address)
return result
def _save_session(self, session):
def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
def hash(self, password):
def hash(self, password: str):
"""Computes a secure hash of password.
Args:
password (unicode): Password to hash.
password: Password to hash.
Returns:
Deferred(unicode): Hashed password.
@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
return defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password, stored_hash):
def validate_hash(self, password: str, stored_hash: bytes):
"""Validates that self.hash(password) == stored_hash.
Args:
password (unicode): Password to hash.
stored_hash (bytes): Expected hash value.
password: Password to hash.
stored_hash: Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None):
def generate_access_token(
self, user_id: str, extra_caveats: Optional[List[str]] = None
) -> str:
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
"""
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def _generate_base_macaroon(self, user_id):
def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",

View file

@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
new_pl_content = copy_power_levels_contents(old_room_pl_state.content)
# pre-msc2260 rooms may not have the right setting for aliases. If no other
# value is set, set it now.
events_default = new_pl_content.get("events_default", 0)
new_pl_content.setdefault("events", {}).setdefault(
EventTypes.Aliases, events_default
)
logger.debug("Setting correct PLs in new room to %s", new_pl_content)
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": new_pl_content,
"content": old_room_pl_state.content,
},
ratelimit=False,
)
@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler):
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
# MSC2260: Allow everybody to send alias events by default
# This will be reudundant on pre-MSC2260 rooms, since the
# aliases event is special-cased.
EventTypes.Aliases: 0,
EventTypes.Tombstone: 100,
EventTypes.ServerACL: 100,
},

View file

@ -23,6 +23,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.types import (
@ -73,6 +74,8 @@ class SamlHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
self._error_html_content = hs.config.saml2_error_html_content
def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect
@ -114,7 +117,22 @@ class SamlHandler:
# the dict.
self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
try:
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
except Exception as e:
# If decoding the response or mapping it to a user failed, then log the
# error and tell the user that something went wrong.
logger.error(e)
request.setResponseCode(400)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(
b"Content-Length", b"%d" % (len(self._error_html_content),)
)
request.write(self._error_html_content.encode("utf8"))
finish_request(request)
return
self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):

View file

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
@ -32,14 +34,17 @@ class SetPasswordHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
def set_password(
self,
user_id: str,
new_password: str,
logout_devices: bool,
requester: Optional[Requester] = None,
):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
password_hash = yield self._auth_handler.hash(new_password)
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@ -48,14 +53,18 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id
)
# Optionally, log out all of the user's other sessions.
if logout_devices:
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)
# First delete all of their other devices.
yield self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)

View file

@ -244,9 +244,6 @@ class SimpleHttpClient(object):
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,

View file

@ -45,7 +45,7 @@ class MatrixFederationAgent(object):
Args:
reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None):
tls_client_options_factory (FederationPolicyForHTTPS|None):
factory to use for fetching client tls options, or none to disable TLS.
_srv_resolver (SrvResolver|None):

View file

@ -210,7 +210,7 @@ class LoggingContext(object):
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "alive", "request", "scope"]
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
@ -218,6 +218,7 @@ class LoggingContext(object):
self.alive = None
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
@ -511,7 +512,7 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context: Optional[LoggingContext] = None) -> None:
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
if new_context is None:
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else:

View file

@ -7,7 +7,7 @@
<body>
<p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
<p>
If you're seeing this page after clicking a link sent to you via email, make
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
validation link in the same client you're logging in from.
</p>
@ -24,19 +24,22 @@
// we just don't print anything specific.
let searchStr = "";
if (window.location.search) {
// For some reason window.location.searchParams isn't always defined when
// window.location.search is, so we can't just use it right away.
// window.location.searchParams isn't always defined when
// window.location.search is, so it's more reliable to parse the latter.
searchStr = window.location.search;
} else if (window.location.hash) {
//
// Replace the # with a ? so that URLSearchParams does the right thing and
// doesn't parse the first parameter incorrectly.
searchStr = window.location.hash.replace("#", "?");
}
// We might end up with no error in the URL, so we need to check if we have one
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
document.getElementById("errormsg").innerHTML = ` ("${errorDesc}")`;
document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
}
</script>
</body>
</html>
</html>

View file

@ -221,8 +221,9 @@ class UserRestServletV2(RestServlet):
raise SynapseError(400, "Invalid password")
else:
new_password = body["password"]
logout_devices = True
await self.set_password_handler.set_password(
target_user.to_string(), new_password, requester
target_user.to_string(), new_password, logout_devices, requester
)
if "deactivated" in body:
@ -536,9 +537,10 @@ class ResetPasswordRestServlet(RestServlet):
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"]
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
target_user_id, new_password, requester
target_user_id, new_password, logout_devices, requester
)
return 200, {}

View file

@ -189,12 +189,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content = parse_json_object_from_request(request)
if event_type == EventTypes.Aliases:
# MSC2260
raise SynapseError(
400, "Cannot send m.room.aliases events via /rooms/{room_id}/state"
)
event_dict = {
"type": event_type,
"content": content,
@ -242,12 +236,6 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
if event_type == EventTypes.Aliases:
# MSC2260
raise SynapseError(
400, "Cannot send m.room.aliases events via /rooms/{room_id}/send"
)
event_dict = {
"type": event_type,
"content": content,

View file

@ -265,8 +265,11 @@ class PasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"]
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(user_id, new_password, requester)
await self._set_password_handler.set_password(
user_id, new_password, logout_devices, requester
)
return 200, {}

View file

@ -30,6 +30,22 @@ from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__)
# list all text content types that will have the charset default to UTF-8 when
# none is given
TEXT_CONTENT_TYPES = [
"text/css",
"text/csv",
"text/html",
"text/calendar",
"text/plain",
"text/javascript",
"application/json",
"application/ld+json",
"application/rtf",
"image/svg+xml",
"text/xml",
]
def parse_media_id(request):
try:
@ -96,7 +112,14 @@ def add_file_headers(request, media_type, file_size, upload_name):
def _quote(x):
return urllib.parse.quote(x.encode("utf-8"))
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
# Default to a UTF-8 charset for text content types.
# ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
if media_type.lower() in TEXT_CONTENT_TYPES:
content_type = media_type + "; charset=UTF-8"
else:
content_type = media_type
request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
if upload_name:
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
#

View file

@ -14,7 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import DirectServeResource, wrap_html_request_handler
from synapse.http.server import (
DirectServeResource,
finish_request,
wrap_html_request_handler,
)
class SAML2ResponseResource(DirectServeResource):
@ -24,8 +28,20 @@ class SAML2ResponseResource(DirectServeResource):
def __init__(self, hs):
super().__init__()
self._error_html_content = hs.config.saml2_error_html_content
self._saml_handler = hs.get_saml_handler()
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
request.setResponseCode(400)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
request.write(self._error_html_content.encode("utf8"))
finish_request(request)
@wrap_html_request_handler
async def _async_render_POST(self, request):
return await self._saml_handler.handle_saml_response(request)

View file

@ -26,7 +26,6 @@ import logging
import os
from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
@ -35,6 +34,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.crypto.context_factory import RegularPolicyForHTTPS
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
@ -310,7 +310,7 @@ class HomeServer(object):
return (
InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
else RegularPolicyForHTTPS()
)
def build_simple_http_client(self):
@ -420,7 +420,7 @@ class HomeServer(object):
return PusherPool(self)
def build_http_client(self):
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
self.config
)
return MatrixFederationHttpClient(self, tls_client_options_factory)

View file

@ -662,28 +662,16 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
"""Gets the full auth chain for a set of events (including rejected
events).
Includes the given event IDs in the result.
Note that:
1. All events must be state events.
2. For v1 rooms this may not have the full auth chain in the
presence of rejected events
Args:
event_ids: The event IDs of the events to fetch the auth chain for.
Must be state events.
ignore_events: Set of events to exclude from the returned auth
chain.
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
Deferred[Set[str]]: Set of event IDs.
"""
return self.store.get_auth_chain_ids(
event_ids, include_given=True, ignore_events=ignore_events,
)
return self.store.get_auth_chain_difference(state_sets)

View file

@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
common = set(itervalues(state_sets[0])).intersection(
*(itervalues(s) for s in state_sets[1:])
difference = yield state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
auth_sets = []
for state_set in state_sets:
auth_ids = {
eid
for key, eid in iteritems(state_set)
if (
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
or key
in (
(EventTypes.PowerLevels, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
)
)
and eid not in common
}
auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
auth_ids.update(auth_chain)
auth_sets.append(auth_ids)
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
union = set().union(*auth_sets)
return union - intersection
return difference
def _seperate(state_sets):

View file

@ -14,7 +14,7 @@
# limitations under the License.
import itertools
import logging
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple
from six.moves.queue import Empty, PriorityQueue
@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
Returns:
Deferred[Set[str]]
"""
return self.db.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
#
# The idea here is to basically walk the auth graph of each state set in
# tandem, keeping track of which auth events are reachable by each state
# set. If we reach an auth event we've already visited (via a different
# state set) then we mark that auth event and all ancestors as reachable
# by the state set. This requires that we keep track of the auth chains
# in memory.
#
# Doing it in a such a way means that we can stop early if all auth
# events we're currently walking are reachable by all state sets.
#
# *Note*: We can't stop walking an event's auth chain if it is reachable
# by all state sets. This is because other auth chains we're walking
# might be reachable only via the original auth chain. For example,
# given the following auth chain:
#
# A -> C -> D -> E
# / /
# B -´---------´
#
# and state sets {A} and {B} then walking the auth chains of A and B
# would immediately show that C is reachable by both. However, if we
# stopped at C then we'd only reach E via the auth chain of B and so E
# would errornously get included in the returned difference.
#
# The other thing that we do is limit the number of auth chains we walk
# at once, due to practical limits (i.e. we can only query the database
# with a limited set of parameters). We pick the auth chains we walk
# each iteration based on their depth, in the hope that events with a
# lower depth are likely reachable by those with higher depths.
#
# We could use any ordering that we believe would give a rough
# topological ordering, e.g. origin server timestamp. If the ordering
# chosen is not topological then the algorithm still produces the right
# result, but perhaps a bit more inefficiently. This is why it is safe
# to use "depth" here.
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Dict from events in auth chains to which sets *cannot* reach them.
# I.e. if the set is empty then all sets can reach the event.
event_to_missing_sets = {
event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
for event_id in initial_events
}
# We need to get the depth of the initial events for sorting purposes.
sql = """
SELECT depth, event_id FROM events
WHERE %s
ORDER BY depth ASC
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", initial_events
)
txn.execute(sql % (clause,), args)
# The sorted list of events whose auth chains we should walk.
search = txn.fetchall() # type: List[Tuple[int, str]]
# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
base_sql = """
SELECT a.event_id, auth_id, depth
FROM event_auth AS a
INNER JOIN events AS e ON (e.event_id = a.auth_id)
WHERE
"""
while search:
# Check whether all our current walks are reachable by all state
# sets. If so we can bail.
if all(not event_to_missing_sets[eid] for _, eid in search):
break
# Fetch the auth events and their depths of the N last events we're
# currently walking
search, chunk = search[:-100], search[-100:]
clause, args = make_in_list_sql_clause(
txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
)
txn.execute(base_sql + clause, args)
for event_id, auth_event_id, auth_event_depth in txn:
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
sets = event_to_missing_sets.get(auth_event_id)
if sets is None:
# First time we're seeing this event, so we add it to the
# queue of things to fetch.
search.append((auth_event_depth, auth_event_id))
# Assume that this event is unreachable from any of the
# state sets until proven otherwise
sets = event_to_missing_sets[auth_event_id] = set(
range(len(state_sets))
)
else:
# We've previously seen this event, so look up its auth
# events and recursively mark all ancestors as reachable
# by the current event's state set.
a_ids = event_to_auth_events.get(auth_event_id)
while a_ids:
new_aids = set()
for a_id in a_ids:
event_to_missing_sets[a_id].intersection_update(
event_to_missing_sets[event_id]
)
b = event_to_auth_events.get(a_id)
if b:
new_aids.update(b)
a_ids = new_aids
# Mark that the auth event is reachable by the approriate sets.
sets.intersection_update(event_to_missing_sets[event_id])
search.sort()
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id):
return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id

View file

@ -29,7 +29,11 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.logging.context import (
LoggingContext,
LoggingContextOrSentinel,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
@ -543,7 +547,9 @@ class Database(object):
Returns:
Deferred: The result of func
"""
parent_context = LoggingContext.current_context()
parent_context = (
LoggingContext.current_context()
) # type: Optional[LoggingContextOrSentinel]
if parent_context == LoggingContext.sentinel:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"

View file

@ -118,30 +118,36 @@ def filter_events_for_client(
the original event if they can see it as normal.
"""
if event.type == "org.matrix.dummy_event" and filter_send_to_client:
return None
# Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can
# see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks.
if filter_send_to_client:
if event.type == "org.matrix.dummy_event":
return None
if not event.is_state() and event.sender in ignore_list and filter_send_to_client:
return None
if not event.is_state() and event.sender in ignore_list:
return None
# Until MSC2261 has landed we can't redact malicious alias events, so for
# now we temporarily filter out m.room.aliases entirely to mitigate
# abuse, while we spec a better solution to advertising aliases
# on rooms.
if event.type == EventTypes.Aliases:
return None
# Until MSC2261 has landed we can't redact malicious alias events, so for
# now we temporarily filter out m.room.aliases entirely to mitigate
# abuse, while we spec a better solution to advertising aliases
# on rooms.
if event.type == EventTypes.Aliases:
return None
# Don't try to apply the room's retention policy if the event is a state event, as
# MSC1763 states that retention is only considered for non-state events.
if filter_send_to_client and not event.is_state():
retention_policy = retention_policies[event.room_id]
max_lifetime = retention_policy.get("max_lifetime")
# Don't try to apply the room's retention policy if the event is a state
# event, as MSC1763 states that retention is only considered for non-state
# events.
if not event.is_state():
retention_policy = retention_policies[event.room_id]
max_lifetime = retention_policy.get("max_lifetime")
if max_lifetime is not None:
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
if max_lifetime is not None:
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
if event.origin_server_ts < oldest_allowed_ts:
return None
if event.origin_server_ts < oldest_allowed_ts:
return None
if event.event_id in always_include_ids:
return event

View file

@ -23,7 +23,7 @@ from OpenSSL import SSL
from synapse.config._base import Config, RootConfig
from synapse.config.tls import ConfigError, TlsConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from tests.unittest import TestCase
@ -180,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="")
cf = ClientTLSOptionsFactory(t)
cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0)
self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_tls_client_minimum_set_passed_through_1_0(self):
"""
@ -195,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="")
cf = ClientTLSOptionsFactory(t)
cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has not had any of the NO_TLS set.
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
"""
@ -273,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="")
cf = ClientTLSOptionsFactory(t)
cf = FederationPolicyForHTTPS(t)
# Not in the whitelist
opts = cf.get_options(b"notexample.com")
@ -282,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
# Caught by the wildcard
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
self.assertFalse(opts._verifier._verify_certs)
def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
"""get the options bits from an openssl context object"""
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
# use the low-level interface
return SSL._lib.SSL_CTX_get_options(ssl_context._context)

View file

@ -31,7 +31,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.well_known_resolver import (
@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self._config = config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
self.tls_factory = ClientTLSOptionsFactory(config)
self.tls_factory = FederationPolicyForHTTPS(config)
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
@ -715,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
config = default_config("test", parse=True)
# Build a new agent and WellKnownResolver with a different tls factory
tls_factory = ClientTLSOptionsFactory(config)
tls_factory = FederationPolicyForHTTPS(config)
agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=tls_factory,

View file

@ -868,6 +868,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Set this new alias as the canonical alias for this room
self.helper.send_state(
room_id,
"m.room.aliases",
{"aliases": [test_alias]},
tok=self.admin_user_tok,
state_key="test",
)
self.helper.send_state(
room_id,
"m.room.canonical_alias",

View file

@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test")
def test_cannot_set_alias_via_state_event(self):
self.ensure_user_joined_room()
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
self.room_id,
self.hs.hostname,
)
data = {"aliases": [self.random_alias(5)]}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.render(request)
self.assertEqual(channel.code, 400, channel.result)
def test_state_event_not_in_room(self):
self.ensure_user_left_room()
self.set_alias_via_state_event(403)
def test_directory_endpoint_not_in_room(self):
self.ensure_user_left_room()
self.set_alias_via_directory(403)
def test_state_event_in_room_too_long(self):
self.ensure_user_joined_room()
self.set_alias_via_state_event(400, alias_length=256)
def test_directory_in_room_too_long(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(400, alias_length=256)
def test_state_event_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_state_event(200)
def test_directory_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(200)
@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
def set_alias_via_state_event(self, expected_code, alias_length=5):
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
self.room_id,
self.hs.hostname,
)
data = {"aliases": [self.random_alias(alias_length)]}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def set_alias_via_directory(self, expected_code, alias_length=5):
url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
data = {"room_id": self.room_id}

View file

@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
def get_auth_chain(self, event_ids, ignore_events):
def _get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
events).
@ -617,9 +617,6 @@ class TestStateResolutionStore(object):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
ignore_events: Set of events to exclude from the returned auth
chain.
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
@ -629,7 +626,7 @@ class TestStateResolutionStore(object):
stack = list(event_ids)
while stack:
event_id = stack.pop()
if event_id in result or event_id in ignore_events:
if event_id in result:
continue
result.add(event_id)
@ -639,3 +636,9 @@ class TestStateResolutionStore(object):
stack.append(aid)
return list(result)
def get_auth_chain_difference(self, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
return set(chains[0]).union(*chains[1:]) - common

View file

@ -13,19 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_prev_events_for_room(self):
room_id = "@ROOM:local"
@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 20):
yield self.store.db.runInteraction("insert", insert_event, i)
self.get_success(self.store.db.runInteraction("insert", insert_event, i))
# this should get the last ten
r = yield self.store.get_prev_events_for_room(room_id)
r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r))
for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
@defer.inlineCallbacks
def test_get_rooms_with_many_extremities(self):
room1 = "#room1"
room2 = "#room2"
@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 20):
yield self.store.db.runInteraction("insert", insert_event, i, room1)
yield self.store.db.runInteraction("insert", insert_event, i, room2)
yield self.store.db.runInteraction("insert", insert_event, i, room3)
self.get_success(
self.store.db.runInteraction("insert", insert_event, i, room1)
)
self.get_success(
self.store.db.runInteraction("insert", insert_event, i, room2)
)
self.get_success(
self.store.db.runInteraction("insert", insert_event, i, room3)
)
# Test simple case
r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
self.assertEqual(len(r), 3)
# Does filter work?
r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1])
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
self.assertTrue(room2 in r)
self.assertTrue(room3 in r)
self.assertEqual(len(r), 2)
r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
r = self.get_success(
self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
)
self.assertEqual(r, [room3])
# Does filter and limit work?
r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1])
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
def test_auth_difference(self):
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
# where the top are the most recent events.
#
# A B
# \ /
# D E
# \ |
# ` F C
# | /|
# G ´ |
# | \ |
# H I
# | |
# K J
auth_graph = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
"d": ["f"],
"e": ["f"],
"f": ["g"],
"g": ["h", "i"],
"h": ["k"],
"i": ["j"],
"k": [],
"j": [],
}
depth_map = {
"a": 7,
"b": 7,
"c": 4,
"d": 6,
"e": 6,
"f": 5,
"g": 3,
"h": 2,
"i": 2,
"k": 1,
"j": 1,
}
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
def insert_event(txn, event_id, stream_ordering):
depth = depth_map[event_id]
self.store.db.simple_insert_txn(
txn,
table="events",
values={
"event_id": event_id,
"room_id": room_id,
"depth": depth,
"topological_ordering": depth,
"type": "m.test",
"processed": True,
"outlier": False,
"stream_ordering": stream_ordering,
},
)
self.store.db.simple_insert_many_txn(
txn,
table="event_auth",
values=[
{"event_id": event_id, "room_id": room_id, "auth_id": a}
for a in auth_graph[event_id]
],
)
next_stream_ordering = 0
for event_id in auth_graph:
next_stream_ordering += 1
self.get_success(
self.store.db.runInteraction(
"insert", insert_event, event_id, next_stream_ordering
)
)
# Now actually test that various combinations give the right result:
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
self.assertSetEqual(difference, set())

View file

@ -185,6 +185,7 @@ commands = mypy \
synapse/federation/federation_client.py \
synapse/federation/sender \
synapse/federation/transport \
synapse/handlers/auth.py \
synapse/handlers/directory.py \
synapse/handlers/presence.py \
synapse/handlers/sync.py \