0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 11:43:51 +01:00

Additional type hints for config module. (#11465)

This adds some misc. type hints to helper methods used
in the `synapse.config` module.
This commit is contained in:
Patrick Cloke 2021-12-01 07:28:23 -05:00 committed by GitHub
parent a265fbd397
commit f44d729d4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 129 additions and 99 deletions

1
changelog.d/11465.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `synapse.config` module.

View file

@ -32,6 +32,7 @@ from typing import (
Iterable, Iterable,
List, List,
NoReturn, NoReturn,
Optional,
Tuple, Tuple,
cast, cast,
) )
@ -129,7 +130,7 @@ def start_worker_reactor(
def start_reactor( def start_reactor(
appname: str, appname: str,
soft_file_limit: int, soft_file_limit: int,
gc_thresholds: Tuple[int, int, int], gc_thresholds: Optional[Tuple[int, int, int]],
pid_file: str, pid_file: str,
daemonize: bool, daemonize: bool,
print_pidfile: bool, print_pidfile: bool,

View file

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
from typing import List
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
def main(args): def main(args: List[str]) -> None:
action = args[1] if len(args) > 1 and args[1] == "read" else None action = args[1] if len(args) > 1 and args[1] == "read" else None
# If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]` # If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]`
# will be the key to read. # will be the key to read.

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,14 +14,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict from typing import Dict, List
from urllib import parse as urlparse from urllib import parse as urlparse
import yaml import yaml
from netaddr import IPSet from netaddr import IPSet
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import JsonDict, UserID
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -30,12 +31,12 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config): class AppServiceConfig(Config):
section = "appservice" section = "appservice"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True) self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
def generate_config_section(cls, **kwargs): def generate_config_section(cls, **kwargs) -> str:
return """\ return """\
# A list of application service config files to use # A list of application service config files to use
# #
@ -50,7 +51,9 @@ class AppServiceConfig(Config):
""" """
def load_appservices(hostname, config_files): def load_appservices(
hostname: str, config_files: List[str]
) -> List[ApplicationService]:
"""Returns a list of Application Services from the config files.""" """Returns a list of Application Services from the config files."""
if not isinstance(config_files, list): if not isinstance(config_files, list):
logger.warning("Expected %s to be a list of AS config files.", config_files) logger.warning("Expected %s to be a list of AS config files.", config_files)
@ -93,7 +96,9 @@ def load_appservices(hostname, config_files):
return appservices return appservices
def _load_appservice(hostname, as_info, config_filename): def _load_appservice(
hostname: str, as_info: JsonDict, config_filename: str
) -> ApplicationService:
required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields: for field in required_string_fields:
if not isinstance(as_info.get(field), str): if not isinstance(as_info.get(field), str):
@ -115,9 +120,9 @@ def _load_appservice(hostname, as_info, config_filename):
user_id = user.to_string() user_id = user.to_string()
# Rate limiting for users of this AS is on by default (excludes sender) # Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = True rate_limited = as_info.get("rate_limited")
if isinstance(as_info.get("rate_limited"), bool): if not isinstance(rate_limited, bool):
rate_limited = as_info.get("rate_limited") rate_limited = True
# namespace checks # namespace checks
if not isinstance(as_info.get("namespaces"), dict): if not isinstance(as_info.get("namespaces"), dict):

View file

@ -1,4 +1,4 @@
# Copyright 2019 Matrix.org Foundation C.I.C. # Copyright 2019-2021 Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@ import re
import threading import threading
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
import attr
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -34,13 +36,13 @@ _DEFAULT_FACTOR_SIZE = 0.5
_DEFAULT_EVENT_CACHE_SIZE = "10K" _DEFAULT_EVENT_CACHE_SIZE = "10K"
@attr.s(slots=True, auto_attribs=True)
class CacheProperties: class CacheProperties:
def __init__(self): # The default factor size for all caches
# The default factor size for all caches default_factor_size: float = float(
self.default_factor_size = float( os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) )
) resize_all_caches_func: Optional[Callable[[], None]] = None
self.resize_all_caches_func = None
properties = CacheProperties() properties = CacheProperties()
@ -62,7 +64,7 @@ def _canonicalise_cache_name(cache_name: str) -> str:
def add_resizable_cache( def add_resizable_cache(
cache_name: str, cache_resize_callback: Callable[[float], None] cache_name: str, cache_resize_callback: Callable[[float], None]
): ) -> None:
"""Register a cache that's size can dynamically change """Register a cache that's size can dynamically change
Args: Args:
@ -91,7 +93,7 @@ class CacheConfig(Config):
_environ = os.environ _environ = os.environ
@staticmethod @staticmethod
def reset(): def reset() -> None:
"""Resets the caches to their defaults. Used for tests.""" """Resets the caches to their defaults. Used for tests."""
properties.default_factor_size = float( properties.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
@ -100,7 +102,7 @@ class CacheConfig(Config):
with _CACHES_LOCK: with _CACHES_LOCK:
_CACHES.clear() _CACHES.clear()
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs) -> str:
return """\ return """\
## Caching ## ## Caching ##
@ -162,7 +164,7 @@ class CacheConfig(Config):
#sync_response_cache_duration: 2m #sync_response_cache_duration: 2m
""" """
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
self.event_cache_size = self.parse_size( self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
) )
@ -232,7 +234,7 @@ class CacheConfig(Config):
# needing an instance of Config # needing an instance of Config
properties.resize_all_caches_func = self.resize_all_caches properties.resize_all_caches_func = self.resize_all_caches
def resize_all_caches(self): def resize_all_caches(self) -> None:
"""Ensure all cache sizes are up to date """Ensure all cache sizes are up to date
For each cache, run the mapped callback function with either For each cache, run the mapped callback function with either

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -28,7 +29,7 @@ class CasConfig(Config):
section = "cas" section = "cas"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
cas_config = config.get("cas_config", None) cas_config = config.get("cas_config", None)
self.cas_enabled = cas_config and cas_config.get("enabled", True) self.cas_enabled = cas_config and cas_config.get("enabled", True)
@ -51,7 +52,7 @@ class CasConfig(Config):
self.cas_displayname_attribute = None self.cas_displayname_attribute = None
self.cas_required_attributes = [] self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\ return """\
# Enable Central Authentication Service (CAS) for registration and login. # Enable Central Authentication Service (CAS) for registration and login.
# #

View file

@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import logging import logging
import os import os
@ -119,7 +120,7 @@ class DatabaseConfig(Config):
self.databases = [] self.databases = []
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
# We *experimentally* support specifying multiple databases via the # We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the # `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra # same format as the `database` config option, plus an extra
@ -163,12 +164,12 @@ class DatabaseConfig(Config):
self.databases = [DatabaseConnectionConfig("master", database_config)] self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path) self.set_databasepath(database_path)
def generate_config_section(self, data_dir_path, **kwargs): def generate_config_section(self, data_dir_path, **kwargs) -> str:
return DEFAULT_CONFIG % { return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db") "database_path": os.path.join(data_dir_path, "homeserver.db")
} }
def read_arguments(self, args): def read_arguments(self, args: argparse.Namespace) -> None:
""" """
Cases for the cli input: Cases for the cli input:
- If no databases are configured and no database_path is set, raise. - If no databases are configured and no database_path is set, raise.
@ -194,7 +195,7 @@ class DatabaseConfig(Config):
else: else:
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
def set_databasepath(self, database_path): def set_databasepath(self, database_path: str) -> None:
if database_path != ":memory:": if database_path != ":memory:":
database_path = self.abspath(database_path) database_path = self.abspath(database_path)
@ -202,7 +203,7 @@ class DatabaseConfig(Config):
self.databases[0].config["args"]["database"] = database_path self.databases[0].config["args"]["database"] = database_path
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser: argparse.ArgumentParser) -> None:
db_group = parser.add_argument_group("database") db_group = parser.add_argument_group("database")
db_group.add_argument( db_group.add_argument(
"-d", "-d",

View file

@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,7 +19,7 @@ import os
import sys import sys
import threading import threading
from string import Template from string import Template
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict, Optional
import yaml import yaml
from zope.interface import implementer from zope.interface import implementer
@ -40,6 +41,7 @@ from synapse.util.versionstring import get_version_string
from ._base import Config, ConfigError from ._base import Config, ConfigError
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig
from synapse.server import HomeServer from synapse.server import HomeServer
DEFAULT_LOG_CONFIG = Template( DEFAULT_LOG_CONFIG = Template(
@ -141,13 +143,13 @@ removed in Synapse 1.3.0. You should instead set up a separate log configuration
class LoggingConfig(Config): class LoggingConfig(Config):
section = "logging" section = "logging"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
if config.get("log_file"): if config.get("log_file"):
raise ConfigError(LOG_FILE_ERROR) raise ConfigError(LOG_FILE_ERROR)
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False) self.no_redirect_stdio = config.get("no_redirect_stdio", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config") log_config = os.path.join(config_dir_path, server_name + ".log.config")
return ( return (
"""\ """\
@ -161,14 +163,14 @@ class LoggingConfig(Config):
% locals() % locals()
) )
def read_arguments(self, args): def read_arguments(self, args: argparse.Namespace) -> None:
if args.no_redirect_stdio is not None: if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio self.no_redirect_stdio = args.no_redirect_stdio
if args.log_file is not None: if args.log_file is not None:
raise ConfigError(LOG_FILE_ERROR) raise ConfigError(LOG_FILE_ERROR)
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser: argparse.ArgumentParser) -> None:
logging_group = parser.add_argument_group("logging") logging_group = parser.add_argument_group("logging")
logging_group.add_argument( logging_group.add_argument(
"-n", "-n",
@ -197,7 +199,9 @@ class LoggingConfig(Config):
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None: def _setup_stdlib_logging(
config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner
) -> None:
""" """
Set up Python standard library logging. Set up Python standard library logging.
""" """
@ -230,7 +234,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
log_metadata_filter = MetadataFilter({"server_name": config.server.server_name}) log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
old_factory = logging.getLogRecordFactory() old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs): def factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
record = old_factory(*args, **kwargs) record = old_factory(*args, **kwargs)
log_context_filter.filter(record) log_context_filter.filter(record)
log_metadata_filter.filter(record) log_metadata_filter.filter(record)
@ -297,7 +301,7 @@ def _load_logging_config(log_config_path: str) -> None:
logging.config.dictConfig(log_config) logging.config.dictConfig(log_config)
def _reload_logging_config(log_config_path): def _reload_logging_config(log_config_path: Optional[str]) -> None:
""" """
Reload the log configuration from the file and apply it. Reload the log configuration from the file and apply it.
""" """
@ -311,8 +315,8 @@ def _reload_logging_config(log_config_path):
def setup_logging( def setup_logging(
hs: "HomeServer", hs: "HomeServer",
config, config: "HomeServerConfig",
use_worker_options=False, use_worker_options: bool = False,
logBeginner: LogBeginner = globalLogBeginner, logBeginner: LogBeginner = globalLogBeginner,
) -> None: ) -> None:
""" """

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from collections import Counter from collections import Counter
from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type
import attr import attr
@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr
class OIDCConfig(Config): class OIDCConfig(Config):
section = "oidc" section = "oidc"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
if not self.oidc_providers: if not self.oidc_providers:
return return
@ -66,7 +66,7 @@ class OIDCConfig(Config):
# OIDC is enabled if we have a provider # OIDC is enabled if we have a provider
return bool(self.oidc_providers) return bool(self.oidc_providers)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\ return """\
# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
# and login. # and login.
@ -495,89 +495,89 @@ def _parse_oidc_config_dict(
) )
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class OidcProviderClientSecretJwtKey: class OidcProviderClientSecretJwtKey:
# a pem-encoded signing key # a pem-encoded signing key
key = attr.ib(type=str) key: str
# properties to include in the JWT header # properties to include in the JWT header
jwt_header = attr.ib(type=Mapping[str, str]) jwt_header: Mapping[str, str]
# properties to include in the JWT payload. # properties to include in the JWT payload.
jwt_payload = attr.ib(type=Mapping[str, str]) jwt_payload: Mapping[str, str]
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class OidcProviderConfig: class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids' # a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol. # table, as well as the query/path parameter used in the login protocol.
idp_id = attr.ib(type=str) idp_id: str
# user-facing name for this identity provider. # user-facing name for this identity provider.
idp_name = attr.ib(type=str) idp_name: str
# Optional MXC URI for icon for this IdP. # Optional MXC URI for icon for this IdP.
idp_icon = attr.ib(type=Optional[str]) idp_icon: Optional[str]
# Optional brand identifier for this IdP. # Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str]) idp_brand: Optional[str]
# whether the OIDC discovery mechanism is used to discover endpoints # whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool) discover: bool
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints. # discover the provider's endpoints.
issuer = attr.ib(type=str) issuer: str
# oauth2 client id to use # oauth2 client id to use
client_id = attr.ib(type=str) client_id: str
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret. # a secret.
client_secret = attr.ib(type=Optional[str]) client_secret: Optional[str]
# key to use to construct a JWT to use as a client secret. May be `None` if # key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set. # `client_secret` is set.
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey]) client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey]
# auth method to use when exchanging the token. # auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and # Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'. # 'none'.
client_auth_method = attr.ib(type=str) client_auth_method: str
# list of scopes to request # list of scopes to request
scopes = attr.ib(type=Collection[str]) scopes: Collection[str]
# the oauth2 authorization endpoint. Required if discovery is disabled. # the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint = attr.ib(type=Optional[str]) authorization_endpoint: Optional[str]
# the oauth2 token endpoint. Required if discovery is disabled. # the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint = attr.ib(type=Optional[str]) token_endpoint: Optional[str]
# the OIDC userinfo endpoint. Required if discovery is disabled and the # the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested. # "openid" scope is not requested.
userinfo_endpoint = attr.ib(type=Optional[str]) userinfo_endpoint: Optional[str]
# URI where to fetch the JWKS. Required if discovery is disabled and the # URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used. # "openid" scope is used.
jwks_uri = attr.ib(type=Optional[str]) jwks_uri: Optional[str]
# Whether to skip metadata verification # Whether to skip metadata verification
skip_verification = attr.ib(type=bool) skip_verification: bool
# Whether to fetch the user profile from the userinfo endpoint. Valid # Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint". # values are: "auto" or "userinfo_endpoint".
user_profile_method = attr.ib(type=str) user_profile_method: str
# whether to allow a user logging in via OIDC to match a pre-existing account # whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing # instead of failing
allow_existing_users = attr.ib(type=bool) allow_existing_users: bool
# the class of the user mapping provider # the class of the user mapping provider
user_mapping_provider_class = attr.ib(type=Type) user_mapping_provider_class: Type
# the config of the user mapping provider # the config of the user mapping provider
user_mapping_provider_config = attr.ib() user_mapping_provider_config: Any
# required attributes to require in userinfo to allow login/registration # required attributes to require in userinfo to allow login/registration
attribute_requirements = attr.ib(type=List[SsoAttributeRequirement]) attribute_requirements: List[SsoAttributeRequirement]

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
from typing import Optional from typing import Optional
from synapse.api.constants import RoomCreationPreset from synapse.api.constants import RoomCreationPreset
@ -362,7 +364,7 @@ class RegistrationConfig(Config):
) )
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser: argparse.ArgumentParser) -> None:
reg_group = parser.add_argument_group("registration") reg_group = parser.add_argument_group("registration")
reg_group.add_argument( reg_group.add_argument(
"--enable-registration", "--enable-registration",
@ -371,6 +373,6 @@ class RegistrationConfig(Config):
help="Enable registration for new users.", help="Enable registration for new users.",
) )
def read_arguments(self, args): def read_arguments(self, args: argparse.Namespace) -> None:
if args.enable_registration is not None: if args.enable_registration is not None:
self.enable_registration = strtobool(str(args.enable_registration)) self.enable_registration = strtobool(str(args.enable_registration))

