mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-12 04:52:26 +01:00
Support icons for Identity Providers (#9154)
This commit is contained in:
parent
6c0dfd2e8e
commit
0cd2938bc8
19 changed files with 146 additions and 91 deletions
1
changelog.d/9154.feature
Normal file
1
changelog.d/9154.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for multiple SSO Identity Providers.
|
|
@ -1726,6 +1726,10 @@ saml2_config:
|
||||||
# idp_name: A user-facing name for this identity provider, which is used to
|
# idp_name: A user-facing name for this identity provider, which is used to
|
||||||
# offer the user a choice of login mechanisms.
|
# offer the user a choice of login mechanisms.
|
||||||
#
|
#
|
||||||
|
# idp_icon: An optional icon for this identity provider, which is presented
|
||||||
|
# by identity picker pages. If given, must be an MXC URI of the format
|
||||||
|
# mxc://<server-name>/<media-id>
|
||||||
|
#
|
||||||
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||||
# to discover endpoints. Defaults to true.
|
# to discover endpoints. Defaults to true.
|
||||||
#
|
#
|
||||||
|
|
1
mypy.ini
1
mypy.ini
|
@ -100,6 +100,7 @@ files =
|
||||||
synapse/util/async_helpers.py,
|
synapse/util/async_helpers.py,
|
||||||
synapse/util/caches,
|
synapse/util/caches,
|
||||||
synapse/util/metrics.py,
|
synapse/util/metrics.py,
|
||||||
|
synapse/util/stringutils.py,
|
||||||
tests/replication,
|
tests/replication,
|
||||||
tests/test_utils,
|
tests/test_utils,
|
||||||
tests/handlers/test_password_providers.py,
|
tests/handlers/test_password_providers.py,
|
||||||
|
|
|
@ -23,6 +23,7 @@ from synapse.config._util import validate_config
|
||||||
from synapse.python_dependencies import DependencyException, check_requirements
|
from synapse.python_dependencies import DependencyException, check_requirements
|
||||||
from synapse.types import Collection, JsonDict
|
from synapse.types import Collection, JsonDict
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
|
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
@ -66,6 +67,10 @@ class OIDCConfig(Config):
|
||||||
# idp_name: A user-facing name for this identity provider, which is used to
|
# idp_name: A user-facing name for this identity provider, which is used to
|
||||||
# offer the user a choice of login mechanisms.
|
# offer the user a choice of login mechanisms.
|
||||||
#
|
#
|
||||||
|
# idp_icon: An optional icon for this identity provider, which is presented
|
||||||
|
# by identity picker pages. If given, must be an MXC URI of the format
|
||||||
|
# mxc://<server-name>/<media-id>
|
||||||
|
#
|
||||||
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
# discover: set to 'false' to disable the use of the OIDC discovery mechanism
|
||||||
# to discover endpoints. Defaults to true.
|
# to discover endpoints. Defaults to true.
|
||||||
#
|
#
|
||||||
|
@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"properties": {
|
"properties": {
|
||||||
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
|
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
|
||||||
"idp_name": {"type": "string"},
|
"idp_name": {"type": "string"},
|
||||||
|
"idp_icon": {"type": "string"},
|
||||||
"discover": {"type": "boolean"},
|
"discover": {"type": "boolean"},
|
||||||
"issuer": {"type": "string"},
|
"issuer": {"type": "string"},
|
||||||
"client_id": {"type": "string"},
|
"client_id": {"type": "string"},
|
||||||
|
@ -336,9 +342,20 @@ def _parse_oidc_config_dict(
|
||||||
config_path + ("idp_id",),
|
config_path + ("idp_id",),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MSC2858 also specifies that the idp_icon must be a valid MXC uri
|
||||||
|
idp_icon = oidc_config.get("idp_icon")
|
||||||
|
if idp_icon is not None:
|
||||||
|
try:
|
||||||
|
parse_and_validate_mxc_uri(idp_icon)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ConfigError(
|
||||||
|
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
||||||
|
) from e
|
||||||
|
|
||||||
return OidcProviderConfig(
|
return OidcProviderConfig(
|
||||||
idp_id=idp_id,
|
idp_id=idp_id,
|
||||||
idp_name=oidc_config.get("idp_name", "OIDC"),
|
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||||
|
idp_icon=idp_icon,
|
||||||
discover=oidc_config.get("discover", True),
|
discover=oidc_config.get("discover", True),
|
||||||
issuer=oidc_config["issuer"],
|
issuer=oidc_config["issuer"],
|
||||||
client_id=oidc_config["client_id"],
|
client_id=oidc_config["client_id"],
|
||||||
|
@ -366,6 +383,9 @@ class OidcProviderConfig:
|
||||||
# user-facing name for this identity provider.
|
# user-facing name for this identity provider.
|
||||||
idp_name = attr.ib(type=str)
|
idp_name = attr.ib(type=str)
|
||||||
|
|
||||||
|
# Optional MXC URI for icon for this IdP.
|
||||||
|
idp_icon = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||||
discover = attr.ib(type=bool)
|
discover = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ import yaml
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,6 @@ from synapse.events import EventBase
|
||||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||||
from synapse.federation.persistence import TransactionActions
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
from synapse.http.endpoint import parse_server_name
|
|
||||||
from synapse.http.servlet import assert_params_in_dict
|
from synapse.http.servlet import assert_params_in_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
|
@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.util.stringutils import parse_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
|
@ -28,7 +28,6 @@ from synapse.api.urls import (
|
||||||
FEDERATION_V1_PREFIX,
|
FEDERATION_V1_PREFIX,
|
||||||
FEDERATION_V2_PREFIX,
|
FEDERATION_V2_PREFIX,
|
||||||
)
|
)
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_boolean_from_args,
|
parse_boolean_from_args,
|
||||||
|
@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
|
||||||
)
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
||||||
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -80,6 +80,10 @@ class CasHandler:
|
||||||
# user-facing name of this auth provider
|
# user-facing name of this auth provider
|
||||||
self.idp_name = "CAS"
|
self.idp_name = "CAS"
|
||||||
|
|
||||||
|
# we do not currently support icons for CAS auth, but this is required by
|
||||||
|
# the SsoIdentityProvider protocol type.
|
||||||
|
self.idp_icon = None
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
self._sso_handler.register_identity_provider(self)
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
|
@ -271,6 +271,9 @@ class OidcProvider:
|
||||||
# user-facing name of this auth provider
|
# user-facing name of this auth provider
|
||||||
self.idp_name = provider.idp_name
|
self.idp_name = provider.idp_name
|
||||||
|
|
||||||
|
# MXC URI for icon for this auth provider
|
||||||
|
self.idp_icon = provider.idp_icon
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
self._sso_handler.register_identity_provider(self)
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
|
@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import copy_power_levels_contents
|
from synapse.events.utils import copy_power_levels_contents
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
|
@ -55,6 +54,7 @@ from synapse.types import (
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
|
@ -78,6 +78,10 @@ class SamlHandler(BaseHandler):
|
||||||
# user-facing name of this auth provider
|
# user-facing name of this auth provider
|
||||||
self.idp_name = "SAML"
|
self.idp_name = "SAML"
|
||||||
|
|
||||||
|
# we do not currently support icons for SAML auth, but this is required by
|
||||||
|
# the SsoIdentityProvider protocol type.
|
||||||
|
self.idp_icon = None
|
||||||
|
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol):
|
||||||
def idp_name(self) -> str:
|
def idp_name(self) -> str:
|
||||||
"""User-facing name for this provider"""
|
"""User-facing name for this provider"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def idp_icon(self) -> Optional[str]:
|
||||||
|
"""Optional MXC URI for user-facing icon"""
|
||||||
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle_redirect_request(
|
async def handle_redirect_request(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_server_name(server_name):
|
|
||||||
"""Split a server name into host/port parts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_name (str): server name to parse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, int|None]: host/port parts.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError if the server name could not be parsed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if server_name[-1] == "]":
|
|
||||||
# ipv6 literal, hopefully
|
|
||||||
return server_name, None
|
|
||||||
|
|
||||||
domain_port = server_name.rsplit(":", 1)
|
|
||||||
domain = domain_port[0]
|
|
||||||
port = int(domain_port[1]) if domain_port[1:] else None
|
|
||||||
return domain, port
|
|
||||||
except Exception:
|
|
||||||
raise ValueError("Invalid server name '%s'" % server_name)
|
|
||||||
|
|
||||||
|
|
||||||
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_and_validate_server_name(server_name):
|
|
||||||
"""Split a server name into host/port parts and do some basic validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_name (str): server name to parse
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, int|None]: host/port parts.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError if the server name could not be parsed.
|
|
||||||
"""
|
|
||||||
host, port = parse_server_name(server_name)
|
|
||||||
|
|
||||||
# these tests don't need to be bulletproof as we'll find out soon enough
|
|
||||||
# if somebody is giving us invalid data. What we *do* need is to be sure
|
|
||||||
# that nobody is sneaking IP literals in that look like hostnames, etc.
|
|
||||||
|
|
||||||
# look for ipv6 literals
|
|
||||||
if host[0] == "[":
|
|
||||||
if host[-1] != "]":
|
|
||||||
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|
|
||||||
return host, port
|
|
||||||
|
|
||||||
# otherwise it should only be alphanumerics.
|
|
||||||
if not VALID_HOST_REGEX.match(host):
|
|
||||||
raise ValueError(
|
|
||||||
"Server name '%s' contains invalid characters" % (server_name,)
|
|
||||||
)
|
|
||||||
|
|
||||||
return host, port
|
|
|
@ -17,6 +17,9 @@
|
||||||
<li>
|
<li>
|
||||||
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
|
<input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
|
||||||
<label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
|
<label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
|
||||||
|
{% if p.idp_icon %}
|
||||||
|
<img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
|
||||||
|
{% endif %}
|
||||||
</li>
|
</li>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
</ul>
|
</ul>
|
||||||
|
|
|
@ -32,7 +32,6 @@ from synapse.api.errors import (
|
||||||
)
|
)
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.events.utils import format_event_for_client_v2
|
from synapse.events.utils import format_event_for_client_v2
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
|
@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import synapse.server
|
import synapse.server
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
|
||||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
from synapse.util.stringutils import MXC_REGEX
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
The local and remote media as a lists of tuples where the key is
|
The local and remote media as a lists of tuples where the key is
|
||||||
the hostname and the value is the media ID.
|
the hostname and the value is the media ID.
|
||||||
"""
|
"""
|
||||||
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
|
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_ordering, json FROM events
|
SELECT stream_ordering, json FROM events
|
||||||
JOIN event_json USING (room_id, event_id)
|
JOIN event_json USING (room_id, event_id)
|
||||||
|
@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
for url in (content_url, thumbnail_url):
|
for url in (content_url, thumbnail_url):
|
||||||
if not url:
|
if not url:
|
||||||
continue
|
continue
|
||||||
matches = mxc_re.match(url)
|
matches = MXC_REGEX.match(url)
|
||||||
if matches:
|
if matches:
|
||||||
hostname = matches.group(1)
|
hostname = matches.group(1)
|
||||||
media_id = matches.group(2)
|
media_id = matches.group(2)
|
||||||
|
|
|
@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.appservice.api import ApplicationService
|
from synapse.appservice.api import ApplicationService
|
||||||
|
|
|
@ -18,6 +18,7 @@ import random
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
|
||||||
|
@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||||
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
|
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
|
||||||
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
|
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
|
||||||
|
|
||||||
|
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
|
||||||
|
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
|
||||||
|
# says "there is no grammar for media ids"
|
||||||
|
#
|
||||||
|
# The server_name part of this is purposely lax: use parse_and_validate_mxc for
|
||||||
|
# additional validation.
|
||||||
|
#
|
||||||
|
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||||
|
|
||||||
# random_string and random_string_with_symbols are used for a range of things,
|
# random_string and random_string_with_symbols are used for a range of things,
|
||||||
# some cryptographically important, some less so. We use SystemRandom to make sure
|
# some cryptographically important, some less so. We use SystemRandom to make sure
|
||||||
# we get cryptographically-secure randoms.
|
# we get cryptographically-secure randoms.
|
||||||
|
@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
|
||||||
|
"""Split a server name into host/port parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: server name to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
host/port parts.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the server name could not be parsed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if server_name[-1] == "]":
|
||||||
|
# ipv6 literal, hopefully
|
||||||
|
return server_name, None
|
||||||
|
|
||||||
|
domain_port = server_name.rsplit(":", 1)
|
||||||
|
domain = domain_port[0]
|
||||||
|
port = int(domain_port[1]) if domain_port[1:] else None
|
||||||
|
return domain, port
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid server name '%s'" % server_name)
|
||||||
|
|
||||||
|
|
||||||
|
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
|
||||||
|
"""Split a server name into host/port parts and do some basic validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: server name to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
host/port parts.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the server name could not be parsed.
|
||||||
|
"""
|
||||||
|
host, port = parse_server_name(server_name)
|
||||||
|
|
||||||
|
# these tests don't need to be bulletproof as we'll find out soon enough
|
||||||
|
# if somebody is giving us invalid data. What we *do* need is to be sure
|
||||||
|
# that nobody is sneaking IP literals in that look like hostnames, etc.
|
||||||
|
|
||||||
|
# look for ipv6 literals
|
||||||
|
if host[0] == "[":
|
||||||
|
if host[-1] != "]":
|
||||||
|
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|
||||||
|
return host, port
|
||||||
|
|
||||||
|
# otherwise it should only be alphanumerics.
|
||||||
|
if not VALID_HOST_REGEX.match(host):
|
||||||
|
raise ValueError(
|
||||||
|
"Server name '%s' contains invalid characters" % (server_name,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return host, port
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
|
||||||
|
"""Parse the given string as an MXC URI
|
||||||
|
|
||||||
|
Checks that the "server name" part is a valid server name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mxc: the (alleged) MXC URI to be checked
|
||||||
|
Returns:
|
||||||
|
hostname, port, media id
|
||||||
|
Raises:
|
||||||
|
ValueError if the URI cannot be parsed
|
||||||
|
"""
|
||||||
|
m = MXC_REGEX.match(mxc)
|
||||||
|
if not m:
|
||||||
|
raise ValueError("mxc URI %r did not match expected format" % (mxc,))
|
||||||
|
server_name = m.group(1)
|
||||||
|
media_id = m.group(2)
|
||||||
|
host, port = parse_and_validate_server_name(server_name)
|
||||||
|
return host, port, media_id
|
||||||
|
|
||||||
|
|
||||||
def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
||||||
"""If iterable has maxitems or fewer, return the stringification of a list
|
"""If iterable has maxitems or fewer, return the stringification of a list
|
||||||
containing those items.
|
containing those items.
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue