Use PydanticOrderedSet for tags
This commit is contained in:
parent
6b4917c371
commit
42d338cfe7
|
@ -67,7 +67,7 @@ def unwrap_tag(
|
|||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
for id in tag.values:
|
||||
for id in tag.value_ids:
|
||||
try:
|
||||
yield MinecraftItemIdIngredient.model_validate(
|
||||
{"item": id},
|
||||
|
|
|
@ -8,6 +8,7 @@ from hexdoc.core.loader import LoaderContext
|
|||
from hexdoc.core.resource import ResourceLocation
|
||||
from hexdoc.model import HexdocModel
|
||||
from hexdoc.utils.deserialize.json import decode_json_dict
|
||||
from hexdoc.utils.types import PydanticOrderedSet
|
||||
|
||||
|
||||
class OptionalTagValue(HexdocModel, frozen=True):
|
||||
|
@ -20,7 +21,7 @@ TagValue = ResourceLocation | OptionalTagValue
|
|||
|
||||
class Tag(HexdocModel):
|
||||
registry: str = Field(exclude=True)
|
||||
raw_values: set[TagValue] = Field(alias="values")
|
||||
values: PydanticOrderedSet[TagValue]
|
||||
replace: bool = False
|
||||
|
||||
@classmethod
|
||||
|
@ -30,7 +31,7 @@ class Tag(HexdocModel):
|
|||
id: ResourceLocation,
|
||||
context: LoaderContext,
|
||||
) -> Self:
|
||||
values = set[TagValue]()
|
||||
values = PydanticOrderedSet[TagValue]()
|
||||
replace = False
|
||||
|
||||
for _, _, tag in context.loader.load_resources(
|
||||
|
@ -42,7 +43,8 @@ class Tag(HexdocModel):
|
|||
):
|
||||
if tag.replace:
|
||||
values.clear()
|
||||
values.update(tag._load_values(context))
|
||||
for value in tag._load_values(context):
|
||||
values.add(value)
|
||||
|
||||
return Tag(registry=registry, values=values, replace=replace)
|
||||
|
||||
|
@ -54,11 +56,8 @@ class Tag(HexdocModel):
|
|||
)
|
||||
|
||||
@property
|
||||
def values(self) -> set[ResourceLocation]:
|
||||
return set(self.iter_values())
|
||||
|
||||
def iter_values(self) -> Iterator[ResourceLocation]:
|
||||
for value in self.raw_values:
|
||||
def value_ids(self) -> Iterator[ResourceLocation]:
|
||||
for value in self.values:
|
||||
match value:
|
||||
case ResourceLocation():
|
||||
yield value
|
||||
|
@ -70,12 +69,12 @@ class Tag(HexdocModel):
|
|||
tag = self
|
||||
else:
|
||||
tag = self.model_copy(
|
||||
update={"raw_values": current.raw_values | self.raw_values},
|
||||
update={"raw_values": current.values | self.values},
|
||||
)
|
||||
return tag.model_dump_json(by_alias=True)
|
||||
|
||||
def _load_values(self, context: LoaderContext) -> Iterator[TagValue]:
|
||||
for value in self.raw_values:
|
||||
for value in self.values:
|
||||
match value:
|
||||
case (
|
||||
(ResourceLocation() as child_id) | OptionalTagValue(id=child_id)
|
||||
|
|
|
@ -25,10 +25,11 @@ class BookContext(
|
|||
|
||||
@model_validator(mode="after")
|
||||
def _post_root_load_tags(self) -> Self:
|
||||
self.spoilered_advancements |= Tag.load(
|
||||
tag = Tag.load(
|
||||
registry="hexdoc",
|
||||
id=ResourceLocation("hexcasting", "spoilered_advancements"),
|
||||
context=self,
|
||||
).values
|
||||
)
|
||||
self.spoilered_advancements.update(tag.value_ids)
|
||||
|
||||
return self
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
# pyright: reportUnknownArgumentType=false
|
||||
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Mapping, Protocol, TypeVar
|
||||
from typing import Any, Mapping, Protocol, TypeVar, get_args
|
||||
|
||||
from pydantic import field_validator, model_validator
|
||||
from ordered_set import OrderedSet, OrderedSetInitializer
|
||||
from pydantic import GetCoreSchemaHandler, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from hexdoc.model import DEFAULT_CONFIG
|
||||
|
||||
|
@ -96,3 +100,58 @@ class TryGetEnum(Enum):
|
|||
return cls(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# https://docs.pydantic.dev/latest/concepts/types/#generic-containers
|
||||
class PydanticOrderedSet(OrderedSet[_T]):
|
||||
def __init__(self, initial: OrderedSetInitializer[_T] | None = None):
|
||||
super().__init__(initial or [])
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
source: type[Any],
|
||||
handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
match get_args(source):
|
||||
case [type_arg]:
|
||||
pass
|
||||
case []:
|
||||
type_arg = Any
|
||||
case args:
|
||||
raise ValueError(f"Expected 0 or 1 type args, got {len(args)}: {args}")
|
||||
|
||||
return core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(cls),
|
||||
cls._get_non_instance_schema(type_arg, handler),
|
||||
],
|
||||
serialization=cls._get_ser_schema(type_arg, handler),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_non_instance_schema(
|
||||
cls,
|
||||
type_arg: type[Any],
|
||||
handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
# validate from OrderedSetInitializer
|
||||
return core_schema.no_info_after_validator_function(
|
||||
function=PydanticOrderedSet,
|
||||
schema=handler.generate_schema(OrderedSetInitializer[type_arg]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_ser_schema(
|
||||
cls,
|
||||
type_arg: type[Any],
|
||||
handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.SerSchema:
|
||||
# serialize to list
|
||||
return core_schema.plain_serializer_function_ser_schema(
|
||||
function=cls._get_items,
|
||||
return_schema=handler.generate_schema(list[type_arg]),
|
||||
)
|
||||
|
||||
def _get_items(self):
|
||||
return self.items
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from hexdoc.utils.types import Color
|
||||
from hexdoc.utils.types import Color, PydanticOrderedSet
|
||||
|
||||
colors: list[str] = [
|
||||
"#0099FF",
|
||||
|
@ -17,3 +18,20 @@ colors: list[str] = [
|
|||
@pytest.mark.parametrize("s", colors)
|
||||
def test_color(s: str):
|
||||
assert Color(s).value == "0099ff"
|
||||
|
||||
|
||||
def test_ordered_set_round_trip():
|
||||
data = [3, 1, 3, 2, 1]
|
||||
ta = TypeAdapter(PydanticOrderedSet[int])
|
||||
|
||||
ordered_set = ta.validate_python(data)
|
||||
|
||||
assert ordered_set.items == [3, 1, 2]
|
||||
|
||||
|
||||
def test_ordered_set_validation_error():
|
||||
data = [1, "a"]
|
||||
ta = TypeAdapter(PydanticOrderedSet[int])
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ta.validate_python(data)
|
||||
|
|
|
@ -48,6 +48,7 @@ dependencies = [
|
|||
"pluggy~=1.3",
|
||||
"typer[all]~=0.9.0",
|
||||
"requests~=2.31",
|
||||
"ordered-set~=4.1",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
Loading…
Reference in a new issue