From 55669bd3de7137553085e9f1c16a686ff657108c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 Nov 2021 10:21:19 -0500 Subject: [PATCH] Add missing type hints to config base classes (#11377) --- changelog.d/11377.bugfix | 1 + changelog.d/11377.misc | 1 + mypy.ini | 3 + synapse/config/_base.py | 157 +++++++++++------- synapse/config/_base.pyi | 87 +++++++--- synapse/config/cache.py | 4 +- synapse/config/key.py | 3 +- synapse/config/logger.py | 4 +- synapse/config/server.py | 4 +- synapse/config/tls.py | 2 +- synapse/module_api/__init__.py | 2 +- .../storage/databases/main/registration.py | 3 +- tests/config/test_load.py | 22 ++- 13 files changed, 184 insertions(+), 109 deletions(-) create mode 100644 changelog.d/11377.bugfix create mode 100644 changelog.d/11377.misc diff --git a/changelog.d/11377.bugfix b/changelog.d/11377.bugfix new file mode 100644 index 000000000..9831fb7bb --- /dev/null +++ b/changelog.d/11377.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.45.0 where the `read_templates` method of the module API would error. diff --git a/changelog.d/11377.misc b/changelog.d/11377.misc new file mode 100644 index 000000000..3dac62557 --- /dev/null +++ b/changelog.d/11377.misc @@ -0,0 +1 @@ +Add type hints to configuration classes. diff --git a/mypy.ini b/mypy.ini index 308cfd95d..bc4f59154 100644 --- a/mypy.ini +++ b/mypy.ini @@ -151,6 +151,9 @@ disallow_untyped_defs = True [mypy-synapse.app.*] disallow_untyped_defs = True +[mypy-synapse.config._base] +disallow_untyped_defs = True + [mypy-synapse.crypto.*] disallow_untyped_defs = True diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 7c4428a13..1265738dc 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -20,7 +20,18 @@ import os from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Iterable, List, MutableMapping, Optional, Union +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import attr import jinja2 @@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\ """ -def path_exists(file_path): +def path_exists(file_path: str) -> bool: """Check if a file exists Unlike os.path.exists, this throws an exception if there is an error @@ -86,7 +97,7 @@ def path_exists(file_path): the parent dir). Returns: - bool: True if the file exists; False if not. + True if the file exists; False if not. """ try: os.stat(file_path) @@ -102,15 +113,15 @@ class Config: A configuration section, containing configuration keys and values. Attributes: - section (str): The section title of this config object, such as + section: The section title of this config object, such as "tls" or "logger". This is used to refer to it on the root logger (for example, `config.tls.some_option`). Must be defined in subclasses. """ - section = None + section: str - def __init__(self, root_config=None): + def __init__(self, root_config: "RootConfig" = None): self.root = root_config # Get the path to the default Synapse template directory @@ -119,7 +130,7 @@ class Config: ) @staticmethod - def parse_size(value): + def parse_size(value: Union[str, int]) -> int: if isinstance(value, int): return value sizes = {"K": 1024, "M": 1024 * 1024} @@ -162,15 +173,15 @@ class Config: return int(value) * size @staticmethod - def abspath(file_path): + def abspath(file_path: str) -> str: return os.path.abspath(file_path) if file_path else file_path @classmethod - def path_exists(cls, file_path): + def path_exists(cls, file_path: str) -> bool: return path_exists(file_path) @classmethod - def check_file(cls, file_path, config_name): + def check_file(cls, file_path: Optional[str], config_name: str) -> str: if file_path is None: raise ConfigError("Missing config for %s." % (config_name,)) try: @@ -183,7 +194,7 @@ class Config: return cls.abspath(file_path) @classmethod - def ensure_directory(cls, dir_path): + def ensure_directory(cls, dir_path: str) -> str: dir_path = cls.abspath(dir_path) os.makedirs(dir_path, exist_ok=True) if not os.path.isdir(dir_path): @@ -191,7 +202,7 @@ class Config: return dir_path @classmethod - def read_file(cls, file_path, config_name): + def read_file(cls, file_path: Any, config_name: str) -> str: """Deprecated: call read_file directly""" return read_file(file_path, (config_name,)) @@ -284,6 +295,9 @@ class Config: return [env.get_template(filename) for filename in filenames] +TRootConfig = TypeVar("TRootConfig", bound="RootConfig") + + class RootConfig: """ Holder of an application's configuration. @@ -308,7 +322,9 @@ class RootConfig: raise Exception("Failed making %s: %r" % (config_class.section, e)) setattr(self, config_class.section, conf) - def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: + def invoke_all( + self, func_name: str, *args: Any, **kwargs: Any + ) -> MutableMapping[str, Any]: """ Invoke a function on all instantiated config objects this RootConfig is configured to use. @@ -317,6 +333,7 @@ class RootConfig: func_name: Name of function to invoke *args **kwargs + Returns: ordered dictionary of config section name and the result of the function from it. @@ -332,7 +349,7 @@ class RootConfig: return res @classmethod - def invoke_all_static(cls, func_name: str, *args, **kwargs): + def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None: """ Invoke a static function on config objects this RootConfig is configured to use. @@ -341,6 +358,7 @@ class RootConfig: func_name: Name of function to invoke *args **kwargs + Returns: ordered dictionary of config section name and the result of the function from it. @@ -351,16 +369,16 @@ class RootConfig: def generate_config( self, - config_dir_path, - data_dir_path, - server_name, - generate_secrets=False, - report_stats=None, - open_private_ports=False, - listeners=None, - tls_certificate_path=None, - tls_private_key_path=None, - ): + config_dir_path: str, + data_dir_path: str, + server_name: str, + generate_secrets: bool = False, + report_stats: Optional[bool] = None, + open_private_ports: bool = False, + listeners: Optional[List[dict]] = None, + tls_certificate_path: Optional[str] = None, + tls_private_key_path: Optional[str] = None, + ) -> str: """ Build a default configuration file @@ -368,27 +386,27 @@ class RootConfig: (eg with --generate_config). Args: - config_dir_path (str): The path where the config files are kept. Used to + config_dir_path: The path where the config files are kept. Used to create filenames for things like the log config and the signing key. - data_dir_path (str): The path where the data files are kept. Used to create + data_dir_path: The path where the data files are kept. Used to create filenames for things like the database and media store. - server_name (str): The server name. Used to initialise the server_name + server_name: The server name. Used to initialise the server_name config param, but also used in the names of some of the config files. - generate_secrets (bool): True if we should generate new secrets for things + generate_secrets: True if we should generate new secrets for things like the macaroon_secret_key. If False, these parameters will be left unset. - report_stats (bool|None): Initial setting for the report_stats setting. + report_stats: Initial setting for the report_stats setting. If None, report_stats will be left unset. - open_private_ports (bool): True to leave private ports (such as the non-TLS + open_private_ports: True to leave private ports (such as the non-TLS HTTP listener) open to the internet. - listeners (list(dict)|None): A list of descriptions of the listeners - synapse should start with each of which specifies a port (str), a list of + listeners: A list of descriptions of the listeners synapse should + start with each of which specifies a port (int), a list of resources (list(str)), tls (bool) and type (str). For example: [{ "port": 8448, @@ -403,16 +421,12 @@ class RootConfig: "type": "http", }], + tls_certificate_path: The path to the tls certificate. - database (str|None): The database type to configure, either `psycog2` - or `sqlite3`. - - tls_certificate_path (str|None): The path to the tls certificate. - - tls_private_key_path (str|None): The path to the tls private key. + tls_private_key_path: The path to the tls private key. Returns: - str: the yaml config file + The yaml config file """ return CONFIG_FILE_HEADER + "\n\n".join( @@ -432,12 +446,15 @@ class RootConfig: ) @classmethod - def load_config(cls, description, argv): + def load_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> TRootConfig: """Parse the commandline and config files Doesn't support config-file-generation: used by the worker apps. - Returns: Config object. + Returns: + Config object. """ config_parser = argparse.ArgumentParser(description=description) cls.add_arguments_to_parser(config_parser) @@ -446,7 +463,7 @@ class RootConfig: return obj @classmethod - def add_arguments_to_parser(cls, config_parser): + def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None: """Adds all the config flags to an ArgumentParser. Doesn't support config-file-generation: used by the worker apps. @@ -454,7 +471,7 @@ class RootConfig: Used for workers where we want to add extra flags/subcommands. Args: - config_parser (ArgumentParser): App description + config_parser: App description """ config_parser.add_argument( @@ -477,7 +494,9 @@ class RootConfig: cls.invoke_all_static("add_arguments", config_parser) @classmethod - def load_config_with_parser(cls, parser, argv): + def load_config_with_parser( + cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str] + ) -> Tuple[TRootConfig, argparse.Namespace]: """Parse the commandline and config files with the given parser Doesn't support config-file-generation: used by the worker apps. @@ -485,13 +504,12 @@ class RootConfig: Used for workers where we want to add extra flags/subcommands. Args: - parser (ArgumentParser) - argv (list[str]) + parser + argv Returns: - tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed - config object and the parsed argparse.Namespace object from - `parser.parse_args(..)` + Returns the parsed config object and the parsed argparse.Namespace + object from parser.parse_args(..)` """ obj = cls() @@ -520,12 +538,15 @@ class RootConfig: return obj, config_args @classmethod - def load_or_generate_config(cls, description, argv): + def load_or_generate_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> Optional[TRootConfig]: """Parse the commandline and config files Supports generation of config files, so is used for the main homeserver app. - Returns: Config object, or None if --generate-config or --generate-keys was set + Returns: + Config object, or None if --generate-config or --generate-keys was set """ parser = argparse.ArgumentParser(description=description) parser.add_argument( @@ -680,16 +701,21 @@ class RootConfig: return obj - def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): + def parse_config_dict( + self, + config_dict: Dict[str, Any], + config_dir_path: Optional[str] = None, + data_dir_path: Optional[str] = None, + ) -> None: """Read the information from the config dict into this Config object. Args: - config_dict (dict): Configuration data, as read from the yaml + config_dict: Configuration data, as read from the yaml - config_dir_path (str): The path where the config files are kept. Used to + config_dir_path: The path where the config files are kept. Used to create filenames for things like the log config and the signing key. - data_dir_path (str): The path where the data files are kept. Used to create + data_dir_path: The path where the data files are kept. Used to create filenames for things like the database and media store. """ self.invoke_all( @@ -699,17 +725,20 @@ class RootConfig: data_dir_path=data_dir_path, ) - def generate_missing_files(self, config_dict, config_dir_path): + def generate_missing_files( + self, config_dict: Dict[str, Any], config_dir_path: str + ) -> None: self.invoke_all("generate_files", config_dict, config_dir_path) -def read_config_files(config_files): +def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: """Read the config files into a dict Args: - config_files (iterable[str]): A list of the config files to read + config_files: A list of the config files to read - Returns: dict + Returns: + The configuration dictionary. """ specified_config = {} for config_file in config_files: @@ -733,17 +762,17 @@ def read_config_files(config_files): return specified_config -def find_config_files(search_paths): +def find_config_files(search_paths: List[str]) -> List[str]: """Finds config files using a list of search paths. If a path is a file then that file path is added to the list. If a search path is a directory then all the "*.yaml" files in that directory are added to the list in sorted order. Args: - search_paths(list(str)): A list of paths to search. + search_paths: A list of paths to search. Returns: - list(str): A list of file paths. + A list of file paths. """ config_files = [] @@ -777,7 +806,7 @@ def find_config_files(search_paths): return config_files -@attr.s +@attr.s(auto_attribs=True) class ShardedWorkerHandlingConfig: """Algorithm for choosing which instance is responsible for handling some sharded work. @@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig: below). """ - instances = attr.ib(type=List[str]) + instances: List[str] def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key.""" diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index c1d906979..1eb5f5a68 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,4 +1,18 @@ -from typing import Any, Iterable, List, Optional +import argparse +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import jinja2 from synapse.config import ( account_validity, @@ -19,6 +33,7 @@ from synapse.config import ( logger, metrics, modules, + oembed, oidc, password_auth_providers, push, @@ -27,6 +42,7 @@ from synapse.config import ( registration, repository, retention, + room, room_directory, saml2, server, @@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_SPIEL: str MISSING_SERVER_NAME: str -def path_exists(file_path: str): ... +def path_exists(file_path: str) -> bool: ... + +TRootConfig = TypeVar("TRootConfig", bound="RootConfig") class RootConfig: server: server.ServerConfig @@ -61,6 +79,7 @@ class RootConfig: logging: logger.LoggingConfig ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig + oembed: oembed.OembedConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig @@ -80,6 +99,7 @@ class RootConfig: authproviders: password_auth_providers.PasswordAuthProviderConfig push: push.PushConfig spamchecker: spam_checker.SpamCheckerConfig + room: room.RoomConfig groups: groups.GroupsConfig userdirectory: user_directory.UserDirectoryConfig consent: consent.ConsentConfig @@ -87,72 +107,85 @@ class RootConfig: servernotices: server_notices.ServerNoticesConfig roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig - tracer: tracer.TracerConfig + tracing: tracer.TracerConfig redis: redis.RedisConfig modules: modules.ModulesConfig caches: cache.CacheConfig federation: federation.FederationConfig retention: retention.RetentionConfig - config_classes: List = ... + config_classes: List[Type["Config"]] = ... def __init__(self) -> None: ... - def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ... + def invoke_all( + self, func_name: str, *args: Any, **kwargs: Any + ) -> MutableMapping[str, Any]: ... @classmethod def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... - def __getattr__(self, item: str): ... def parse_config_dict( self, - config_dict: Any, - config_dir_path: Optional[Any] = ..., - data_dir_path: Optional[Any] = ..., + config_dict: Dict[str, Any], + config_dir_path: Optional[str] = ..., + data_dir_path: Optional[str] = ..., ) -> None: ... - read_config: Any = ... def generate_config( self, config_dir_path: str, data_dir_path: str, server_name: str, generate_secrets: bool = ..., - report_stats: Optional[str] = ..., + report_stats: Optional[bool] = ..., open_private_ports: bool = ..., listeners: Optional[Any] = ..., - database_conf: Optional[Any] = ..., tls_certificate_path: Optional[str] = ..., tls_private_key_path: Optional[str] = ..., - ): ... + ) -> str: ... @classmethod - def load_or_generate_config(cls, description: Any, argv: Any): ... + def load_or_generate_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> Optional[TRootConfig]: ... @classmethod - def load_config(cls, description: Any, argv: Any): ... + def load_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> TRootConfig: ... @classmethod - def add_arguments_to_parser(cls, config_parser: Any) -> None: ... + def add_arguments_to_parser( + cls, config_parser: argparse.ArgumentParser + ) -> None: ... @classmethod - def load_config_with_parser(cls, parser: Any, argv: Any): ... + def load_config_with_parser( + cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str] + ) -> Tuple[TRootConfig, argparse.Namespace]: ... def generate_missing_files( self, config_dict: dict, config_dir_path: str ) -> None: ... class Config: root: RootConfig + default_template_dir: str def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... - def __getattr__(self, item: str, from_root: bool = ...): ... @staticmethod - def parse_size(value: Any): ... + def parse_size(value: Union[str, int]) -> int: ... @staticmethod - def parse_duration(value: Any): ... + def parse_duration(value: Union[str, int]) -> int: ... @staticmethod - def abspath(file_path: Optional[str]): ... + def abspath(file_path: Optional[str]) -> str: ... @classmethod - def path_exists(cls, file_path: str): ... + def path_exists(cls, file_path: str) -> bool: ... @classmethod - def check_file(cls, file_path: str, config_name: str): ... + def check_file(cls, file_path: str, config_name: str) -> str: ... @classmethod - def ensure_directory(cls, dir_path: str): ... + def ensure_directory(cls, dir_path: str) -> str: ... @classmethod - def read_file(cls, file_path: str, config_name: str): ... + def read_file(cls, file_path: str, config_name: str) -> str: ... + def read_template(self, filenames: str) -> jinja2.Template: ... + def read_templates( + self, + filenames: List[str], + custom_template_directories: Optional[Iterable[str]] = None, + ) -> List[jinja2.Template]: ... -def read_config_files(config_files: List[str]): ... -def find_config_files(search_paths: List[str]): ... +def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ... +def find_config_files(search_paths: List[str]) -> List[str]: ... class ShardedWorkerHandlingConfig: instances: List[str] diff --git a/synapse/config/cache.py b/synapse/config/cache.py index d119427ad..f05445553 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -15,7 +15,7 @@ import os import re import threading -from typing import Callable, Dict +from typing import Callable, Dict, Optional from synapse.python_dependencies import DependencyException, check_requirements @@ -217,7 +217,7 @@ class CacheConfig(Config): expiry_time = cache_config.get("expiry_time") if expiry_time: - self.expiry_time_msec = self.parse_duration(expiry_time) + self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time) else: self.expiry_time_msec = None diff --git a/synapse/config/key.py b/synapse/config/key.py index 015dbb8a6..035ee2416 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,6 +16,7 @@ import hashlib import logging import os +from typing import Any, Dict import attr import jsonschema @@ -312,7 +313,7 @@ class KeyConfig(Config): ) return keys - def generate_files(self, config, config_dir_path): + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: if "signing_key" in config: return diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 5252e61a9..63aab0bab 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -18,7 +18,7 @@ import os import sys import threading from string import Template -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict import yaml from zope.interface import implementer @@ -185,7 +185,7 @@ class LoggingConfig(Config): help=argparse.SUPPRESS, ) - def generate_files(self, config, config_dir_path): + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: log_config = config.get("log_config") if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") diff --git a/synapse/config/server.py b/synapse/config/server.py index 7bc0030a9..8445e9dd0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -421,7 +421,7 @@ class ServerConfig(Config): # before redacting them. redaction_retention_period = config.get("redaction_retention_period", "7d") if redaction_retention_period is not None: - self.redaction_retention_period = self.parse_duration( + self.redaction_retention_period: Optional[int] = self.parse_duration( redaction_retention_period ) else: @@ -430,7 +430,7 @@ class ServerConfig(Config): # How long to keep entries in the `users_ips` table. user_ips_max_age = config.get("user_ips_max_age", "28d") if user_ips_max_age is not None: - self.user_ips_max_age = self.parse_duration(user_ips_max_age) + self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age) else: self.user_ips_max_age = None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 21e5ddd15..4ca111618 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -245,7 +245,7 @@ class TlsConfig(Config): cert_path = self.tls_certificate_file logger.info("Loading TLS certificate from %s", cert_path) cert_pem = self.read_file(cert_path, "tls_certificate_path") - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode()) return cert diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ac8e8142f..96d7a8f2a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1014,7 +1014,7 @@ class ModuleApi: A list containing the loaded templates, with the orders matching the one of the filenames parameter. """ - return self._hs.config.read_templates( + return self._hs.config.server.read_templates( filenames, (td for td in (self.custom_template_dir, custom_template_directory) if td), ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8478463a2..0e8c16866 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1198,8 +1198,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts = now_ms + self._account_validity_period if use_delta: + assert self._account_validity_startup_job_max_delta is not None expiration_ts = random.randrange( - expiration_ts - self._account_validity_startup_job_max_delta, + int(expiration_ts - self._account_validity_startup_job_max_delta), expiration_ts, ) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index d8668d56b..69a4e9413 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -46,15 +46,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): "was: %r" % (config.key.macaroon_secret_key,) ) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + assert config2 is not None self.assertTrue( - hasattr(config.key, "macaroon_secret_key"), + hasattr(config2.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) - if len(config.key.macaroon_secret_key) < 5: + if len(config2.key.macaroon_secret_key) < 5: self.fail( "Want macaroon secret key to be string of at least length 5," - "was: %r" % (config.key.macaroon_secret_key,) + "was: %r" % (config2.key.macaroon_secret_key,) ) def test_load_succeeds_if_macaroon_secret_key_missing(self): @@ -62,6 +63,9 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + assert config1 is not None + assert config2 is not None + assert config3 is not None self.assertEqual( config1.key.macaroon_secret_key, config2.key.macaroon_secret_key ) @@ -78,14 +82,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.registration.enable_registration) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) - self.assertFalse(config.registration.enable_registration) + config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + assert config2 is not None + self.assertFalse(config2.registration.enable_registration) # Check that either config value is clobbered by the command line. - config = HomeServerConfig.load_or_generate_config( + config3 = HomeServerConfig.load_or_generate_config( "", ["-c", self.config_file, "--enable-registration"] ) - self.assertTrue(config.registration.enable_registration) + assert config3 is not None + self.assertTrue(config3.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics")