Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-03-18 17:17:03 +00:00
commit f86962cb6b
49 changed files with 818 additions and 447 deletions

View file

@ -124,12 +124,21 @@ sudo pacman -S base-devel python python-pip \
#### CentOS/Fedora #### CentOS/Fedora
Installing prerequisites on CentOS 7 or Fedora 25: Installing prerequisites on CentOS 8 or Fedora>26:
```
sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
libwebp-devel tk-devel redhat-rpm-config \
python3-virtualenv libffi-devel openssl-devel
sudo dnf groupinstall "Development Tools"
```
Installing prerequisites on CentOS 7 or Fedora<=25:
``` ```
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
python-virtualenv libffi-devel openssl-devel python3-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools" sudo yum groupinstall "Development Tools"
``` ```

1
changelog.d/6925.doc Normal file
View file

@ -0,0 +1 @@
Updated CentOS8 install instructions. Contributed by Richard Kellner.

1
changelog.d/7026.removal Normal file
View file

@ -0,0 +1 @@
Remove the unused query_auth federation endpoint per MSC2451.

1
changelog.d/7034.removal Normal file
View file

@ -0,0 +1 @@
Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1.

1
changelog.d/7044.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that renders UTF-8 text files incorrectly when loaded from media. Contributed by @TheStranjer.

1
changelog.d/7058.feature Normal file
View file

@ -0,0 +1 @@
Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process.

1
changelog.d/7063.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations and comments to the auth handler.

1
changelog.d/7066.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that would cause Synapse to respond with an error about event visibility if a client tried to request the state of a room at a given token.

1
changelog.d/7067.feature Normal file
View file

@ -0,0 +1 @@
Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process.

1
changelog.d/7070.bugfix Normal file
View file

@ -0,0 +1 @@
Repair a data-corruption issue which was introduced in Synapse 1.10, and fixed in Synapse 1.11, and which could cause `/sync` to return with 404 errors about missing events and unknown rooms.

1
changelog.d/7074.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases.

1
changelog.d/7085.feature Normal file
View file

@ -0,0 +1 @@
Add an optional parameter to control whether other sessions are logged out when a user's password is modified.

1
changelog.d/7094.misc Normal file
View file

@ -0,0 +1 @@
Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections.

1
changelog.d/7095.misc Normal file
View file

@ -0,0 +1 @@
Attempt to improve performance of state res v2 algorithm.

View file

@ -38,6 +38,7 @@ The parameter ``threepids`` is optional.
The parameter ``avatar_url`` is optional. The parameter ``avatar_url`` is optional.
The parameter ``admin`` is optional and defaults to 'false'. The parameter ``admin`` is optional and defaults to 'false'.
The parameter ``deactivated`` is optional and defaults to 'false'. The parameter ``deactivated`` is optional and defaults to 'false'.
The parameter ``password`` is optional. If provided the user's password is updated and all devices are logged out.
If the user already exists then optional parameters default to the current value. If the user already exists then optional parameters default to the current value.
List Accounts List Accounts
@ -168,11 +169,14 @@ with a body of:
.. code:: json .. code:: json
{ {
"new_password": "<secret>" "new_password": "<secret>",
"logout_devices": true,
} }
including an ``access_token`` of a server admin. including an ``access_token`` of a server admin.
The parameter ``new_password`` is required.
The parameter ``logout_devices`` is optional and defaults to ``true``.
Get whether a user is a server administrator or not Get whether a user is a server administrator or not
=================================================== ===================================================

View file

@ -1347,6 +1347,25 @@ saml2_config:
# #
#grandfathered_mxid_source_attribute: upn #grandfathered_mxid_source_attribute: upn
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# This template doesn't currently need any variable to render.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
# Enable CAS for registration and login. # Enable CAS for registration and login.

View file

@ -15,6 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os
import pkg_resources
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module from synapse.util.module_loader import load_module, load_python_module
@ -160,6 +163,14 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m") saml2_config.get("saml_session_lifetime", "5m")
) )
template_dir = saml2_config.get("template_dir")
if not template_dir:
template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
self.saml2_error_html_content = self.read_file(
os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error",
)
def _default_saml_config_dict( def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set self, required_attributes: set, optional_attributes: set
): ):
@ -325,6 +336,25 @@ class SAML2Config(Config):
# The default is 'uid'. # The default is 'uid'.
# #
#grandfathered_mxid_source_attribute: upn #grandfathered_mxid_source_attribute: upn
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# This template doesn't currently need any variable to render.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
""" % { """ % {
"config_dir_path": config_dir_path "config_dir_path": config_dir_path
} }

View file

