0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-14 14:01:59 +01:00

Additional type hints for config module. (#11465)

This adds some misc. type hints to helper methods used
in the `synapse.config` module.
This commit is contained in:
Patrick Cloke 2021-12-01 07:28:23 -05:00 committed by GitHub
parent a265fbd397
commit f44d729d4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 129 additions and 99 deletions

1
changelog.d/11465.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `synapse.config` module.

View file

@ -32,6 +32,7 @@ from typing import (
Iterable,
List,
NoReturn,
Optional,
Tuple,
cast,
)
@ -129,7 +130,7 @@ def start_worker_reactor(
def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Tuple[int, int, int],
gc_thresholds: Optional[Tuple[int, int, int]],
pid_file: str,
daemonize: bool,
print_pidfile: bool,

View file

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import List
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
def main(args):
def main(args: List[str]) -> None:
action = args[1] if len(args) > 1 and args[1] == "read" else None
# If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]`
# will be the key to read.

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,14 +14,14 @@
# limitations under the License.
import logging
from typing import Dict
from typing import Dict, List
from urllib import parse as urlparse
import yaml
from netaddr import IPSet
from synapse.appservice import ApplicationService
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from ._base import Config, ConfigError
@ -30,12 +31,12 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
section = "appservice"
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
def generate_config_section(cls, **kwargs):
def generate_config_section(cls, **kwargs) -> str:
return """\
# A list of application service config files to use
#
@ -50,7 +51,9 @@ class AppServiceConfig(Config):
"""
def load_appservices(hostname, config_files):
def load_appservices(
hostname: str, config_files: List[str]
) -> List[ApplicationService]:
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning("Expected %s to be a list of AS config files.", config_files)
@ -93,7 +96,9 @@ def load_appservices(hostname, config_files):
return appservices
def _load_appservice(hostname, as_info, config_filename):
def _load_appservice(
hostname: str, as_info: JsonDict, config_filename: str
) -> ApplicationService:
required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields:
if not isinstance(as_info.get(field), str):
@ -115,9 +120,9 @@ def _load_appservice(hostname, as_info, config_filename):
user_id = user.to_string()
# Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = True
if isinstance(as_info.get("rate_limited"), bool):
rate_limited = as_info.get("rate_limited")
rate_limited = as_info.get("rate_limited")
if not isinstance(rate_limited, bool):
rate_limited = True
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):

View file

@ -1,4 +1,4 @@
# Copyright 2019 Matrix.org Foundation C.I.C.
# Copyright 2019-2021 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@ import re
import threading
from typing import Callable, Dict, Optional
import attr
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError
@ -34,13 +36,13 @@ _DEFAULT_FACTOR_SIZE = 0.5
_DEFAULT_EVENT_CACHE_SIZE = "10K"
@attr.s(slots=True, auto_attribs=True)
class CacheProperties:
def __init__(self):
# The default factor size for all caches
self.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
self.resize_all_caches_func = None
# The default factor size for all caches
default_factor_size: float = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
resize_all_caches_func: Optional[Callable[[], None]] = None
properties = CacheProperties()
@ -62,7 +64,7 @@ def _canonicalise_cache_name(cache_name: str) -> str:
def add_resizable_cache(
cache_name: str, cache_resize_callback: Callable[[float], None]
):
) -> None:
"""Register a cache that's size can dynamically change
Args:
@ -91,7 +93,7 @@ class CacheConfig(Config):
_environ = os.environ
@staticmethod
def reset():
def reset() -> None:
"""Resets the caches to their defaults. Used for tests."""
properties.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
@ -100,7 +102,7 @@ class CacheConfig(Config):
with _CACHES_LOCK:
_CACHES.clear()
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs) -> str:
return """\
## Caching ##
@ -162,7 +164,7 @@ class CacheConfig(Config):
#sync_response_cache_duration: 2m
"""
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)
@ -232,7 +234,7 @@ class CacheConfig(Config):
# needing an instance of Config
properties.resize_all_caches_func = self.resize_all_caches
def resize_all_caches(self):
def resize_all_caches(self) -> None:
"""Ensure all cache sizes are up to date
For each cache, run the mapped callback function with either

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -28,7 +29,7 @@ class CasConfig(Config):
section = "cas"
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
cas_config = config.get("cas_config", None)
self.cas_enabled = cas_config and cas_config.get("enabled", True)
@ -51,7 +52,7 @@ class CasConfig(Config):
self.cas_displayname_attribute = None
self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\
# Enable Central Authentication Service (CAS) for registration and login.
#

View file

