Use PydanticOrderedSet for tags

This commit is contained in:
object-Object 2023-10-24 02:31:05 -04:00
parent 6b4917c371
commit 42d338cfe7
6 changed files with 94 additions and 16 deletions

View file

@ -67,7 +67,7 @@ def unwrap_tag(
except FileNotFoundError:
return
for id in tag.values:
for id in tag.value_ids:
try:
yield MinecraftItemIdIngredient.model_validate(
{"item": id},

View file

@ -8,6 +8,7 @@ from hexdoc.core.loader import LoaderContext
from hexdoc.core.resource import ResourceLocation
from hexdoc.model import HexdocModel
from hexdoc.utils.deserialize.json import decode_json_dict
from hexdoc.utils.types import PydanticOrderedSet
class OptionalTagValue(HexdocModel, frozen=True):
@ -20,7 +21,7 @@ TagValue = ResourceLocation | OptionalTagValue
class Tag(HexdocModel):
registry: str = Field(exclude=True)
raw_values: set[TagValue] = Field(alias="values")
values: PydanticOrderedSet[TagValue]
replace: bool = False
@classmethod
@ -30,7 +31,7 @@ class Tag(HexdocModel):
id: ResourceLocation,
context: LoaderContext,
) -> Self:
values = set[TagValue]()
values = PydanticOrderedSet[TagValue]()
replace = False
for _, _, tag in context.loader.load_resources(
@ -42,7 +43,8 @@ class Tag(HexdocModel):
):
if tag.replace:
values.clear()
values.update(tag._load_values(context))
for value in tag._load_values(context):
values.add(value)
return Tag(registry=registry, values=values, replace=replace)
@ -54,11 +56,8 @@ class Tag(HexdocModel):
)
@property
def values(self) -> set[ResourceLocation]:
return set(self.iter_values())
def iter_values(self) -> Iterator[ResourceLocation]:
for value in self.raw_values:
def value_ids(self) -> Iterator[ResourceLocation]:
for value in self.values:
match value:
case ResourceLocation():
yield value
@ -70,12 +69,12 @@ class Tag(HexdocModel):
tag = self
else:
tag = self.model_copy(
update={"raw_values": current.raw_values | self.raw_values},
update={"raw_values": current.values | self.values},
)
return tag.model_dump_json(by_alias=True)
def _load_values(self, context: LoaderContext) -> Iterator[TagValue]:
for value in self.raw_values:
for value in self.values:
match value:
case (
(ResourceLocation() as child_id) | OptionalTagValue(id=child_id)

View file

@ -25,10 +25,11 @@ class BookContext(
@model_validator(mode="after")
def _post_root_load_tags(self) -> Self:
self.spoilered_advancements |= Tag.load(
tag = Tag.load(
registry="hexdoc",
id=ResourceLocation("hexcasting", "spoilered_advancements"),
context=self,
).values
)
self.spoilered_advancements.update(tag.value_ids)
return self

View file

@ -1,10 +1,14 @@
# pyright: reportUnknownArgumentType=false
import string
from abc import ABC, abstractmethod
from enum import Enum, unique
from typing import Any, Mapping, Protocol, TypeVar
from typing import Any, Mapping, Protocol, TypeVar, get_args
from pydantic import field_validator, model_validator
from ordered_set import OrderedSet, OrderedSetInitializer
from pydantic import GetCoreSchemaHandler, field_validator, model_validator
from pydantic.dataclasses import dataclass
from pydantic_core import core_schema
from hexdoc.model import DEFAULT_CONFIG
@ -96,3 +100,58 @@ class TryGetEnum(Enum):
return cls(value)
except ValueError:
return None
# https://docs.pydantic.dev/latest/concepts/types/#generic-containers
class PydanticOrderedSet(OrderedSet[_T]):
def __init__(self, initial: OrderedSetInitializer[_T] | None = None):
super().__init__(initial or [])
@classmethod
def __get_pydantic_core_schema__(
cls,
source: type[Any],
handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
match get_args(source):
case [type_arg]:
pass
case []:
type_arg = Any
case args:
raise ValueError(f"Expected 0 or 1 type args, got {len(args)}: {args}")
return core_schema.union_schema(
[
core_schema.is_instance_schema(cls),
cls._get_non_instance_schema(type_arg, handler),
],
serialization=cls._get_ser_schema(type_arg, handler),
)
@classmethod
def _get_non_instance_schema(
cls,
type_arg: type[Any],
handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
# validate from OrderedSetInitializer
return core_schema.no_info_after_validator_function(
function=PydanticOrderedSet,
schema=handler.generate_schema(OrderedSetInitializer[type_arg]),
)
@classmethod
def _get_ser_schema(
cls,
type_arg: type[Any],
handler: GetCoreSchemaHandler,
) -> core_schema.SerSchema:
# serialize to list
return core_schema.plain_serializer_function_ser_schema(
function=cls._get_items,
return_schema=handler.generate_schema(list[type_arg]),
)
def _get_items(self):
return self.items

View file

@ -1,6 +1,7 @@
import pytest
from pydantic import TypeAdapter, ValidationError
from hexdoc.utils.types import Color
from hexdoc.utils.types import Color, PydanticOrderedSet
colors: list[str] = [
"#0099FF",
@ -17,3 +18,20 @@ colors: list[str] = [
@pytest.mark.parametrize("s", colors)
def test_color(s: str):
assert Color(s).value == "0099ff"
def test_ordered_set_round_trip():
data = [3, 1, 3, 2, 1]
ta = TypeAdapter(PydanticOrderedSet[int])
ordered_set = ta.validate_python(data)
assert ordered_set.items == [3, 1, 2]
def test_ordered_set_validation_error():
data = [1, "a"]
ta = TypeAdapter(PydanticOrderedSet[int])
with pytest.raises(ValidationError):
ta.validate_python(data)

View file

@ -48,6 +48,7 @@ dependencies = [
"pluggy~=1.3",
"typer[all]~=0.9.0",
"requests~=2.31",
"ordered-set~=4.1",
]
dynamic = ["version"]