mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-10 12:02:43 +01:00
Switch InstanceLocationConfig
to a pydantic BaseModel
(#15431)
* Switch InstanceLocationConfig to a pydantic BaseModel, apply Strict* types and add a few helper methods(that will make more sense in follow up work). Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
parent
d935b806a5
commit
e12d788bb7
3 changed files with 71 additions and 10 deletions
1
changelog.d/15431.feature
Normal file
1
changelog.d/15431.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add some validation to `instance_map` configuration loading.
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Dict, Iterable, Type, TypeVar
|
||||
|
||||
import jsonschema
|
||||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.types import JsonDict
|
||||
|
@ -64,3 +65,28 @@ def json_error_to_config_error(
|
|||
else:
|
||||
path.append(str(p))
|
||||
return ConfigError(e.message, path)
|
||||
|
||||
|
||||
Model = TypeVar("Model", bound=BaseModel)
|
||||
|
||||
|
||||
def parse_and_validate_mapping(
|
||||
config: Any,
|
||||
model_type: Type[Model],
|
||||
) -> Dict[str, Model]:
|
||||
"""Parse `config` as a mapping from strings to a given `Model` type.
|
||||
Args:
|
||||
config: The configuration data to check
|
||||
model_type: The BaseModel to validate and parse against.
|
||||
Returns:
|
||||
Fully validated and parsed Dict[str, Model].
|
||||
Raises:
|
||||
ConfigError, if given improper input.
|
||||
"""
|
||||
try:
|
||||
# type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
|
||||
# `model_type` is a runtime variable. Pydantic is fine with this.
|
||||
instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
|
||||
except ValidationError as e:
|
||||
raise ConfigError(str(e)) from e
|
||||
return instances
|
||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
|||
from typing import Any, Dict, List, Union
|
||||
|
||||
import attr
|
||||
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
|
||||
|
||||
from synapse.config._base import (
|
||||
Config,
|
||||
|
@ -25,6 +26,7 @@ from synapse.config._base import (
|
|||
RoutableShardedWorkerHandlingConfig,
|
||||
ShardedWorkerHandlingConfig,
|
||||
)
|
||||
from synapse.config._util import parse_and_validate_mapping
|
||||
from synapse.config.server import (
|
||||
DIRECT_TCP_ERROR,
|
||||
TCPListenerConfig,
|
||||
|
@ -50,13 +52,43 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
|||
return obj
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class InstanceLocationConfig:
|
||||
class ConfigModel(BaseModel):
|
||||
"""A custom version of Pydantic's BaseModel which
|
||||
|
||||
- ignores unknown fields and
|
||||
- does not allow fields to be overwritten after construction,
|
||||
|
||||
but otherwise uses Pydantic's default behaviour.
|
||||
|
||||
For now, ignore unknown fields. In the future, we could change this so that unknown
|
||||
config values cause a ValidationError, provided the error messages are meaningful to
|
||||
server operators.
|
||||
|
||||
Subclassing in this way is recommended by
|
||||
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
|
||||
"""
|
||||
|
||||
class Config:
|
||||
# By default, ignore fields that we don't recognise.
|
||||
extra = Extra.ignore
|
||||
# By default, don't allow fields to be reassigned after parsing.
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
class InstanceLocationConfig(ConfigModel):
|
||||
"""The host and port to talk to an instance via HTTP replication."""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
tls: bool = False
|
||||
host: StrictStr
|
||||
port: StrictInt
|
||||
tls: StrictBool = False
|
||||
|
||||
def scheme(self) -> str:
|
||||
"""Hardcode a retrievable scheme based on self.tls"""
|
||||
return "https" if self.tls else "http"
|
||||
|
||||
def netloc(self) -> str:
|
||||
"""Nicely format the network location data"""
|
||||
return f"{self.host}:{self.port}"
|
||||
|
||||
|
||||
@attr.s
|
||||
|
@ -183,10 +215,12 @@ class WorkerConfig(Config):
|
|||
)
|
||||
|
||||
# A map from instance name to host/port of their HTTP replication endpoint.
|
||||
instance_map = config.get("instance_map") or {}
|
||||
self.instance_map = {
|
||||
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
|
||||
}
|
||||
self.instance_map: Dict[
|
||||
str, InstanceLocationConfig
|
||||
] = parse_and_validate_mapping(
|
||||
config.get("instance_map", {}),
|
||||
InstanceLocationConfig,
|
||||
)
|
||||
|
||||
# Map from type of streams to source, c.f. WriterLocations.
|
||||
writers = config.get("stream_writers") or {}
|
||||
|
|
Loading…
Reference in a new issue