Support icons for Identity Providers (#9154)

This commit is contained in:
Richard van der Hoff 2021-01-20 13:15:14 +00:00 committed by GitHub
parent 6c0dfd2e8e
commit 0cd2938bc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 146 additions and 91 deletions

1
changelog.d/9154.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View file

@ -1726,6 +1726,10 @@ saml2_config:
# idp_name: A user-facing name for this identity provider, which is used to # idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms. # offer the user a choice of login mechanisms.
# #
# idp_icon: An optional icon for this identity provider, which is presented
# by identity picker pages. If given, must be an MXC URI of the format
# mxc://<server-name>/<media-id>
#
# discover: set to 'false' to disable the use of the OIDC discovery mechanism # discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true. # to discover endpoints. Defaults to true.
# #

View file

@ -100,6 +100,7 @@ files =
synapse/util/async_helpers.py, synapse/util/async_helpers.py,
synapse/util/caches, synapse/util/caches,
synapse/util/metrics.py, synapse/util/metrics.py,
synapse/util/stringutils.py,
tests/replication, tests/replication,
tests/test_utils, tests/test_utils,
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,

View file

@ -23,6 +23,7 @@ from synapse.config._util import validate_config
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import Collection, JsonDict from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -66,6 +67,10 @@ class OIDCConfig(Config):
# idp_name: A user-facing name for this identity provider, which is used to # idp_name: A user-facing name for this identity provider, which is used to
# offer the user a choice of login mechanisms. # offer the user a choice of login mechanisms.
# #
# idp_icon: An optional icon for this identity provider, which is presented
# by identity picker pages. If given, must be an MXC URI of the format
# mxc://<server-name>/<media-id>
#
# discover: set to 'false' to disable the use of the OIDC discovery mechanism # discover: set to 'false' to disable the use of the OIDC discovery mechanism
# to discover endpoints. Defaults to true. # to discover endpoints. Defaults to true.
# #
@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"properties": { "properties": {
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
"idp_name": {"type": "string"}, "idp_name": {"type": "string"},
"idp_icon": {"type": "string"},
"discover": {"type": "boolean"}, "discover": {"type": "boolean"},
"issuer": {"type": "string"}, "issuer": {"type": "string"},
"client_id": {"type": "string"}, "client_id": {"type": "string"},
@ -336,9 +342,20 @@ def _parse_oidc_config_dict(
config_path + ("idp_id",), config_path + ("idp_id",),
) )
# MSC2858 also specifies that the idp_icon must be a valid MXC uri
idp_icon = oidc_config.get("idp_icon")
if idp_icon is not None:
try:
parse_and_validate_mxc_uri(idp_icon)
except ValueError as e:
raise ConfigError(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e
return OidcProviderConfig( return OidcProviderConfig(
idp_id=idp_id, idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"), idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
discover=oidc_config.get("discover", True), discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"], issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"], client_id=oidc_config["client_id"],
@ -366,6 +383,9 @@ class OidcProviderConfig:
# user-facing name for this identity provider. # user-facing name for this identity provider.
idp_name = attr.ib(type=str) idp_name = attr.ib(type=str)
# Optional MXC URI for icon for this IdP.
idp_icon = attr.ib(type=Optional[str])
# whether the OIDC discovery mechanism is used to discover endpoints # whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool) discover = attr.ib(type=bool)

View file

@ -26,7 +26,7 @@ import yaml
from netaddr import IPSet from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError from ._base import Config, ConfigError

View file

@ -49,7 +49,6 @@ from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer

View file

@ -28,7 +28,6 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX, FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX, FEDERATION_V2_PREFIX,
) )
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_boolean_from_args, parse_boolean_from_args,
@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -80,6 +80,10 @@ class CasHandler:
# user-facing name of this auth provider # user-facing name of this auth provider
self.idp_name = "CAS" self.idp_name = "CAS"
# we do not currently support icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self) self._sso_handler.register_identity_provider(self)

View file

@ -271,6 +271,9 @@ class OidcProvider:
# user-facing name of this auth provider # user-facing name of this auth provider
self.idp_name = provider.idp_name self.idp_name = provider.idp_name
# MXC URI for icon for this auth provider
self.idp_icon = provider.idp_icon
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self) self._sso_handler.register_identity_provider(self)

View file

@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
@ -55,6 +54,7 @@ from synapse.types import (
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler

View file

@ -78,6 +78,10 @@ class SamlHandler(BaseHandler):
# user-facing name of this auth provider # user-facing name of this auth provider
self.idp_name = "SAML" self.idp_name = "SAML"
# we do not currently support icons for SAML auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]

View file

@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol):
def idp_name(self) -> str: def idp_name(self) -> str:
"""User-facing name for this provider""" """User-facing name for this provider"""
@property
def idp_icon(self) -> Optional[str]:
"""Optional MXC URI for user-facing icon"""
return None
@abc.abstractmethod @abc.abstractmethod
async def handle_redirect_request( async def handle_redirect_request(
self, self,

View file

@ -1,79 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
logger = logging.getLogger(__name__)
def parse_server_name(server_name):
"""Split a server name into host/port parts.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
domain_port = server_name.rsplit(":", 1)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
return domain, port
except Exception:
raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name):
"""Split a server name into host/port parts and do some basic validation.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
host, port = parse_server_name(server_name)
# these tests don't need to be bulletproof as we'll find out soon enough
# if somebody is giving us invalid data. What we *do* need is to be sure
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError(
"Server name '%s' contains invalid characters" % (server_name,)
)
return host, port

View file

@ -17,6 +17,9 @@
<li> <li>
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}"> <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
<label for="prov{{loop.index}}">{{p.idp_name | e}}</label> <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
{% if p.idp_icon %}
<img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
{% endif %}
</li> </li>
{% endfor %} {% endfor %}
</ul> </ul>

View file

@ -32,7 +32,6 @@ from synapse.api.errors import (
) )
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.stringutils import random_string from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.server import synapse.server

View file

@ -16,7 +16,6 @@
import collections import collections
import logging import logging
import re
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of tuples where the key is The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID. the hostname and the value is the media ID.
""" """
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
sql = """ sql = """
SELECT stream_ordering, json FROM events SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id) JOIN event_json USING (room_id, event_id)
@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
for url in (content_url, thumbnail_url): for url in (content_url, thumbnail_url):
if not url: if not url:
continue continue
matches = mxc_re.match(url) matches = MXC_REGEX.match(url)
if matches: if matches:
hostname = matches.group(1) hostname = matches.group(1)
media_id = matches.group(2) media_id = matches.group(2)

