0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 18:53:53 +01:00

Switch InstanceLocationConfig to a pydantic BaseModel (#15431)

* Switch InstanceLocationConfig to a pydantic BaseModel, apply Strict* types and add a few helper methods(that will make more sense in follow up work).

Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
Jason Little 2023-04-17 18:53:43 -05:00 committed by GitHub
parent d935b806a5
commit e12d788bb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 10 deletions

View file

@ -0,0 +1 @@
Add some validation to `instance_map` configuration loading.

View file

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterable from typing import Any, Dict, Iterable, Type, TypeVar
import jsonschema import jsonschema
from pydantic import BaseModel, ValidationError, parse_obj_as
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.types import JsonDict from synapse.types import JsonDict
@ -64,3 +65,28 @@ def json_error_to_config_error(
else: else:
path.append(str(p)) path.append(str(p))
return ConfigError(e.message, path) return ConfigError(e.message, path)
Model = TypeVar("Model", bound=BaseModel)
def parse_and_validate_mapping(
config: Any,
model_type: Type[Model],
) -> Dict[str, Model]:
"""Parse `config` as a mapping from strings to a given `Model` type.
Args:
config: The configuration data to check
model_type: The BaseModel to validate and parse against.
Returns:
Fully validated and parsed Dict[str, Model].
Raises:
ConfigError, if given improper input.
"""
try:
# type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
# `model_type` is a runtime variable. Pydantic is fine with this.
instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
except ValidationError as e:
raise ConfigError(str(e)) from e
return instances

View file

@ -18,6 +18,7 @@ import logging
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import attr import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from synapse.config._base import ( from synapse.config._base import (
Config, Config,
@ -25,6 +26,7 @@ from synapse.config._base import (
RoutableShardedWorkerHandlingConfig, RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig, ShardedWorkerHandlingConfig,
) )
from synapse.config._util import parse_and_validate_mapping
from synapse.config.server import ( from synapse.config.server import (
DIRECT_TCP_ERROR, DIRECT_TCP_ERROR,
TCPListenerConfig, TCPListenerConfig,
@ -50,13 +52,43 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj return obj
@attr.s(auto_attribs=True) class ConfigModel(BaseModel):
class InstanceLocationConfig: """A custom version of Pydantic's BaseModel which
- ignores unknown fields and
- does not allow fields to be overwritten after construction,
but otherwise uses Pydantic's default behaviour.
For now, ignore unknown fields. In the future, we could change this so that unknown
config values cause a ValidationError, provided the error messages are meaningful to
server operators.
Subclassing in this way is recommended by
https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
"""
class Config:
# By default, ignore fields that we don't recognise.
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
class InstanceLocationConfig(ConfigModel):
"""The host and port to talk to an instance via HTTP replication.""" """The host and port to talk to an instance via HTTP replication."""
host: str host: StrictStr
port: int port: StrictInt
tls: bool = False tls: StrictBool = False
def scheme(self) -> str:
"""Hardcode a retrievable scheme based on self.tls"""
return "https" if self.tls else "http"
def netloc(self) -> str:
"""Nicely format the network location data"""
return f"{self.host}:{self.port}"
@attr.s @attr.s
@ -183,10 +215,12 @@ class WorkerConfig(Config):
) )
# A map from instance name to host/port of their HTTP replication endpoint. # A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {} self.instance_map: Dict[
self.instance_map = { str, InstanceLocationConfig
name: InstanceLocationConfig(**c) for name, c in instance_map.items() ] = parse_and_validate_mapping(
} config.get("instance_map", {}),
InstanceLocationConfig,
)
# Map from type of streams to source, c.f. WriterLocations. # Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {} writers = config.get("stream_writers") or {}