mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 15:43:22 +01:00
Support icons for Identity Providers (#9154)
This commit is contained in:
parent
6c0dfd2e8e
commit
0cd2938bc8
19 changed files with 146 additions and 91 deletions
1
changelog.d/9154.feature
Normal file
1
changelog.d/9154.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for multiple SSO Identity Providers.
|
|
@ -1726,6 +1726,10 @@ saml2_config:
|
|||
# idp_name: A user-facing name for this identity provider, which is used to
|
||||
# offer the user a choice of login mechanisms.
|
||||
#
|
||||
# idp_icon: An optional icon for this identity provider, which is presented
|
||||
# by identity picker pages. If given, must be an MXC URI of the format
|
||||
# mxc://<server-name>/<media-id>
|
||||
#
|
||||
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||
# to discover endpoints. Defaults to true.
|
||||
#
|
||||
|
|
1
mypy.ini
1
mypy.ini
|
@ -100,6 +100,7 @@ files =
|
|||
synapse/util/async_helpers.py,
|
||||
synapse/util/caches,
|
||||
synapse/util/metrics.py,
|
||||
synapse/util/stringutils.py,
|
||||
tests/replication,
|
||||
tests/test_utils,
|
||||
tests/handlers/test_password_providers.py,
|
||||
|
|
|
@ -23,6 +23,7 @@ from synapse.config._util import validate_config
|
|||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.types import Collection, JsonDict
|
||||
from synapse.util.module_loader import load_module
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
@ -66,6 +67,10 @@ class OIDCConfig(Config):
|
|||
# idp_name: A user-facing name for this identity provider, which is used to
|
||||
# offer the user a choice of login mechanisms.
|
||||
#
|
||||
# idp_icon: An optional icon for this identity provider, which is presented
|
||||
# by identity picker pages. If given, must be an MXC URI of the format
|
||||
# mxc://<server-name>/<media-id>
|
||||
#
|
||||
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||
# to discover endpoints. Defaults to true.
|
||||
#
|
||||
|
@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
|||
"properties": {
|
||||
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
|
||||
"idp_name": {"type": "string"},
|
||||
"idp_icon": {"type": "string"},
|
||||
"discover": {"type": "boolean"},
|
||||
"issuer": {"type": "string"},
|
||||
"client_id": {"type": "string"},
|
||||
|
@ -336,9 +342,20 @@ def _parse_oidc_config_dict(
|
|||
config_path + ("idp_id",),
|
||||
)
|
||||
|
||||
# MSC2858 also specifies that the idp_icon must be a valid MXC uri
|
||||
idp_icon = oidc_config.get("idp_icon")
|
||||
if idp_icon is not None:
|
||||
try:
|
||||
parse_and_validate_mxc_uri(idp_icon)
|
||||
except ValueError as e:
|
||||
raise ConfigError(
|
||||
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
||||
) from e
|
||||
|
||||
return OidcProviderConfig(
|
||||
idp_id=idp_id,
|
||||
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||
idp_icon=idp_icon,
|
||||
discover=oidc_config.get("discover", True),
|
||||
issuer=oidc_config["issuer"],
|
||||
client_id=oidc_config["client_id"],
|
||||
|
@ -366,6 +383,9 @@ class OidcProviderConfig:
|
|||
# user-facing name for this identity provider.
|
||||
idp_name = attr.ib(type=str)
|
||||
|
||||
# Optional MXC URI for icon for this IdP.
|
||||
idp_icon = attr.ib(type=Optional[str])
|
||||
|
||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||
discover = attr.ib(type=bool)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import yaml
|
|||
from netaddr import IPSet
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
|
|
@ -49,7 +49,6 @@ from synapse.events import EventBase
|
|||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.federation.persistence import TransactionActions
|
||||
from synapse.federation.units import Edu, Transaction
|
||||
from synapse.http.endpoint import parse_server_name
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
|
@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
|
|||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import parse_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
|
|
@ -28,7 +28,6 @@ from synapse.api.urls import (
|
|||
FEDERATION_V1_PREFIX,
|
||||
FEDERATION_V2_PREFIX,
|
||||
)
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import (
|
||||
parse_boolean_from_args,
|
||||
|
@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
|
|||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -80,6 +80,10 @@ class CasHandler:
|
|||
# user-facing name of this auth provider
|
||||
self.idp_name = "CAS"
|
||||
|
||||
# we do not currently support icons for CAS auth, but this is required by
|
||||
# the SsoIdentityProvider protocol type.
|
||||
self.idp_icon = None
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
self._sso_handler.register_identity_provider(self)
|
||||
|
|
|
@ -271,6 +271,9 @@ class OidcProvider:
|
|||
# user-facing name of this auth provider
|
||||
self.idp_name = provider.idp_name
|
||||
|
||||
# MXC URI for icon for this auth provider
|
||||
self.idp_icon = provider.idp_icon
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
self._sso_handler.register_identity_provider(self)
|
||||
|
|
|
@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import copy_power_levels_contents
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
|
@ -55,6 +54,7 @@ from synapse.types import (
|
|||
from synapse.util import stringutils
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
|
|
@ -78,6 +78,10 @@ class SamlHandler(BaseHandler):
|
|||
# user-facing name of this auth provider
|
||||
self.idp_name = "SAML"
|
||||
|
||||
# we do not currently support icons for SAML auth, but this is required by
|
||||
# the SsoIdentityProvider protocol type.
|
||||
self.idp_icon = None
|
||||
|
||||
# a map from saml session id to Saml2SessionData object
|
||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||
|
||||
|
|
|
@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol):
|
|||
def idp_name(self) -> str:
|
||||
"""User-facing name for this provider"""
|
||||
|
||||
@property
|
||||
def idp_icon(self) -> Optional[str]:
|
||||
"""Optional MXC URI for user-facing icon"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_redirect_request(
|
||||
self,
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_server_name(server_name):
|
||||
"""Split a server name into host/port parts.
|
||||
|
||||
Args:
|
||||
server_name (str): server name to parse
|
||||
|
||||
Returns:
|
||||
Tuple[str, int|None]: host/port parts.
|
||||
|
||||
Raises:
|
||||
ValueError if the server name could not be parsed.
|
||||
"""
|
||||
try:
|
||||
if server_name[-1] == "]":
|
||||
# ipv6 literal, hopefully
|
||||
return server_name, None
|
||||
|
||||
domain_port = server_name.rsplit(":", 1)
|
||||
domain = domain_port[0]
|
||||
port = int(domain_port[1]) if domain_port[1:] else None
|
||||
return domain, port
|
||||
except Exception:
|
||||
raise ValueError("Invalid server name '%s'" % server_name)
|
||||
|
||||
|
||||
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
|
||||
|
||||
|
||||
def parse_and_validate_server_name(server_name):
|
||||
"""Split a server name into host/port parts and do some basic validation.
|
||||
|
||||
Args:
|
||||
server_name (str): server name to parse
|
||||
|
||||
Returns:
|
||||
Tuple[str, int|None]: host/port parts.
|
||||
|
||||
Raises:
|
||||
ValueError if the server name could not be parsed.
|
||||
"""
|
||||
host, port = parse_server_name(server_name)
|
||||
|
||||
# these tests don't need to be bulletproof as we'll find out soon enough
|
||||
# if somebody is giving us invalid data. What we *do* need is to be sure
|
||||
# that nobody is sneaking IP literals in that look like hostnames, etc.
|
||||
|
||||
# look for ipv6 literals
|
||||
if host[0] == "[":
|
||||
if host[-1] != "]":
|
||||
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|
||||
return host, port
|
||||
|
||||
# otherwise it should only be alphanumerics.
|
||||
if not VALID_HOST_REGEX.match(host):
|
||||
raise ValueError(
|
||||
"Server name '%s' contains invalid characters" % (server_name,)
|
||||
)
|
||||
|
||||
return host, port
|
|
@ -17,6 +17,9 @@
|
|||
<li>
|
||||
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
|
||||
<label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
|
||||
{% if p.idp_icon %}
|
||||
<img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
|
||||
{% endif %}
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
|
|
|
@ -32,7 +32,6 @@ from synapse.api.errors import (
|
|||
)
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events.utils import format_event_for_client_v2
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
|
@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter
|
|||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
import collections
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
|
|||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.stringutils import MXC_REGEX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
"""
|
||||
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
|
||||
|
||||
sql = """
|
||||
SELECT stream_ordering, json FROM events
|
||||
JOIN event_json USING (room_id, event_id)
|
||||
|
@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
for url in (content_url, thumbnail_url):
|
||||
if not url:
|
||||
continue
|
||||
matches = mxc_re.match(url)
|
||||
matches = MXC_REGEX.match(url)
|
||||
if matches:
|
||||
hostname = matches.group(1)
|
||||
media_id = matches.group(2)
|
||||
|
|
|
@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes
|
|||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.appservice.api import ApplicationService
|
||||
|
|
|
@ -18,6 +18,7 @@ import random
|
|||
import re
|
||||
import string
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
|
@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
|||
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
|
||||
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
|
||||
|
||||
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
|
||||
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
|
||||
# says "there is no grammar for media ids"
|
||||
#
|
||||
# The server_name part of this is purposely lax: use parse_and_validate_mxc for
|
||||
# additional validation.
|
||||
#
|
||||
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||
|
||||
# random_string and random_string_with_symbols are used for a range of things,
|
||||
# some cryptographically important, some less so. We use SystemRandom to make sure
|
||||
# we get cryptographically-secure randoms.
|
||||
|
@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
|
|||
)
|
||||
|
||||
|
||||
def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
|
||||
"""Split a server name into host/port parts.
|
||||
|
||||
Args:
|
||||
server_name: server name to parse
|
||||
|
||||
Returns:
|
||||
host/port parts.
|
||||
|
||||
Raises:
|
||||
ValueError if the server name could not be parsed.
|
||||
"""
|
||||
try:
|
||||
if server_name[-1] == "]":
|
||||
# ipv6 literal, hopefully
|
||||
return server_name, None
|
||||
|
||||
domain_port = server_name.rsplit(":", 1)
|
||||
domain = domain_port[0]
|
||||
port = int(domain_port[1]) if domain_port[1:] else None
|
||||
return domain, port
|
||||
except Exception:
|
||||
raise ValueError("Invalid server name '%s'" % server_name)
|
||||
|
||||
|
||||
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
|
||||
|
||||
|
||||
def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
|
||||
"""Split a server name into host/port parts and do some basic validation.
|
||||
|
||||
Args:
|
||||
server_name: server name to parse
|
||||
|
||||
Returns:
|
||||
host/port parts.
|
||||
|
||||
Raises:
|
||||
ValueError if the server name could not be parsed.
|
||||
"""
|
||||
host, port = parse_server_name(server_name)
|
||||
|
||||
# these tests don't need to be bulletproof as we'll find out soon enough
|
||||
# if somebody is giving us invalid data. What we *do* need is to be sure
|
||||
# that nobody is sneaking IP literals in that look like hostnames, etc.
|
||||
|
||||
# look for ipv6 literals
|
||||
if host[0] == "[":
|
||||
if host[-1] != "]":
|
||||
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|
||||
return host, port
|
||||
|
||||
# otherwise it should only be alphanumerics.
|
||||
if not VALID_HOST_REGEX.match(host):
|
||||
raise ValueError(
|
||||
"Server name '%s' contains invalid characters" % (server_name,)
|
||||
)
|
||||
|
||||
return host, port
|
||||
|
||||
|
||||
def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
|
||||
"""Parse the given string as an MXC URI
|
||||
|
||||
Checks that the "server name" part is a valid server name
|
||||
|
||||
Args:
|
||||
mxc: the (alleged) MXC URI to be checked
|
||||
Returns:
|
||||
hostname, port, media id
|
||||
Raises:
|
||||
ValueError if the URI cannot be parsed
|
||||
"""
|
||||
m = MXC_REGEX.match(mxc)
|
||||
if not m:
|
||||
raise ValueError("mxc URI %r did not match expected format" % (mxc,))
|
||||
server_name = m.group(1)
|
||||
media_id = m.group(2)
|
||||
host, port = parse_and_validate_server_name(server_name)
|
||||
return host, port, media_id
|
||||
|
||||
|
||||
def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
||||
"""If iterable has maxitems or fewer, return the stringification of a list
|
||||
containing those items.
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
|
||||
from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
|
Loading…
Reference in a new issue