@ -75,7 +75,7 @@ class ServerContextFactory(ContextFactory):
@implementer(IPolicyForHTTPS) @implementer(IPolicyForHTTPS)
class ClientTLSOptionsFactory(object): class FederationPolicyForHTTPS(object):
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections """Factory for Twisted SSLClientConnectionCreators that are used to make connections
to remote servers for federation. to remote servers for federation.
@ -103,15 +103,15 @@ class ClientTLSOptionsFactory(object):
# let us do). # let us do).
minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
self._verify_ssl = CertificateOptions( _verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
) )
self._verify_ssl_context = self._verify_ssl.getContext() self._verify_ssl_context = _verify_ssl.getContext()
self._verify_ssl_context.set_info_callback(self._context_info_cb) self._verify_ssl_context.set_info_callback(_context_info_cb)
self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) _no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS)
self._no_verify_ssl_context = self._no_verify_ssl.getContext() self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(self._context_info_cb) self._no_verify_ssl_context.set_info_callback(_context_info_cb)
def get_options(self, host: bytes): def get_options(self, host: bytes):
@ -136,23 +136,6 @@ class ClientTLSOptionsFactory(object):
return SSLClientConnectionCreator(host, ssl_context, should_verify) return SSLClientConnectionCreator(host, ssl_context, should_verify)
@staticmethod
def _context_info_cb(ssl_connection, where, ret):
"""The 'information callback' for our openssl context object."""
# we assume that the app_data on the connection object has been set to
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
tls_protocol = ssl_connection.get_app_data()
try:
# ... we further assume that SSLClientConnectionCreator has set the
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
tls_protocol._synapse_tls_verifier.verify_context_info_cb(
ssl_connection, where
)
except: # noqa: E722, taken from the twisted implementation
logger.exception("Error during info_callback")
f = Failure()
tls_protocol.failVerification(f)
def creatorForNetloc(self, hostname, port): def creatorForNetloc(self, hostname, port):
"""Implements the IPolicyForHTTPS interace so that this can be passed """Implements the IPolicyForHTTPS interace so that this can be passed
directly to agents. directly to agents.
@ -160,6 +143,43 @@ class ClientTLSOptionsFactory(object):
return self.get_options(hostname) return self.get_options(hostname)
@implementer(IPolicyForHTTPS)
class RegularPolicyForHTTPS(object):
"""Factory for Twisted SSLClientConnectionCreators that are used to make connections
to remote servers, for other than federation.
Always uses the same OpenSSL context object, which uses the default OpenSSL CA
trust root.
"""
def __init__(self):
trust_root = platformTrust()
self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
self._ssl_context.set_info_callback(_context_info_cb)
def creatorForNetloc(self, hostname, port):
return SSLClientConnectionCreator(hostname, self._ssl_context, True)
def _context_info_cb(ssl_connection, where, ret):
"""The 'information callback' for our openssl context objects.
Note: Once this is set as the info callback on a Context object, the Context should
only be used with the SSLClientConnectionCreator.
"""
# we assume that the app_data on the connection object has been set to
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
tls_protocol = ssl_connection.get_app_data()
try:
# ... we further assume that SSLClientConnectionCreator has set the
# '_synapse_tls_verifier' attribute to a ConnectionVerifier object.
tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where)
except: # noqa: E722, taken from the twisted implementation
logger.exception("Error during info_callback")
f = Failure()
tls_protocol.failVerification(f)
@implementer(IOpenSSLClientConnectionCreator) @implementer(IOpenSSLClientConnectionCreator)
class SSLClientConnectionCreator(object): class SSLClientConnectionCreator(object):
"""Creates openssl connection objects for client connections. """Creates openssl connection objects for client connections.

View file

@ -39,10 +39,8 @@ from synapse.logging.context import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
make_deferred_yieldable, make_deferred_yieldable,
preserve_fn,
) )
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,86 +55,6 @@ class FederationBase(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self._clock = hs.get_clock() self._clock = hs.get_clock()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: List[EventBase],
room_version: str,
outlier: bool = False,
include_none: bool = False,
):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
origin
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
res = yield self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
# This should not exist in the base implementation, until
# this is fixed, ignore it for typing. See issue #6997.
res = yield defer.ensureDeferred(
self.get_pdu( # type: ignore
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
)
except SynapseError:
pass
if not res:
logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id
)
return res
handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = yield make_deferred_yieldable(
defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
if include_none:
return valid_pdus
else:
return [p for p in valid_pdus if p]
def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred: def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
return make_deferred_yieldable( return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0] self._check_sigs_and_hashes(room_version, [pdu])[0]

View file

@ -33,6 +33,7 @@ from typing import (
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
@ -51,7 +52,7 @@ from synapse.api.room_versions import (
) )
from synapse.events import EventBase, builder from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
@ -345,6 +346,83 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids return state_event_ids, auth_event_ids
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: List[EventBase],
room_version: str,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
origin
pdu
room_version
outlier: Whether the events are outliers or not
include_none: Whether to include None in the returned list
for events that have failed their checks
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
res = yield self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
res = yield defer.ensureDeferred(
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version, # type: ignore
outlier=outlier,
timeout=10000,
)
)
except SynapseError:
pass
if not res:
logger.warning(
"Failed to find copy of %s with valid signature", pdu.event_id
)
return res
handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = await make_deferred_yieldable(
defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
if include_none:
return valid_pdus
else:
return [p for p in valid_pdus if p]
async def get_event_auth(self, destination, room_id, event_id): async def get_event_auth(self, destination, room_id, event_id):
res = await self.transport_layer.get_event_auth(destination, room_id, event_id) res = await self.transport_layer.get_event_auth(destination, room_id, event_id)

View file

@ -470,57 +470,6 @@ class FederationServer(FederationBase):
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
return 200, res return 200, res
async def on_query_auth_request(self, origin, content, room_id, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
missing (list): A list of event_ids indicating what the other
side (`origin`) think we're missing.
rejects (dict): A mapping from event_id to a 2-tuple of reason
string and a proof (or None) of why the event was rejected.
The keys of this dict give the list of events the `origin` has
rejected.
Args:
origin (str)
content (dict)
event_id (str)
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
room_version = await self.store.get_room_version(room_id)
auth_chain = [
event_from_pdu_json(e, room_version) for e in content["auth_chain"]
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version.identifier
)
ret = await self.handler.on_query_auth(
origin,
event_id,
room_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
)
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
return 200, send_content
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) return self.on_query_request("client_keys", content)

View file

@ -643,17 +643,6 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response return 200, response
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
async def on_POST(self, origin, content, query, context, event_id):
new_content = await self.handler.on_query_auth_request(
origin, content, context, event_id
)
return 200, new_content
class FederationGetMissingEventsServlet(BaseFederationServlet): class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional? # TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@ -1412,7 +1401,6 @@ FEDERATION_SERVLET_CLASSES = (
FederationV2SendLeaveServlet, FederationV2SendLeaveServlet,
FederationV1InviteServlet, FederationV1InviteServlet,
FederationV2InviteServlet, FederationV2InviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,

View file

@ -44,7 +44,11 @@ class AccountValidityHandler(object):
self._account_validity = self.hs.config.account_validity self._account_validity = self.hs.config.account_validity
if self._account_validity.renew_by_email_enabled and load_jinja2_templates: if (
self._account_validity.enabled
and self._account_validity.renew_by_email_enabled
and load_jinja2_templates
):
# Don't do email-specific configuration if renewal by email is disabled. # Don't do email-specific configuration if renewal by email is disabled.
try: try:
app_name = self.hs.config.email_app_name app_name = self.hs.config.email_app_name

View file

@ -18,10 +18,10 @@ import logging
import time import time
import unicodedata import unicodedata
import urllib.parse import urllib.parse
from typing import Any from typing import Any, Dict, Iterable, List, Optional
import attr import attr
import bcrypt import bcrypt # type: ignore[import]
import pymacaroons import pymacaroons
from twisted.internet import defer from twisted.internet import defer
@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler from ._base import BaseHandler
@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
""" """
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs) inst = auth_checker_class(hs)
if inst.is_enabled(): if inst.is_enabled():
self.checkers[inst.AUTH_TYPE] = inst self.checkers[inst.AUTH_TYPE] = inst # type: ignore
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip): def validate_user_via_ui_auth(
self, requester: Requester, request_body: Dict[str, Any], clientip: str
):
""" """
Checks that the user is who they claim to be, via a UI auth. Checks that the user is who they claim to be, via a UI auth.
@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn't stolen by re-authenticating them. that it isn't stolen by re-authenticating them.
Args: Args:
requester (Requester): The user, as given by the access token requester: The user, as given by the access token
request_body (dict): The body of the request sent by the client request_body: The body of the request sent by the client
clientip (str): The IP address of the client. clientip: The IP address of the client.
Returns: Returns:
defer.Deferred[dict]: the parameters for this request (which may defer.Deferred[dict]: the parameters for this request (which may
@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
return self.checkers.keys() return self.checkers.keys()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
):
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow. protocol and handles the User-Interactive Auth flow.
@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
decorator. decorator.
Args: Args:
flows (list): A list of login flows. Each flow is an ordered list of flows: A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full strings representing auth-types. At least one full
flow must be completed in order for auth to be successful. flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client. clientip: The IP address of the client.
Returns: Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of defer.Deferred[dict, dict, str]: a deferred tuple of
@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
""" """
authdict = None authdict = None
sid = None sid = None # type: Optional[str]
if clientdict and "auth" in clientdict: if clientdict and "auth" in clientdict:
authdict = clientdict["auth"] authdict = clientdict["auth"]
del clientdict["auth"] del clientdict["auth"]
@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
creds = session["creds"] creds = session["creds"]
# check auth type currently being presented # check auth type currently being presented
errordict = {} errordict = {} # type: Dict[str, Any]
if "type" in authdict: if "type" in authdict:
login_type = authdict["type"] login_type = authdict["type"] # type: str
try: try:
result = yield self._check_auth_dict(authdict, clientip) result = yield self._check_auth_dict(authdict, clientip)
if result: if result:
@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
raise InteractiveAuthIncompleteError(ret) raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip): def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
""" """
Adds the result of out-of-band authentication into an existing auth Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth. session. Currently used for adding the result of fallback auth.
@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
return True return True
return False return False
def get_session_id(self, clientdict): def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
""" """
Gets the session ID for a client given the client dictionary Gets the session ID for a client given the client dictionary
@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request clientdict: The dictionary sent by the client in the request
Returns: Returns:
str|None: The string session ID the client sent. If the client did The string session ID the client sent. If the client did
not send a session ID, returns None. not send a session ID, returns None.
""" """
sid = None sid = None
@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
sid = authdict["session"] sid = authdict["session"]
return sid return sid
def set_session_data(self, session_id, key, value): def set_session_data(self, session_id: str, key: str, value: Any) -> None:
""" """
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by request. This data is stored server-side and cannot be modified by
the client. the client.
Args: Args:
session_id (string): The ID of this session as returned from check_auth session_id: The ID of this session as returned from check_auth
key (string): The key to store the data under key: The key to store the data under
value (any): The data to store value: The data to store
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
sess.setdefault("serverdict", {})[key] = value sess.setdefault("serverdict", {})[key] = value
self._save_session(sess) self._save_session(sess)
def get_session_data(self, session_id, key, default=None): def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
""" """
Retrieve data stored with set_session_data Retrieve data stored with set_session_data
Args: Args:
session_id (string): The ID of this session as returned from check_auth session_id: The ID of this session as returned from check_auth
key (string): The key to store the data under key: The key to store the data under
default (any): Value to return if the key has not been set default: Value to return if the key has not been set
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default) return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip): def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
"""Attempt to validate the auth dict provided by a client """Attempt to validate the auth dict provided by a client
Args: Args:
authdict (object): auth dict provided by the client authdict: auth dict provided by the client
clientip (str): IP address of the client clientip: IP address of the client
Returns: Returns:
Deferred: result of the stage verification. Deferred: result of the stage verification.
@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict) (canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id return canonical_id
def _get_params_recaptcha(self): def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key} return {"public_key": self.hs.config.recaptcha_public_key}
def _get_params_terms(self): def _get_params_terms(self) -> dict:
return { return {
"policies": { "policies": {
"privacy_policy": { "privacy_policy": {
@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
} }
} }
def _auth_dict_for_flows(self, flows, session): def _auth_dict_for_flows(
self, flows: List[List[str]], session: Dict[str, Any]
) -> Dict[str, Any]:
public_flows = [] public_flows = []
for f in flows: for f in flows:
public_flows.append(f) public_flows.append(f)
@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms, LoginType.TERMS: self._get_params_terms,
} }
params = {} params = {} # type: Dict[str, Any]
for f in public_flows: for f in public_flows:
for stage in f: for stage in f:
@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
"params": params, "params": params,
} }
def _get_session_info(self, session_id): def _get_session_info(self, session_id: Optional[str]) -> dict:
"""
Gets or creates a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
if session_id not in self.sessions: if session_id not in self.sessions:
session_id = None session_id = None
@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks @defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already. The device will be recorded in the table if it is not there already.
Args: Args:
user_id (str): canonical User ID user_id: canonical User ID
device_id (str|None): the device ID to associate with the tokens. device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated: None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID) we should always have a device ID)
valid_until_ms (int|None): when the token is valid until. None for valid_until_ms: when the token is valid until. None for
no expiry. no expiry.
Returns: Returns:
The access token for the user's session. The access token for the user's session.
@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
return access_token return access_token
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def check_user_exists(self, user_id: str):
""" """
Checks to see if a user with the given id exists. Will check case Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches. insensitively, but return None if there are multiple inexact matches.
Args: Args:
(unicode|bytes) user_id: complete @user:id user_id: complete @user:id
Returns: Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or defer.Deferred: (unicode) canonical_user_id, or None if zero or
@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
return None return None
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id: str):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact insensitively, but will return None if there are multiple inexact
matches. matches.
@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
) )
return result return result
def get_supported_login_types(self): def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API """Get a the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is By default this is just 'm.login.password' (unless password_enabled is
@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types. other login types.
Returns: Returns:
Iterable[str]: login types login types
""" """
return self._supported_login_types return self._supported_login_types
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_login(self, username, login_submission): def validate_login(self, username: str, login_submission: Dict[str, Any]):
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate Also used by the user-interactive auth flow to validate
m.login.password auth types. m.login.password auth types.
Args: Args:
username (str): username supplied by the user username: username supplied by the user
login_submission (dict): the whole of the login submission login_submission: the whole of the login submission
(including 'type' and other relevant fields) (including 'type' and other relevant fields)
Returns: Returns:
Deferred[str, func]: canonical user id, and optional callback Deferred[str, func]: canonical user id, and optional callback
@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password): def check_password_provider_3pid(self, medium: str, address: str, password: str):
"""Check if a password provider is able to validate a thirdparty login """Check if a password provider is able to validate a thirdparty login
Args: Args:
medium (str): The medium of the 3pid (ex. email). medium: The medium of the 3pid (ex. email).
address (str): The address of the 3pid (ex. jdoe@example.com). address: The address of the 3pid (ex. jdoe@example.com).
password (str): The password of the user. password: The password of the user.
Returns: Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id, Deferred[(str|None, func|None)]: A tuple of `(user_id,
@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
return None, None return None, None
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (unicode): complete @user:id user_id: complete @user:id
password (unicode): the provided password password: the provided password
Returns: Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password unknown user/bad password
@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
return user_id return user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
return user_id return user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_token(self, access_token): def delete_access_token(self, access_token: str):
"""Invalidate a single access token """Invalidate a single access token
Args: Args:
access_token (str): access token to be deleted access_token: access token to be deleted
Returns: Returns:
Deferred Deferred
@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_tokens_for_user( def delete_access_tokens_for_user(
self, user_id, except_token_id=None, device_id=None self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
): ):
"""Invalidate access tokens belonging to a user """Invalidate access tokens belonging to a user
Args: Args:
user_id (str): ID of user the tokens belong to user_id: ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be except_token_id: access_token ID which should *not* be deleted
deleted device_id: ID of device the tokens are associated with.
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
Returns: Returns:
@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
# check if medium has a valid value # check if medium has a valid value
if medium not in ["email", "msisdn"]: if medium not in ["email", "msisdn"]:
raise SynapseError( raise SynapseError(
@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address, id_server=None): def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
):
"""Attempts to unbind the 3pid on the identity servers and deletes it """Attempts to unbind the 3pid on the identity servers and deletes it
from the local database. from the local database.
Args: Args:
user_id (str) user_id: ID of user to remove the 3pid from.
medium (str) medium: The medium of the 3pid being removed: "email" or "msisdn".
address (str) address: The 3pid address to remove.
id_server (str|None): Use the given identity server when unbinding id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known). identity server specified when binding (if known).
Returns: Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the the identity server, False if identity server doesn't support the
@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
yield self.store.user_delete_threepid(user_id, medium, address) yield self.store.user_delete_threepid(user_id, medium, address)
return result return result
def _save_session(self, session): def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec() session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session self.sessions[session["id"]] = session
def hash(self, password): def hash(self, password: str):
"""Computes a secure hash of password. """Computes a secure hash of password.
Args: Args:
password (unicode): Password to hash. password: Password to hash.
Returns: Returns:
Deferred(unicode): Hashed password. Deferred(unicode): Hashed password.
@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
return defer_to_thread(self.hs.get_reactor(), _do_hash) return defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password, stored_hash): def validate_hash(self, password: str, stored_hash: bytes):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
password (unicode): Password to hash. password: Password to hash.
stored_hash (bytes): Expected hash value. stored_hash: Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
hs = attr.ib() hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(
self, user_id: str, extra_caveats: Optional[List[str]] = None
) -> str:
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(
""" self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
) -> str:
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def generate_delete_pusher_token(self, user_id): def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher") macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize() return macaroon.serialize()
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",