View file

@ -15,11 +15,12 @@
import logging import logging
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Dict, List from typing import Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore from urllib.request import getproxies_environment # type: ignore
from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -57,7 +58,9 @@ MediaStorageProviderConfig = namedtuple(
) )
def parse_thumbnail_requirements(thumbnail_sizes): def parse_thumbnail_requirements(
thumbnail_sizes: List[JsonDict],
) -> Dict[str, Tuple[ThumbnailRequirement, ...]]:
"""Takes a list of dictionaries with "width", "height", and "method" keys """Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumbnailing and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate method, and thumbnail media type to precalculate
@ -69,7 +72,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of Dictionary mapping from media type string to list of
ThumbnailRequirement tuples. ThumbnailRequirement tuples.
""" """
requirements: Dict[str, List] = {} requirements: Dict[str, List[ThumbnailRequirement]] = {}
for size in thumbnail_sizes: for size in thumbnail_sizes:
width = size["width"] width = size["width"]
height = size["height"] height = size["height"]

View file

@ -1,5 +1,5 @@
# Copyright 2018 New Vector Ltd # Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, List from typing import Any, List, Set
from synapse.config.sso import SsoAttributeRequirement from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
from synapse.util.module_loader import load_module, load_python_module from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -33,7 +34,7 @@ LEGACY_USER_MAPPING_PROVIDER = (
) )
def _dict_merge(merge_dict, into_dict): def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
"""Do a deep merge of two dicts """Do a deep merge of two dicts
Recursively merges `merge_dict` into `into_dict`: Recursively merges `merge_dict` into `into_dict`:
@ -43,8 +44,8 @@ def _dict_merge(merge_dict, into_dict):
the value from `merge_dict`. the value from `merge_dict`.
Args: Args:
merge_dict (dict): dict to merge merge_dict: dict to merge
into_dict (dict): target dict into_dict: target dict to be modified
""" """
for k, v in merge_dict.items(): for k, v in merge_dict.items():
if k not in into_dict: if k not in into_dict:
@ -64,7 +65,7 @@ def _dict_merge(merge_dict, into_dict):
class SAML2Config(Config): class SAML2Config(Config):
section = "saml2" section = "saml2"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
self.saml2_enabled = False self.saml2_enabled = False
saml2_config = config.get("saml2_config") saml2_config = config.get("saml2_config")
@ -183,8 +184,8 @@ class SAML2Config(Config):
) )
def _default_saml_config_dict( def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set self, required_attributes: Set[str], optional_attributes: Set[str]
): ) -> JsonDict:
"""Generate a configuration dictionary with required and optional attributes that """Generate a configuration dictionary with required and optional attributes that
will be needed to process new user registration will be needed to process new user registration
@ -195,7 +196,7 @@ class SAML2Config(Config):
additional information to Synapse user accounts, but are not required additional information to Synapse user accounts, but are not required
Returns: Returns:
dict: A SAML configuration dictionary A SAML configuration dictionary
""" """
import saml2 import saml2
@ -222,7 +223,7 @@ class SAML2Config(Config):
}, },
} }
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\ return """\
## Single sign-on integration ## ## Single sign-on integration ##

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import itertools import itertools
import logging import logging
import os.path import os.path
@ -27,6 +28,7 @@ from netaddr import AddrFormatError, IPNetwork, IPSet
from twisted.conch.ssh.keys import Key from twisted.conch.ssh.keys import Key
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.types import JsonDict
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
@ -1223,7 +1225,7 @@ class ServerConfig(Config):
% locals() % locals()
) )
def read_arguments(self, args): def read_arguments(self, args: argparse.Namespace) -> None:
if args.manhole is not None: if args.manhole is not None:
self.manhole = args.manhole self.manhole = args.manhole
if args.daemonize is not None: if args.daemonize is not None:
@ -1232,7 +1234,7 @@ class ServerConfig(Config):
self.print_pidfile = args.print_pidfile self.print_pidfile = args.print_pidfile
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser: argparse.ArgumentParser) -> None:
server_group = parser.add_argument_group("server") server_group = parser.add_argument_group("server")
server_group.add_argument( server_group.add_argument(
"-D", "-D",
@ -1274,14 +1276,16 @@ class ServerConfig(Config):
) )
def is_threepid_reserved(reserved_threepids, threepid): def is_threepid_reserved(
reserved_threepids: List[JsonDict], threepid: JsonDict
) -> bool:
"""Check the threepid against the reserved threepid config """Check the threepid against the reserved threepid config
Args: Args:
reserved_threepids([dict]) - list of reserved threepids reserved_threepids: List of reserved threepids
threepid(dict) - The threepid to test for threepid: The threepid to test for
Returns: Returns:
boolean Is the threepid undertest reserved_user Is the threepid undertest reserved_user
""" """
for tp in reserved_threepids: for tp in reserved_threepids:
@ -1290,7 +1294,9 @@ def is_threepid_reserved(reserved_threepids, threepid):
return False return False
def read_gc_thresholds(thresholds): def read_gc_thresholds(
thresholds: Optional[List[Any]],
) -> Optional[Tuple[int, int, int]]:
"""Reads the three integer thresholds for garbage collection. Ensures that """Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied. the thresholds are integers if thresholds are supplied.
""" """

View file

@ -1,4 +1,4 @@
# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,13 +29,13 @@ https://matrix-org.github.io/synapse/latest/templates.html
---------------------------------------------------------------------------------------""" ---------------------------------------------------------------------------------------"""
@attr.s(frozen=True) @attr.s(frozen=True, auto_attribs=True)
class SsoAttributeRequirement: class SsoAttributeRequirement:
"""Object describing a single requirement for SSO attributes.""" """Object describing a single requirement for SSO attributes."""
attribute = attr.ib(type=str) attribute: str
# If a value is not given, than the attribute must simply exist. # If a value is not given, than the attribute must simply exist.
value = attr.ib(type=Optional[str]) value: Optional[str]
JSON_SCHEMA = { JSON_SCHEMA = {
"type": "object", "type": "object",
@ -49,7 +49,7 @@ class SSOConfig(Config):
section = "sso" section = "sso"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs) -> None:
sso_config: Dict[str, Any] = config.get("sso") or {} sso_config: Dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir # The sso-specific template_dir
@ -106,7 +106,7 @@ class SSOConfig(Config):
) )
self.sso_client_whitelist.append(login_fallback_url) self.sso_client_whitelist.append(login_fallback_url)
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs) -> str:
return """\ return """\
# Additional settings to use with single-sign on systems such as OpenID Connect, # Additional settings to use with single-sign on systems such as OpenID Connect,
# SAML2 and CAS. # SAML2 and CAS.

View file

@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
from typing import List, Union from typing import List, Union
import attr import attr
@ -343,7 +345,7 @@ class WorkerConfig(Config):
#worker_replication_secret: "" #worker_replication_secret: ""
""" """
def read_arguments(self, args): def read_arguments(self, args: argparse.Namespace) -> None:
# We support a bunch of command line arguments that override options in # We support a bunch of command line arguments that override options in
# the config. A lot of these options have a worker_* prefix when running # the config. A lot of these options have a worker_* prefix when running
# on workers so we also have to override them when command line options # on workers so we also have to override them when command line options