0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-10 12:02:43 +01:00

Switch InstanceLocationConfig to a pydantic BaseModel (#15431)

* Switch InstanceLocationConfig to a pydantic BaseModel, apply Strict* types and add a few helper methods(that will make more sense in follow up work).

Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
Jason Little 2023-04-17 18:53:43 -05:00 committed by GitHub
parent d935b806a5
commit e12d788bb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 10 deletions

View file

@ -0,0 +1 @@
Add some validation to `instance_map` configuration loading.

View file

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable
from typing import Any, Dict, Iterable, Type, TypeVar
import jsonschema
from pydantic import BaseModel, ValidationError, parse_obj_as
from synapse.config._base import ConfigError
from synapse.types import JsonDict
@ -64,3 +65,28 @@ def json_error_to_config_error(
else:
path.append(str(p))
return ConfigError(e.message, path)
Model = TypeVar("Model", bound=BaseModel)
def parse_and_validate_mapping(
config: Any,
model_type: Type[Model],
) -> Dict[str, Model]:
"""Parse `config` as a mapping from strings to a given `Model` type.
Args:
config: The configuration data to check
model_type: The BaseModel to validate and parse against.
Returns:
Fully validated and parsed Dict[str, Model].
Raises:
ConfigError, if given improper input.
"""
try:
# type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
# `model_type` is a runtime variable. Pydantic is fine with this.
instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
except ValidationError as e:
raise ConfigError(str(e)) from e
return instances

View file

@ -18,6 +18,7 @@ import logging
from typing import Any, Dict, List, Union
import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from synapse.config._base import (
Config,
@ -25,6 +26,7 @@ from synapse.config._base import (
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
)
from synapse.config._util import parse_and_validate_mapping
from synapse.config.server import (
DIRECT_TCP_ERROR,
TCPListenerConfig,
@ -50,13 +52,43 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj
@attr.s(auto_attribs=True)
class InstanceLocationConfig:
class ConfigModel(BaseModel):
"""A custom version of Pydantic's BaseModel which
- ignores unknown fields and
- does not allow fields to be overwritten after construction,
but otherwise uses Pydantic's default behaviour.
For now, ignore unknown fields. In the future, we could change this so that unknown
config values cause a ValidationError, provided the error messages are meaningful to
server operators.
Subclassing in this way is recommended by
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
"""
class Config:
# By default, ignore fields that we don't recognise.
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
class InstanceLocationConfig(ConfigModel):
"""The host and port to talk to an instance via HTTP replication."""
host: str
port: int
tls: bool = False
host: StrictStr
port: StrictInt
tls: StrictBool = False
def scheme(self) -> str:
"""Hardcode a retrievable scheme based on self.tls"""
return "https" if self.tls else "http"
def netloc(self) -> str:
"""Nicely format the network location data"""
return f"{self.host}:{self.port}"
@attr.s
@ -183,10 +215,12 @@ class WorkerConfig(Config):
)
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
self.instance_map: Dict[
str, InstanceLocationConfig
] = parse_and_validate_mapping(
config.get("instance_map", {}),
InstanceLocationConfig,
)
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}