View file

@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.endpoint import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService from synapse.appservice.api import ApplicationService

View file

@ -18,6 +18,7 @@ import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Tuple
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
# says "there is no grammar for media ids"
#
# The server_name part of this is purposely lax: use parse_and_validate_mxc for
# additional validation.
#
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
# random_string and random_string_with_symbols are used for a range of things, # random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure # some cryptographically important, some less so. We use SystemRandom to make sure
# we get cryptographically-secure randoms. # we get cryptographically-secure randoms.
@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
) )
def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
"""Split a server name into host/port parts.
Args:
server_name: server name to parse
Returns:
host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
domain_port = server_name.rsplit(":", 1)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
return domain, port
except Exception:
raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
"""Split a server name into host/port parts and do some basic validation.
Args:
server_name: server name to parse
Returns:
host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
host, port = parse_server_name(server_name)
# these tests don't need to be bulletproof as we'll find out soon enough
# if somebody is giving us invalid data. What we *do* need is to be sure
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError(
"Server name '%s' contains invalid characters" % (server_name,)
)
return host, port
def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
"""Parse the given string as an MXC URI
Checks that the "server name" part is a valid server name
Args:
mxc: the (alleged) MXC URI to be checked
Returns:
hostname, port, media id
Raises:
ValueError if the URI cannot be parsed
"""
m = MXC_REGEX.match(mxc)
if not m:
raise ValueError("mxc URI %r did not match expected format" % (mxc,))
server_name = m.group(1)
media_id = m.group(2)
host, port = parse_and_validate_server_name(server_name)
return host, port, media_id
def shortstr(iterable: Iterable, maxitems: int = 5) -> str: def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
"""If iterable has maxitems or fewer, return the stringification of a list """If iterable has maxitems or fewer, return the stringification of a list
containing those items. containing those items.

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
from tests import unittest from tests import unittest