View file

@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e: except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e) logger.warning("Unable to update PLs in old room: %s", e)
new_pl_content = copy_power_levels_contents(old_room_pl_state.content)
# pre-msc2260 rooms may not have the right setting for aliases. If no other
# value is set, set it now.
events_default = new_pl_content.get("events_default", 0)
new_pl_content.setdefault("events", {}).setdefault(
EventTypes.Aliases, events_default
)
logger.debug("Setting correct PLs in new room to %s", new_pl_content)
yield self.event_creation_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler):
"state_key": "", "state_key": "",
"room_id": new_room_id, "room_id": new_room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"content": new_pl_content, "content": old_room_pl_state.content,
}, },
ratelimit=False, ratelimit=False,
) )
@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler):
EventTypes.RoomHistoryVisibility: 100, EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50, EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50, EventTypes.RoomAvatar: 50,
# MSC2260: Allow everybody to send alias events by default
# This will be reudundant on pre-MSC2260 rooms, since the
# aliases event is special-cased.
EventTypes.Aliases: 0,
EventTypes.Tombstone: 100, EventTypes.Tombstone: 100,
EventTypes.ServerACL: 100, EventTypes.ServerACL: 100,
}, },

View file

@ -23,6 +23,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.types import ( from synapse.types import (
@ -73,6 +74,8 @@ class SamlHandler:
# a lock on the mappings # a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
self._error_html_content = hs.config.saml2_error_html_content
def handle_redirect_request(self, client_redirect_url): def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect """Handle an incoming request to /login/sso/redirect
@ -114,7 +117,22 @@ class SamlHandler:
# the dict. # the dict.
self.expire_sessions() self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) try:
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
except Exception as e:
# If decoding the response or mapping it to a user failed, then log the
# error and tell the user that something went wrong.
logger.error(e)
request.setResponseCode(400)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(
b"Content-Length", b"%d" % (len(self._error_html_content),)
)
request.write(self._error_html_content.encode("utf8"))
finish_request(request)
return
self._auth_handler.complete_sso_login(user_id, request, relay_state) self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):

View file

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler from ._base import BaseHandler
@ -32,14 +34,17 @@ class SetPasswordHandler(BaseHandler):
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(
self,
user_id: str,
new_password: str,
logout_devices: bool,
requester: Optional[Requester] = None,
):
if not self.hs.config.password_localdb_enabled: if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
password_hash = yield self._auth_handler.hash(newpassword) password_hash = yield self._auth_handler.hash(new_password)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@ -48,14 +53,18 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
# we want to log out all of the user's other sessions. First delete # Optionally, log out all of the user's other sessions.
# all his other devices. if logout_devices:
yield self._device_handler.delete_all_devices_for_user( except_device_id = requester.device_id if requester else None
user_id, except_device_id=except_device_id except_access_token_id = requester.access_token_id if requester else None
)
# and now delete any access tokens which weren't associated with # First delete all of their other devices.
# devices (or were associated with this device). yield self._device_handler.delete_all_devices_for_user(
yield self._auth_handler.delete_access_tokens_for_user( user_id, except_device_id=except_device_id
user_id, except_token_id=except_access_token_id )
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)

View file

@ -244,9 +244,6 @@ class SimpleHttpClient(object):
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60 pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = ProxyAgent( self.agent = ProxyAgent(
self.reactor, self.reactor,
connectTimeout=15, connectTimeout=15,

View file

@ -45,7 +45,7 @@ class MatrixFederationAgent(object):
Args: Args:
reactor (IReactor): twisted reactor to use for underlying requests reactor (IReactor): twisted reactor to use for underlying requests
tls_client_options_factory (ClientTLSOptionsFactory|None): tls_client_options_factory (FederationPolicyForHTTPS|None):
factory to use for fetching client tls options, or none to disable TLS. factory to use for fetching client tls options, or none to disable TLS.
_srv_resolver (SrvResolver|None): _srv_resolver (SrvResolver|None):

View file

@ -210,7 +210,7 @@ class LoggingContext(object):
class Sentinel(object): class Sentinel(object):
"""Sentinel to represent the root context""" """Sentinel to represent the root context"""
__slots__ = ["previous_context", "alive", "request", "scope"] __slots__ = ["previous_context", "alive", "request", "scope", "tag"]
def __init__(self) -> None: def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext # Minimal set for compatibility with LoggingContext
@ -218,6 +218,7 @@ class LoggingContext(object):
self.alive = None self.alive = None
self.request = None self.request = None
self.scope = None self.scope = None
self.tag = None
def __str__(self): def __str__(self):
return "sentinel" return "sentinel"
@ -511,7 +512,7 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context: Optional[LoggingContext] = None) -> None: def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
if new_context is None: if new_context is None:
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
else: else:

View file

@ -7,7 +7,7 @@
<body> <body>
<p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p> <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
<p> <p>
If you're seeing this page after clicking a link sent to you via email, make If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the sure you only click the confirmation link once, and that you open the
validation link in the same client you're logging in from. validation link in the same client you're logging in from.
</p> </p>
@ -24,19 +24,22 @@
// we just don't print anything specific. // we just don't print anything specific.
let searchStr = ""; let searchStr = "";
if (window.location.search) { if (window.location.search) {
// For some reason window.location.searchParams isn't always defined when // window.location.searchParams isn't always defined when
// window.location.search is, so we can't just use it right away. // window.location.search is, so it's more reliable to parse the latter.
searchStr = window.location.search; searchStr = window.location.search;
} else if (window.location.hash) { } else if (window.location.hash) {
// // Replace the # with a ? so that URLSearchParams does the right thing and
// doesn't parse the first parameter incorrectly.
searchStr = window.location.hash.replace("#", "?"); searchStr = window.location.hash.replace("#", "?");
} }
// We might end up with no error in the URL, so we need to check if we have one
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description") let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) { if (errorDesc) {
document.getElementById("errormsg").innerHTML = ` ("${errorDesc}")`;
document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
} }
</script> </script>
</body> </body>
</html> </html>

View file

@ -221,8 +221,9 @@ class UserRestServletV2(RestServlet):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
else: else:
new_password = body["password"] new_password = body["password"]
logout_devices = True
await self.set_password_handler.set_password( await self.set_password_handler.set_password(
target_user.to_string(), new_password, requester target_user.to_string(), new_password, logout_devices, requester
) )
if "deactivated" in body: if "deactivated" in body:
@ -536,9 +537,10 @@ class ResetPasswordRestServlet(RestServlet):
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"]) assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"] new_password = params["new_password"]
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password( await self._set_password_handler.set_password(
target_user_id, new_password, requester target_user_id, new_password, logout_devices, requester
) )
return 200, {} return 200, {}

View file

@ -189,12 +189,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
if event_type == EventTypes.Aliases:
# MSC2260
raise SynapseError(
400, "Cannot send m.room.aliases events via /rooms/{room_id}/state"
)
event_dict = { event_dict = {
"type": event_type, "type": event_type,
"content": content, "content": content,
@ -242,12 +236,6 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
if event_type == EventTypes.Aliases:
# MSC2260
raise SynapseError(
400, "Cannot send m.room.aliases events via /rooms/{room_id}/send"
)
event_dict = { event_dict = {
"type": event_type, "type": event_type,
"content": content, "content": content,

View file

@ -265,8 +265,11 @@ class PasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"]) assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"] new_password = params["new_password"]
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(user_id, new_password, requester) await self._set_password_handler.set_password(
user_id, new_password, logout_devices, requester
)
return 200, {} return 200, {}

View file

@ -30,6 +30,22 @@ from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# list all text content types that will have the charset default to UTF-8 when
# none is given
TEXT_CONTENT_TYPES = [
"text/css",
"text/csv",
"text/html",
"text/calendar",
"text/plain",
"text/javascript",
"application/json",
"application/ld+json",
"application/rtf",
"image/svg+xml",
"text/xml",
]
def parse_media_id(request): def parse_media_id(request):
try: try:
@ -96,7 +112,14 @@ def add_file_headers(request, media_type, file_size, upload_name):
def _quote(x): def _quote(x):
return urllib.parse.quote(x.encode("utf-8")) return urllib.parse.quote(x.encode("utf-8"))
request.setHeader(b"Content-Type", media_type.encode("UTF-8")) # Default to a UTF-8 charset for text content types.
# ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
if media_type.lower() in TEXT_CONTENT_TYPES:
content_type = media_type + "; charset=UTF-8"
else:
content_type = media_type
request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
if upload_name: if upload_name:
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`. # RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
# #

View file

@ -14,7 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import DirectServeResource, wrap_html_request_handler from synapse.http.server import (
DirectServeResource,
finish_request,
wrap_html_request_handler,
)
class SAML2ResponseResource(DirectServeResource): class SAML2ResponseResource(DirectServeResource):
@ -24,8 +28,20 @@ class SAML2ResponseResource(DirectServeResource):
def __init__(self, hs): def __init__(self, hs):
super().__init__() super().__init__()
self._error_html_content = hs.config.saml2_error_html_content
self._saml_handler = hs.get_saml_handler() self._saml_handler = hs.get_saml_handler()
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
request.setResponseCode(400)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
request.write(self._error_html_content.encode("utf8"))
finish_request(request)
@wrap_html_request_handler @wrap_html_request_handler
async def _async_render_POST(self, request): async def _async_render_POST(self, request):
return await self._saml_handler.handle_saml_response(request) return await self._saml_handler.handle_saml_response(request)

View file

@ -26,7 +26,6 @@ import logging
import os import os
from twisted.mail.smtp import sendmail from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.filtering import Filtering from synapse.api.filtering import Filtering
@ -35,6 +34,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.crypto.context_factory import RegularPolicyForHTTPS
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker from synapse.events.spamcheck import SpamChecker
@ -310,7 +310,7 @@ class HomeServer(object):
return ( return (
InsecureInterceptableContextFactory() InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS() else RegularPolicyForHTTPS()
) )
def build_simple_http_client(self): def build_simple_http_client(self):
@ -420,7 +420,7 @@ class HomeServer(object):
return PusherPool(self) return PusherPool(self)
def build_http_client(self): def build_http_client(self):
tls_client_options_factory = context_factory.ClientTLSOptionsFactory( tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
self.config self.config
) )
return MatrixFederationHttpClient(self, tls_client_options_factory) return MatrixFederationHttpClient(self, tls_client_options_factory)

View file

@ -662,28 +662,16 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]): def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Gets the full auth chain for a set of events (including rejected """Given sets of state events figure out the auth chain difference (as
events). per state res v2 algorithm).
Includes the given event IDs in the result.
Note that:
1. All events must be state events.
2. For v1 rooms this may not have the full auth chain in the
presence of rejected events
Args:
event_ids: The event IDs of the events to fetch the auth chain for.
Must be state events.
ignore_events: Set of events to exclude from the returned auth
chain.
This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
Returns: Returns:
Deferred[list[str]]: List of event IDs of the auth chain. Deferred[Set[str]]: Set of event IDs.
""" """
return self.store.get_auth_chain_ids( return self.store.get_auth_chain_difference(state_sets)
event_ids, include_given=True, ignore_events=ignore_events,
)

View file

@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns: Returns:
Deferred[set[str]]: Set of event IDs Deferred[set[str]]: Set of event IDs
""" """
common = set(itervalues(state_sets[0])).intersection(
*(itervalues(s) for s in state_sets[1:]) difference = yield state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
) )
auth_sets = [] return difference
for state_set in state_sets:
auth_ids = {
eid
for key, eid in iteritems(state_set)
if (
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
or key
in (
(EventTypes.PowerLevels, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
)
)
and eid not in common
}
auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
auth_ids.update(auth_chain)
auth_sets.append(auth_ids)
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
union = set().union(*auth_sets)
return union - intersection
def _seperate(state_sets): def _seperate(state_sets):

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
from typing import List, Optional, Set from typing import Dict, List, Optional, Set, Tuple
from six.moves.queue import Empty, PriorityQueue from six.moves.queue import Empty, PriorityQueue
@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results) return list(results)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
Returns:
Deferred[Set[str]]
"""
return self.db.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
#
# The idea here is to basically walk the auth graph of each state set in
# tandem, keeping track of which auth events are reachable by each state
# set. If we reach an auth event we've already visited (via a different
# state set) then we mark that auth event and all ancestors as reachable
# by the state set. This requires that we keep track of the auth chains
# in memory.
#
# Doing it in a such a way means that we can stop early if all auth
# events we're currently walking are reachable by all state sets.
#
# *Note*: We can't stop walking an event's auth chain if it is reachable
# by all state sets. This is because other auth chains we're walking
# might be reachable only via the original auth chain. For example,
# given the following auth chain:
#
# A -> C -> D -> E
# / /
# B -´---------´
#
# and state sets {A} and {B} then walking the auth chains of A and B
# would immediately show that C is reachable by both. However, if we
# stopped at C then we'd only reach E via the auth chain of B and so E
# would errornously get included in the returned difference.
#
# The other thing that we do is limit the number of auth chains we walk
# at once, due to practical limits (i.e. we can only query the database
# with a limited set of parameters). We pick the auth chains we walk
# each iteration based on their depth, in the hope that events with a
# lower depth are likely reachable by those with higher depths.
#
# We could use any ordering that we believe would give a rough
# topological ordering, e.g. origin server timestamp. If the ordering
# chosen is not topological then the algorithm still produces the right
# result, but perhaps a bit more inefficiently. This is why it is safe
# to use "depth" here.
initial_events = set(state_sets[0]).union(*state_sets[1:])
# Dict from events in auth chains to which sets *cannot* reach them.
# I.e. if the set is empty then all sets can reach the event.
event_to_missing_sets = {
event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
for event_id in initial_events
}
# We need to get the depth of the initial events for sorting purposes.
sql = """
SELECT depth, event_id FROM events
WHERE %s
ORDER BY depth ASC
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", initial_events
)
txn.execute(sql % (clause,), args)
# The sorted list of events whose auth chains we should walk.
search = txn.fetchall() # type: List[Tuple[int, str]]
# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
base_sql = """
SELECT a.event_id, auth_id, depth
FROM event_auth AS a
INNER JOIN events AS e ON (e.event_id = a.auth_id)
WHERE
"""
while search:
# Check whether all our current walks are reachable by all state
# sets. If so we can bail.
if all(not event_to_missing_sets[eid] for _, eid in search):
break
# Fetch the auth events and their depths of the N last events we're
# currently walking
search, chunk = search[:-100], search[-100:]
clause, args = make_in_list_sql_clause(
txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
)
txn.execute(base_sql + clause, args)
for event_id, auth_event_id, auth_event_depth in txn:
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
sets = event_to_missing_sets.get(auth_event_id)
if sets is None:
# First time we're seeing this event, so we add it to the
# queue of things to fetch.
search.append((auth_event_depth, auth_event_id))
# Assume that this event is unreachable from any of the
# state sets until proven otherwise
sets = event_to_missing_sets[auth_event_id] = set(
range(len(state_sets))
)
else:
# We've previously seen this event, so look up its auth
# events and recursively mark all ancestors as reachable
# by the current event's state set.
a_ids = event_to_auth_events.get(auth_event_id)
while a_ids:
new_aids = set()
for a_id in a_ids:
event_to_missing_sets[a_id].intersection_update(
event_to_missing_sets[event_id]
)
b = event_to_auth_events.get(a_id)
if b:
new_aids.update(b)
a_ids = new_aids
# Mark that the auth event is reachable by the approriate sets.
sets.intersection_update(event_to_missing_sets[event_id])
search.sort()
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id): def get_oldest_events_in_room(self, room_id):
return self.db.runInteraction( return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id

View file

@ -29,7 +29,11 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.logging.context import (
LoggingContext,
LoggingContextOrSentinel,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
@ -543,7 +547,9 @@ class Database(object):
Returns: Returns:
Deferred: The result of func Deferred: The result of func
""" """
parent_context = LoggingContext.current_context() parent_context = (
LoggingContext.current_context()
) # type: Optional[LoggingContextOrSentinel]
if parent_context == LoggingContext.sentinel: if parent_context == LoggingContext.sentinel:
logger.warning( logger.warning(
"Starting db connection from sentinel context: metrics will be lost" "Starting db connection from sentinel context: metrics will be lost"

View file

@ -118,30 +118,36 @@ def filter_events_for_client(
the original event if they can see it as normal. the original event if they can see it as normal.
""" """
if event.type == "org.matrix.dummy_event" and filter_send_to_client: # Only run some checks if these events aren't about to be sent to clients. This is
return None # because, if this is not the case, we're probably only checking if the users can
# see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks.
if filter_send_to_client:
if event.type == "org.matrix.dummy_event":
return None
if not event.is_state() and event.sender in ignore_list and filter_send_to_client: if not event.is_state() and event.sender in ignore_list:
return None return None
# Until MSC2261 has landed we can't redact malicious alias events, so for # Until MSC2261 has landed we can't redact malicious alias events, so for
# now we temporarily filter out m.room.aliases entirely to mitigate # now we temporarily filter out m.room.aliases entirely to mitigate
# abuse, while we spec a better solution to advertising aliases # abuse, while we spec a better solution to advertising aliases
# on rooms. # on rooms.
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
return None return None
# Don't try to apply the room's retention policy if the event is a state event, as # Don't try to apply the room's retention policy if the event is a state
# MSC1763 states that retention is only considered for non-state events. # event, as MSC1763 states that retention is only considered for non-state
if filter_send_to_client and not event.is_state(): # events.
retention_policy = retention_policies[event.room_id] if not event.is_state():
max_lifetime = retention_policy.get("max_lifetime") retention_policy = retention_policies[event.room_id]
max_lifetime = retention_policy.get("max_lifetime")
if max_lifetime is not None: if max_lifetime is not None:
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
if event.origin_server_ts < oldest_allowed_ts: if event.origin_server_ts < oldest_allowed_ts:
return None return None
if event.event_id in always_include_ids: if event.event_id in always_include_ids:
return event return event

View file

@ -23,7 +23,7 @@ from OpenSSL import SSL
from synapse.config._base import Config, RootConfig from synapse.config._base import Config, RootConfig
from synapse.config.tls import ConfigError, TlsConfig from synapse.config.tls import ConfigError, TlsConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.crypto.context_factory import FederationPolicyForHTTPS
from tests.unittest import TestCase from tests.unittest import TestCase
@ -180,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig() t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="") t.read_config(config, config_dir_path="", data_dir_path="")
cf = ClientTLSOptionsFactory(t) cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0)
self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_tls_client_minimum_set_passed_through_1_0(self): def test_tls_client_minimum_set_passed_through_1_0(self):
""" """
@ -195,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig() t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="") t.read_config(config, config_dir_path="", data_dir_path="")
cf = ClientTLSOptionsFactory(t) cf = FederationPolicyForHTTPS(t)
options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has not had any of the NO_TLS set. # The context has not had any of the NO_TLS set.
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_acme_disabled_in_generated_config_no_acme_domain_provied(self): def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
""" """
@ -273,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig() t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="") t.read_config(config, config_dir_path="", data_dir_path="")
cf = ClientTLSOptionsFactory(t) cf = FederationPolicyForHTTPS(t)
# Not in the whitelist # Not in the whitelist
opts = cf.get_options(b"notexample.com") opts = cf.get_options(b"notexample.com")
@ -282,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
# Caught by the wildcard # Caught by the wildcard
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
self.assertFalse(opts._verifier._verify_certs) self.assertFalse(opts._verifier._verify_certs)
def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
"""get the options bits from an openssl context object"""
# the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
# use the low-level interface
return SSL._lib.SSL_CTX_get_options(ssl_context._context)

View file

@ -31,7 +31,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.well_known_resolver import ( from synapse.http.federation.well_known_resolver import (
@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self._config = config = HomeServerConfig() self._config = config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "") config.parse_config_dict(config_dict, "", "")
self.tls_factory = ClientTLSOptionsFactory(config) self.tls_factory = FederationPolicyForHTTPS(config)
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
@ -715,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
config = default_config("test", parse=True) config = default_config("test", parse=True)
# Build a new agent and WellKnownResolver with a different tls factory # Build a new agent and WellKnownResolver with a different tls factory
tls_factory = ClientTLSOptionsFactory(config) tls_factory = FederationPolicyForHTTPS(config)
agent = MatrixFederationAgent( agent = MatrixFederationAgent(
reactor=self.reactor, reactor=self.reactor,
tls_client_options_factory=tls_factory, tls_client_options_factory=tls_factory,

View file

@ -868,6 +868,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Set this new alias as the canonical alias for this room # Set this new alias as the canonical alias for this room
self.helper.send_state(
room_id,
"m.room.aliases",
{"aliases": [test_alias]},
tok=self.admin_user_tok,
state_key="test",
)
self.helper.send_state( self.helper.send_state(
room_id, room_id,
"m.room.canonical_alias", "m.room.canonical_alias",

View file

@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "test") self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test") self.user_tok = self.login("user", "test")
def test_cannot_set_alias_via_state_event(self): def test_state_event_not_in_room(self):
self.ensure_user_joined_room() self.ensure_user_left_room()
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( self.set_alias_via_state_event(403)
self.room_id,
self.hs.hostname,
)
data = {"aliases": [self.random_alias(5)]}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.render(request)
self.assertEqual(channel.code, 400, channel.result)
def test_directory_endpoint_not_in_room(self): def test_directory_endpoint_not_in_room(self):
self.ensure_user_left_room() self.ensure_user_left_room()
self.set_alias_via_directory(403) self.set_alias_via_directory(403)
def test_state_event_in_room_too_long(self):
self.ensure_user_joined_room()
self.set_alias_via_state_event(400, alias_length=256)
def test_directory_in_room_too_long(self): def test_directory_in_room_too_long(self):
self.ensure_user_joined_room() self.ensure_user_joined_room()
self.set_alias_via_directory(400, alias_length=256) self.set_alias_via_directory(400, alias_length=256)
def test_state_event_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_state_event(200)
def test_directory_in_room(self): def test_directory_in_room(self):
self.ensure_user_joined_room() self.ensure_user_joined_room()
self.set_alias_via_directory(200) self.set_alias_via_directory(200)
@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
def set_alias_via_state_event(self, expected_code, alias_length=5):
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
self.room_id,
self.hs.hostname,
)
data = {"aliases": [self.random_alias(alias_length)]}
request_data = json.dumps(data)
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def set_alias_via_directory(self, expected_code, alias_length=5): def set_alias_via_directory(self, expected_code, alias_length=5):
url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
data = {"room_id": self.room_id} data = {"room_id": self.room_id}

View file

@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
def get_auth_chain(self, event_ids, ignore_events): def _get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected """Gets the full auth chain for a set of events (including rejected
events). events).
@ -617,9 +617,6 @@ class TestStateResolutionStore(object):
Args: Args:
event_ids (list): The event IDs of the events to fetch the auth event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events. chain for. Must be state events.
ignore_events: Set of events to exclude from the returned auth
chain.
Returns: Returns:
Deferred[list[str]]: List of event IDs of the auth chain. Deferred[list[str]]: List of event IDs of the auth chain.
""" """
@ -629,7 +626,7 @@ class TestStateResolutionStore(object):
stack = list(event_ids) stack = list(event_ids)
while stack: while stack:
event_id = stack.pop() event_id = stack.pop()
if event_id in result or event_id in ignore_events: if event_id in result:
continue continue
result.add(event_id) result.add(event_id)
@ -639,3 +636,9 @@ class TestStateResolutionStore(object):
stack.append(aid) stack.append(aid)
return list(result) return list(result)
def get_auth_chain_difference(self, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
return set(chains[0]).union(*chains[1:]) - common

View file

@ -13,19 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import tests.unittest import tests.unittest
import tests.utils import tests.utils
class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_prev_events_for_room(self): def test_get_prev_events_for_room(self):
room_id = "@ROOM:local" room_id = "@ROOM:local"
@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
) )
for i in range(0, 20): for i in range(0, 20):
yield self.store.db.runInteraction("insert", insert_event, i) self.get_success(self.store.db.runInteraction("insert", insert_event, i))
# this should get the last ten # this should get the last ten
r = yield self.store.get_prev_events_for_room(room_id) r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r)) self.assertEqual(10, len(r))
for i in range(0, 10): for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i]) self.assertEqual("$event_%i:local" % (19 - i), r[i])
@defer.inlineCallbacks
def test_get_rooms_with_many_extremities(self): def test_get_rooms_with_many_extremities(self):
room1 = "#room1" room1 = "#room1"
room2 = "#room2" room2 = "#room2"
@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
) )
for i in range(0, 20): for i in range(0, 20):
yield self.store.db.runInteraction("insert", insert_event, i, room1) self.get_success(
yield self.store.db.runInteraction("insert", insert_event, i, room2) self.store.db.runInteraction("insert", insert_event, i, room1)
yield self.store.db.runInteraction("insert", insert_event, i, room3) )
self.get_success(
self.store.db.runInteraction("insert", insert_event, i, room2)
)
self.get_success(
self.store.db.runInteraction("insert", insert_event, i, room3)
)
# Test simple case # Test simple case
r = yield self.store.get_rooms_with_many_extremities(5, 5, []) r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
self.assertEqual(len(r), 3) self.assertEqual(len(r), 3)
# Does filter work? # Does filter work?
r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1]) r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
self.assertTrue(room2 in r) self.assertTrue(room2 in r)
self.assertTrue(room3 in r) self.assertTrue(room3 in r)
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) r = self.get_success(
self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
)
self.assertEqual(r, [room3]) self.assertEqual(r, [room3])
# Does filter and limit work? # Does filter and limit work?
r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1]) r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3]) self.assertTrue(r == [room2] or r == [room3])
def test_auth_difference(self):
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
# where the top are the most recent events.
#
# A B
# \ /
# D E
# \ |
# ` F C
# | /|
# G ´ |
# | \ |
# H I
# | |
# K J
auth_graph = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
"d": ["f"],
"e": ["f"],
"f": ["g"],
"g": ["h", "i"],
"h": ["k"],
"i": ["j"],
"k": [],
"j": [],
}
depth_map = {
"a": 7,
"b": 7,
"c": 4,
"d": 6,
"e": 6,
"f": 5,
"g": 3,
"h": 2,
"i": 2,
"k": 1,
"j": 1,
}
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
def insert_event(txn, event_id, stream_ordering):
depth = depth_map[event_id]
self.store.db.simple_insert_txn(
txn,
table="events",
values={
"event_id": event_id,
"room_id": room_id,
"depth": depth,
"topological_ordering": depth,
"type": "m.test",
"processed": True,
"outlier": False,
"stream_ordering": stream_ordering,
},
)
self.store.db.simple_insert_many_txn(
txn,
table="event_auth",
values=[
{"event_id": event_id, "room_id": room_id, "auth_id": a}
for a in auth_graph[event_id]
],
)
next_stream_ordering = 0
for event_id in auth_graph:
next_stream_ordering += 1
self.get_success(
self.store.db.runInteraction(
"insert", insert_event, event_id, next_stream_ordering
)
)
# Now actually test that various combinations give the right result:
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
)
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
)
self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
self.assertSetEqual(difference, set())

View file

@ -185,6 +185,7 @@ commands = mypy \
synapse/federation/federation_client.py \ synapse/federation/federation_client.py \
synapse/federation/sender \ synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
synapse/handlers/auth.py \
synapse/handlers/directory.py \ synapse/handlers/directory.py \
synapse/handlers/presence.py \ synapse/handlers/presence.py \
synapse/handlers/sync.py \ synapse/handlers/sync.py \