@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
@ -119,7 +120,7 @@ class DatabaseConfig(Config):
self.databases = []
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
# We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra
@ -163,12 +164,12 @@ class DatabaseConfig(Config):
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path)
def generate_config_section(self, data_dir_path, **kwargs):
def generate_config_section(self, data_dir_path, **kwargs) -> str:
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}
def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
"""
Cases for the cli input:
- If no databases are configured and no database_path is set, raise.
@ -194,7 +195,7 @@ class DatabaseConfig(Config):
else:
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
def set_databasepath(self, database_path):
def set_databasepath(self, database_path: str) -> None:
if database_path != ":memory:":
database_path = self.abspath(database_path)
@ -202,7 +203,7 @@ class DatabaseConfig(Config):
self.databases[0].config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
def add_arguments(parser: argparse.ArgumentParser) -> None:
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d",

View file

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,7 +19,7 @@ import os
import sys
import threading
from string import Template
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional
import yaml
from zope.interface import implementer
@ -40,6 +41,7 @@ from synapse.util.versionstring import get_version_string
from ._base import Config, ConfigError
if TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig
from synapse.server import HomeServer
DEFAULT_LOG_CONFIG = Template(
@ -141,13 +143,13 @@ removed in Synapse 1.3.0. You should instead set up a separate log configuration
class LoggingConfig(Config):
section = "logging"
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
if config.get("log_file"):
raise ConfigError(LOG_FILE_ERROR)
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return (
"""\
@ -161,14 +163,14 @@ class LoggingConfig(Config):
% locals()
)
def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
if args.log_file is not None:
raise ConfigError(LOG_FILE_ERROR)
@staticmethod
def add_arguments(parser):
def add_arguments(parser: argparse.ArgumentParser) -> None:
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"-n",
@ -197,7 +199,9 @@ class LoggingConfig(Config):
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
def _setup_stdlib_logging(
config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner
) -> None:
"""
Set up Python standard library logging.
"""
@ -230,7 +234,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs):
def factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
record = old_factory(*args, **kwargs)
log_context_filter.filter(record)
log_metadata_filter.filter(record)
@ -297,7 +301,7 @@ def _load_logging_config(log_config_path: str) -> None:
logging.config.dictConfig(log_config)
def _reload_logging_config(log_config_path):
def _reload_logging_config(log_config_path: Optional[str]) -> None:
"""
Reload the log configuration from the file and apply it.
"""
@ -311,8 +315,8 @@ def _reload_logging_config(log_config_path):
def setup_logging(
hs: "HomeServer",
config,
use_worker_options=False,
config: "HomeServerConfig",
use_worker_options: bool = False,
logBeginner: LogBeginner = globalLogBeginner,
) -> None:
"""

View file

@ -14,7 +14,7 @@
# limitations under the License.
from collections import Counter
from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type
from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type
import attr
@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr
class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers:
return
@ -66,7 +66,7 @@ class OIDCConfig(Config):
# OIDC is enabled if we have a provider
return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login.
@ -495,89 +495,89 @@ def _parse_oidc_config_dict(
)
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class OidcProviderClientSecretJwtKey:
# a pem-encoded signing key
key = attr.ib(type=str)
key: str
# properties to include in the JWT header
jwt_header = attr.ib(type=Mapping[str, str])
jwt_header: Mapping[str, str]
# properties to include in the JWT payload.
jwt_payload = attr.ib(type=Mapping[str, str])
jwt_payload: Mapping[str, str]
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
idp_id = attr.ib(type=str)
idp_id: str
# user-facing name for this identity provider.
idp_name = attr.ib(type=str)
idp_name: str
# Optional MXC URI for icon for this IdP.
idp_icon = attr.ib(type=Optional[str])
idp_icon: Optional[str]
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])
idp_brand: Optional[str]
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
discover: bool
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
issuer = attr.ib(type=str)
issuer: str
# oauth2 client id to use
client_id = attr.ib(type=str)
client_id: str
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret.
client_secret = attr.ib(type=Optional[str])
client_secret: Optional[str]
# key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set.
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey]
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'.
client_auth_method = attr.ib(type=str)
client_auth_method: str
# list of scopes to request
scopes = attr.ib(type=Collection[str])
scopes: Collection[str]
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint = attr.ib(type=Optional[str])
authorization_endpoint: Optional[str]
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint = attr.ib(type=Optional[str])
token_endpoint: Optional[str]
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
userinfo_endpoint = attr.ib(type=Optional[str])
userinfo_endpoint: Optional[str]
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
jwks_uri = attr.ib(type=Optional[str])
jwks_uri: Optional[str]
# Whether to skip metadata verification
skip_verification = attr.ib(type=bool)
skip_verification: bool
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
user_profile_method = attr.ib(type=str)
user_profile_method: str
# whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing
allow_existing_users = attr.ib(type=bool)
allow_existing_users: bool
# the class of the user mapping provider
user_mapping_provider_class = attr.ib(type=Type)
user_mapping_provider_class: Type
# the config of the user mapping provider
user_mapping_provider_config = attr.ib()
user_mapping_provider_config: Any
# required attributes to require in userinfo to allow login/registration
attribute_requirements = attr.ib(type=List[SsoAttributeRequirement])
attribute_requirements: List[SsoAttributeRequirement]

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Optional
from synapse.api.constants import RoomCreationPreset
@ -362,7 +364,7 @@ class RegistrationConfig(Config):
)
@staticmethod
def add_arguments(parser):
def add_arguments(parser: argparse.ArgumentParser) -> None:
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--enable-registration",
@ -371,6 +373,6 @@ class RegistrationConfig(Config):
help="Enable registration for new users.",
)
def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
if args.enable_registration is not None:
self.enable_registration = strtobool(str(args.enable_registration))

View file

@ -15,11 +15,12 @@
import logging
import os
from collections import namedtuple
from typing import Dict, List
from typing import Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore
from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
@ -57,7 +58,9 @@ MediaStorageProviderConfig = namedtuple(
)
def parse_thumbnail_requirements(thumbnail_sizes):
def parse_thumbnail_requirements(
thumbnail_sizes: List[JsonDict],
) -> Dict[str, Tuple[ThumbnailRequirement, ...]]:
"""Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate
@ -69,7 +72,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements: Dict[str, List] = {}
requirements: Dict[str, List[ThumbnailRequirement]] = {}
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]

