Use entry hooks to load union types
This commit is contained in:
parent
371d56033e
commit
5e61d3b8b9
11 changed files with 89 additions and 78 deletions
|
@ -1,4 +1,4 @@
|
|||
# build
|
||||
# project metadata
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
@ -20,17 +20,19 @@ dependencies = [
|
|||
"dacite@git+https://github.com/mciszczon/dacite@f298260c6aedc1097c7567b1b0a61298a0ddf2a8",
|
||||
]
|
||||
|
||||
[project.entry-points."tagged_unions.Page"]
|
||||
patchouli = "patchouli.page.patchouli_pages"
|
||||
hexcasting = "hexcasting.hex_pages"
|
||||
[project.entry-points."hexdoc.Page"]
|
||||
hexdoc-patchouli = "patchouli.page.pages"
|
||||
hexdoc-hexcasting = "hexcasting.hex_pages"
|
||||
|
||||
[project.entry-points."tagged_unions.Recipe"]
|
||||
patchouli = "minecraft.recipe.minecraft_recipes"
|
||||
hexcasting = "hexcasting.hex_recipes"
|
||||
[project.entry-points."hexdoc.Recipe"]
|
||||
hexdoc-minecraft = "minecraft.recipe.recipes"
|
||||
hexdoc-hexcasting = "hexcasting.hex_recipes"
|
||||
|
||||
[project.entry-points."tagged_unions.ItemIngredient"]
|
||||
patchouli = "minecraft.recipe.ingredients"
|
||||
hexcasting = "hexcasting.hex_recipes"
|
||||
[project.entry-points."hexdoc.ItemIngredient"]
|
||||
hexdoc-minecraft = "minecraft.recipe.ingredients"
|
||||
hexdoc-hexcasting = "hexcasting.hex_recipes"
|
||||
|
||||
# Hatch settings (the build backend)
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true # TODO: remove when we switch to Pydantic
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Generator, Self
|
||||
|
||||
from dacite import StrictUnionMatchError, UnionMatchError, from_dict
|
||||
from pkg_resources import iter_entry_points
|
||||
|
||||
from common.dacite_patch import UnionSkip
|
||||
from common.deserialize import TypedConfig
|
||||
|
@ -34,16 +34,17 @@ class WrongTagSkip(UnionSkip):
|
|||
tag_value: TagValue,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
f"Expected {union_type.__tag_key}={union_type.__expected_tag_value}, got {tag_value}"
|
||||
f"Expected {union_type._tag_key}={union_type.__expected_tag_value}, got {tag_value}"
|
||||
)
|
||||
|
||||
|
||||
class InternallyTaggedUnion(ABC):
|
||||
class InternallyTaggedUnion:
|
||||
"""Implements [internally tagged unions](https://serde.rs/enum-representations.html#internally-tagged)
|
||||
using the [Registry pattern](https://charlesreid1.github.io/python-patterns-the-registry.html).
|
||||
|
||||
NOTE: Make sure the module where you declare your union types is actually imported,
|
||||
or they won't be registered.
|
||||
To ensure your subtypes are loaded even if they're not imported by any file, add
|
||||
the module as a plugin to your package's entry points. For example, to add subtypes
|
||||
to a union
|
||||
|
||||
Args:
|
||||
key: The dict key for the internal tag. Should be None for classes which are not
|
||||
|
@ -52,22 +53,46 @@ class InternallyTaggedUnion(ABC):
|
|||
shouldn't be instantiated (eg. abstract classes).
|
||||
"""
|
||||
|
||||
__tag_key: ClassVar[str | None] = None
|
||||
_loaded_groups: ClassVar[set[str]] = set()
|
||||
"""Global set of groups whose plugins have already been loaded. Do not overwrite.
|
||||
|
||||
We use this so we don't have to load the same modules over and over again.
|
||||
"""
|
||||
|
||||
_group: ClassVar[str | None] = None
|
||||
_tag_key: ClassVar[str | None] = None
|
||||
|
||||
__expected_tag_value: ClassVar[TagValue | None]
|
||||
|
||||
__all_subtypes: ClassVar[set[type[Self]]]
|
||||
__concrete_subtypes: ClassVar[defaultdict[TagValue, set[type[Self]]]]
|
||||
|
||||
def __init_subclass__(cls, key: str | None, value: TagValue | None) -> None:
|
||||
# don't bother initializing classes which aren't part of any union
|
||||
if key is None:
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
*,
|
||||
group: str | None = None,
|
||||
key: str | None = None,
|
||||
value: TagValue | None,
|
||||
) -> None:
|
||||
# inherited data, so only set if not None
|
||||
if group is not None:
|
||||
cls._group = group
|
||||
if key is not None:
|
||||
cls._tag_key = key
|
||||
|
||||
# don't bother with rest of init if it's not part of a union
|
||||
if cls._tag_key is None:
|
||||
if cls._group is not None:
|
||||
raise ValueError(
|
||||
f"Expected cls._group=None for {cls} with key=None, got {cls._group}"
|
||||
)
|
||||
if value is not None:
|
||||
raise ValueError(
|
||||
f"Expected value=None for {cls} with key=None, got {value}"
|
||||
)
|
||||
return
|
||||
|
||||
# set up per-class data and lookups
|
||||
cls.__tag_key = key
|
||||
# per-class data and lookups
|
||||
cls.__expected_tag_value = value
|
||||
cls.__all_subtypes = set()
|
||||
cls.__concrete_subtypes = defaultdict(set)
|
||||
|
@ -81,7 +106,7 @@ class InternallyTaggedUnion(ABC):
|
|||
|
||||
@classmethod
|
||||
def _tag_key_or_raise(cls) -> str:
|
||||
if (tag_key := cls.__tag_key) is None:
|
||||
if (tag_key := cls._tag_key) is None:
|
||||
raise NotImplementedError
|
||||
return tag_key
|
||||
|
||||
|
@ -95,7 +120,7 @@ class InternallyTaggedUnion(ABC):
|
|||
# recursively yield bases
|
||||
# stop when we reach a non-union or a type with a different key (or no key)
|
||||
for base in cls.__bases__:
|
||||
if issubclass(base, InternallyTaggedUnion) and base.__tag_key == tag_key:
|
||||
if issubclass(base, InternallyTaggedUnion) and base._tag_key == tag_key:
|
||||
yield base
|
||||
yield from base._supertypes()
|
||||
|
||||
|
@ -109,6 +134,18 @@ class InternallyTaggedUnion(ABC):
|
|||
|
||||
@classmethod
|
||||
def _resolve_from_dict(cls, data: Self | Any, config: TypedConfig) -> Self:
|
||||
# if we haven't yet, load plugins from entry points
|
||||
if cls._group is not None and cls._group not in cls._loaded_groups:
|
||||
cls._loaded_groups.add(cls._group)
|
||||
for entry_point in iter_entry_points(cls._group):
|
||||
try:
|
||||
entry_point.load()
|
||||
except ModuleNotFoundError as e:
|
||||
e.add_note(
|
||||
f'Note: Tried to load entry point "{entry_point}" from {entry_point.dist}'
|
||||
)
|
||||
raise
|
||||
|
||||
# do this first so we know it's part of a union
|
||||
tag_key = cls._tag_key_or_raise()
|
||||
|
||||
|
@ -137,8 +174,8 @@ class InternallyTaggedUnion(ABC):
|
|||
union_matches[inner_type] = value
|
||||
except UnionSkip:
|
||||
pass
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
except Exception as entry_point:
|
||||
exceptions.append(entry_point)
|
||||
|
||||
# ensure we only matched one
|
||||
match len(union_matches):
|
||||
|
@ -151,11 +188,6 @@ class InternallyTaggedUnion(ABC):
|
|||
|
||||
# oopsies
|
||||
raise ExceptionGroup(
|
||||
f"Failed to match {cls} with {cls.__tag_key}={tag_value} to any of {tag_types}: {data}",
|
||||
f"Failed to match {cls} with {cls._tag_key}={tag_value} to any of {tag_types}: {data}",
|
||||
exceptions,
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _tag_value(self) -> str:
|
||||
...
|
||||
|
|
|
@ -1,38 +1,7 @@
|
|||
__all__ = [
|
||||
"BlockState",
|
||||
"BlockStateIngredient",
|
||||
"BrainsweepRecipe",
|
||||
"ModConditionalIngredient",
|
||||
"HexBook",
|
||||
"HexBookState",
|
||||
"VillagerIngredient",
|
||||
"PageWithPattern",
|
||||
"LookupPatternPage",
|
||||
"ManualPatternNosigPage",
|
||||
"ManualOpPatternPage",
|
||||
"ManualRawPatternPage",
|
||||
"CraftingMultiPage",
|
||||
"BrainsweepPage",
|
||||
]
|
||||
|
||||
from patchouli import Book
|
||||
|
||||
from .hex_pages import (
|
||||
BrainsweepPage,
|
||||
CraftingMultiPage,
|
||||
LookupPatternPage,
|
||||
ManualOpPatternPage,
|
||||
ManualPatternNosigPage,
|
||||
ManualRawPatternPage,
|
||||
PageWithPattern,
|
||||
)
|
||||
from .hex_recipes import (
|
||||
BlockState,
|
||||
BlockStateIngredient,
|
||||
BrainsweepRecipe,
|
||||
ModConditionalIngredient,
|
||||
VillagerIngredient,
|
||||
)
|
||||
from .hex_state import HexBookState
|
||||
|
||||
HexBook = Book[HexBookState]
|
||||
from .hex_state import HexBook, HexBookState
|
||||
|
|
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
|||
from typing import Any, Literal
|
||||
|
||||
from common.types import LocalizedItem
|
||||
from minecraft.recipe import Recipe
|
||||
from minecraft.recipe.ingredients import (
|
||||
from minecraft.recipe import (
|
||||
ItemIngredient,
|
||||
MinecraftItemIdIngredient,
|
||||
MinecraftItemTagIngredient,
|
||||
Recipe,
|
||||
)
|
||||
from minecraft.resource import ResourceLocation
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any
|
|||
|
||||
from common.pattern import PatternInfo
|
||||
from minecraft.resource import ResourceLocation
|
||||
from patchouli.book import Book
|
||||
from patchouli.state import BookState
|
||||
|
||||
|
||||
|
@ -26,3 +27,6 @@ class HexBookState(BookState):
|
|||
f"Duplicate pattern {pattern.id}\n{pattern}\n{duplicate}"
|
||||
)
|
||||
self.patterns[pattern.id] = pattern
|
||||
|
||||
|
||||
HexBook = Book[HexBookState]
|
||||
|
|
|
@ -13,13 +13,12 @@ from hexcasting.hex_pages import (
|
|||
LookupPatternPage,
|
||||
PageWithPattern,
|
||||
)
|
||||
from patchouli import Category, Entry, FormatTree
|
||||
from patchouli import Category, Entry, FormatTree, Page
|
||||
from patchouli.page import (
|
||||
CraftingPage,
|
||||
EmptyPage,
|
||||
ImagePage,
|
||||
LinkPage,
|
||||
Page,
|
||||
PageWithText,
|
||||
PageWithTitle,
|
||||
SpotlightPage,
|
||||
|
|
|
@ -16,7 +16,7 @@ class ItemResult:
|
|||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Recipe(StatefulTypeTaggedUnion[AnyState], type=None):
|
||||
class Recipe(StatefulTypeTaggedUnion[AnyState], group="hexdoc.Recipe", type=None):
|
||||
id: ResourceLocation
|
||||
group: str | None = None
|
||||
|
||||
|
|
|
@ -5,7 +5,11 @@ from minecraft.resource import ResourceLocation
|
|||
from patchouli.state import AnyState, BookState, StatefulTypeTaggedUnion
|
||||
|
||||
|
||||
class ItemIngredient(StatefulTypeTaggedUnion[AnyState], type=None):
|
||||
class ItemIngredient(
|
||||
StatefulTypeTaggedUnion[AnyState],
|
||||
group="hexdoc.ItemIngredient",
|
||||
type=None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ from typing import Literal, Self
|
|||
from common.deserialize import from_dict_checked, load_json_data, rename
|
||||
from common.types import Color, LocalizedStr
|
||||
from minecraft.i18n import I18n
|
||||
from minecraft.recipe import Recipe
|
||||
from minecraft.recipe.ingredients import ItemIngredient
|
||||
from minecraft.recipe import ItemIngredient, Recipe
|
||||
from minecraft.resource import ItemStack, ResLoc, ResourceLocation
|
||||
|
||||
from .category import Category
|
||||
|
|
|
@ -14,7 +14,7 @@ _T = TypeVar("_T")
|
|||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Page(StatefulTypeTaggedUnion[AnyState], type=None):
|
||||
class Page(StatefulTypeTaggedUnion[AnyState], group="hexdoc.Page", type=None):
|
||||
"""Base class for Patchouli page types.
|
||||
|
||||
See: https://vazkiimods.github.io/Patchouli/docs/patchouli-basics/page-types
|
||||
|
|
|
@ -120,6 +120,7 @@ class StatefulFile(Stateful[AnyState]):
|
|||
class StatefulInternallyTaggedUnion(
|
||||
Stateful[AnyState],
|
||||
InternallyTaggedUnion,
|
||||
group=None,
|
||||
key=None,
|
||||
value=None,
|
||||
):
|
||||
|
@ -149,8 +150,13 @@ class StatefulTypeTaggedUnion(
|
|||
): # :(
|
||||
type: ResourceLocation | None = field(init=False)
|
||||
|
||||
def __init_subclass__(cls, type: TagValue | None) -> None:
|
||||
super().__init_subclass__("type", type)
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
*,
|
||||
group: str | None = None,
|
||||
type: TagValue | None,
|
||||
) -> None:
|
||||
super().__init_subclass__(group=group, value=type)
|
||||
match type:
|
||||
case str():
|
||||
cls.type = ResourceLocation.from_str(type)
|
||||
|
@ -158,7 +164,3 @@ class StatefulTypeTaggedUnion(
|
|||
cls.type = None
|
||||
case None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def _tag_value(self) -> str:
|
||||
return str(self.type)
|
||||
|
|
Loading…
Reference in a new issue