View file

@ -1,5 +1,5 @@
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,10 +14,11 @@
# limitations under the License.
import logging
from typing import Any, List
from typing import Any, List, Set
from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
@ -33,7 +34,7 @@ LEGACY_USER_MAPPING_PROVIDER = (
)
def _dict_merge(merge_dict, into_dict):
def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
"""Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`:
@ -43,8 +44,8 @@ def _dict_merge(merge_dict, into_dict):
the value from `merge_dict`.
Args:
merge_dict (dict): dict to merge
into_dict (dict): target dict
merge_dict: dict to merge
into_dict: target dict to be modified
"""
for k, v in merge_dict.items():
if k not in into_dict:
@ -64,7 +65,7 @@ def _dict_merge(merge_dict, into_dict):
class SAML2Config(Config):
section = "saml2"
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
@ -183,8 +184,8 @@ class SAML2Config(Config):
)
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
self, required_attributes: Set[str], optional_attributes: Set[str]
) -> JsonDict:
"""Generate a configuration dictionary with required and optional attributes that
will be needed to process new user registration
@ -195,7 +196,7 @@ class SAML2Config(Config):
additional information to Synapse user accounts, but are not required
Returns:
dict: A SAML configuration dictionary
A SAML configuration dictionary
"""
import saml2
@ -222,7 +223,7 @@ class SAML2Config(Config):
},
}
def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\
## Single sign-on integration ##

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import logging
import os.path
@ -27,6 +28,7 @@ from netaddr import AddrFormatError, IPNetwork, IPSet
from twisted.conch.ssh.keys import Key
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name
@ -1223,7 +1225,7 @@ class ServerConfig(Config):
% locals()
)
def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
if args.manhole is not None:
self.manhole = args.manhole
if args.daemonize is not None:
@ -1232,7 +1234,7 @@ class ServerConfig(Config):
self.print_pidfile = args.print_pidfile
@staticmethod
def add_arguments(parser):
def add_arguments(parser: argparse.ArgumentParser) -> None:
server_group = parser.add_argument_group("server")
server_group.add_argument(
"-D",
@ -1274,14 +1276,16 @@ class ServerConfig(Config):
)
def is_threepid_reserved(reserved_threepids, threepid):
def is_threepid_reserved(
reserved_threepids: List[JsonDict], threepid: JsonDict
) -> bool:
"""Check the threepid against the reserved threepid config
Args:
reserved_threepids([dict]) - list of reserved threepids
threepid(dict) - The threepid to test for
reserved_threepids: List of reserved threepids
threepid: The threepid to test for
Returns:
boolean Is the threepid undertest reserved_user
Is the threepid undertest reserved_user
"""
for tp in reserved_threepids:
@ -1290,7 +1294,9 @@ def is_threepid_reserved(reserved_threepids, threepid):
return False
def read_gc_thresholds(thresholds):
def read_gc_thresholds(
thresholds: Optional[List[Any]],
) -> Optional[Tuple[int, int, int]]:
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.
"""

View file

@ -1,4 +1,4 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,13 +29,13 @@ https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------"""
@attr.s(frozen=True)
@attr.s(frozen=True, auto_attribs=True)
class SsoAttributeRequirement:
"""Object describing a single requirement for SSO attributes."""
attribute = attr.ib(type=str)
attribute: str
# If a value is not given, than the attribute must simply exist.
value = attr.ib(type=Optional[str])
value: Optional[str]
JSON_SCHEMA = {
"type": "object",
@ -49,7 +49,7 @@ class SSOConfig(Config):
section = "sso"
def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
sso_config: Dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir
@ -106,7 +106,7 @@ class SSOConfig(Config):
)
self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs) -> str:
return """\
# Additional settings to use with single-sign on systems such as OpenID Connect,
# SAML2 and CAS.

View file

@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List, Union
import attr
@ -343,7 +345,7 @@ class WorkerConfig(Config):
#worker_replication_secret: ""
"""
def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
# We support a bunch of command line arguments that override options in
# the config. A lot of these options have a worker_* prefix when running
# on workers so we also have to override them when command line options