Use entry hooks to load union types

This commit is contained in:
object-Object 2023-06-27 22:08:10 -04:00
parent 371d56033e
commit 5e61d3b8b9
11 changed files with 89 additions and 78 deletions

View file

@ -1,4 +1,4 @@
# build
# project metadata
[build-system]
requires = ["hatchling"]
@ -20,17 +20,19 @@ dependencies = [
"dacite@git+https://github.com/mciszczon/dacite@f298260c6aedc1097c7567b1b0a61298a0ddf2a8",
]
[project.entry-points."tagged_unions.Page"]
patchouli = "patchouli.page.patchouli_pages"
hexcasting = "hexcasting.hex_pages"
[project.entry-points."hexdoc.Page"]
hexdoc-patchouli = "patchouli.page.pages"
hexdoc-hexcasting = "hexcasting.hex_pages"
[project.entry-points."tagged_unions.Recipe"]
patchouli = "minecraft.recipe.minecraft_recipes"
hexcasting = "hexcasting.hex_recipes"
[project.entry-points."hexdoc.Recipe"]
hexdoc-minecraft = "minecraft.recipe.recipes"
hexdoc-hexcasting = "hexcasting.hex_recipes"
[project.entry-points."tagged_unions.ItemIngredient"]
patchouli = "minecraft.recipe.ingredients"
hexcasting = "hexcasting.hex_recipes"
[project.entry-points."hexdoc.ItemIngredient"]
hexdoc-minecraft = "minecraft.recipe.ingredients"
hexdoc-hexcasting = "hexcasting.hex_recipes"
# Hatch settings (the build backend)
[tool.hatch.metadata]
allow-direct-references = true # TODO: remove when we switch to Pydantic

View file

@ -2,12 +2,12 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import Any, ClassVar, Generator, Self
from dacite import StrictUnionMatchError, UnionMatchError, from_dict
from pkg_resources import iter_entry_points
from common.dacite_patch import UnionSkip
from common.deserialize import TypedConfig
@ -34,16 +34,17 @@ class WrongTagSkip(UnionSkip):
tag_value: TagValue,
) -> None:
super().__init__(
f"Expected {union_type.__tag_key}={union_type.__expected_tag_value}, got {tag_value}"
f"Expected {union_type._tag_key}={union_type.__expected_tag_value}, got {tag_value}"
)
class InternallyTaggedUnion(ABC):
class InternallyTaggedUnion:
"""Implements [internally tagged unions](https://serde.rs/enum-representations.html#internally-tagged)
using the [Registry pattern](https://charlesreid1.github.io/python-patterns-the-registry.html).
NOTE: Make sure the module where you declare your union types is actually imported,
or they won't be registered.
To ensure your subtypes are loaded even if they're not imported by any file, add
the module as a plugin to your package's entry points. For example, to add subtypes
to a union
Args:
key: The dict key for the internal tag. Should be None for classes which are not
@ -52,22 +53,46 @@ class InternallyTaggedUnion(ABC):
shouldn't be instantiated (eg. abstract classes).
"""
__tag_key: ClassVar[str | None] = None
_loaded_groups: ClassVar[set[str]] = set()
"""Global set of groups whose plugins have already been loaded. Do not overwrite.
We use this so we don't have to load the same modules over and over again.
"""
_group: ClassVar[str | None] = None
_tag_key: ClassVar[str | None] = None
__expected_tag_value: ClassVar[TagValue | None]
__all_subtypes: ClassVar[set[type[Self]]]
__concrete_subtypes: ClassVar[defaultdict[TagValue, set[type[Self]]]]
def __init_subclass__(cls, key: str | None, value: TagValue | None) -> None:
# don't bother initializing classes which aren't part of any union
if key is None:
def __init_subclass__(
cls,
*,
group: str | None = None,
key: str | None = None,
value: TagValue | None,
) -> None:
# inherited data, so only set if not None
if group is not None:
cls._group = group
if key is not None:
cls._tag_key = key
# don't bother with rest of init if it's not part of a union
if cls._tag_key is None:
if cls._group is not None:
raise ValueError(
f"Expected cls._group=None for {cls} with key=None, got {cls._group}"
)
if value is not None:
raise ValueError(
f"Expected value=None for {cls} with key=None, got {value}"
)
return
# set up per-class data and lookups
cls.__tag_key = key
# per-class data and lookups
cls.__expected_tag_value = value
cls.__all_subtypes = set()
cls.__concrete_subtypes = defaultdict(set)
@ -81,7 +106,7 @@ class InternallyTaggedUnion(ABC):
@classmethod
def _tag_key_or_raise(cls) -> str:
if (tag_key := cls.__tag_key) is None:
if (tag_key := cls._tag_key) is None:
raise NotImplementedError
return tag_key
@ -95,7 +120,7 @@ class InternallyTaggedUnion(ABC):
# recursively yield bases
# stop when we reach a non-union or a type with a different key (or no key)
for base in cls.__bases__:
if issubclass(base, InternallyTaggedUnion) and base.__tag_key == tag_key:
if issubclass(base, InternallyTaggedUnion) and base._tag_key == tag_key:
yield base
yield from base._supertypes()
@ -109,6 +134,18 @@ class InternallyTaggedUnion(ABC):
@classmethod
def _resolve_from_dict(cls, data: Self | Any, config: TypedConfig) -> Self:
# if we haven't yet, load plugins from entry points
if cls._group is not None and cls._group not in cls._loaded_groups:
cls._loaded_groups.add(cls._group)
for entry_point in iter_entry_points(cls._group):
try:
entry_point.load()
except ModuleNotFoundError as e:
e.add_note(
f'Note: Tried to load entry point "{entry_point}" from {entry_point.dist}'
)
raise
# do this first so we know it's part of a union
tag_key = cls._tag_key_or_raise()
@ -137,8 +174,8 @@ class InternallyTaggedUnion(ABC):
union_matches[inner_type] = value
except UnionSkip:
pass
except Exception as e:
exceptions.append(e)
except Exception as entry_point:
exceptions.append(entry_point)
# ensure we only matched one
match len(union_matches):
@ -151,11 +188,6 @@ class InternallyTaggedUnion(ABC):
# oopsies
raise ExceptionGroup(
f"Failed to match {cls} with {cls.__tag_key}={tag_value} to any of {tag_types}: {data}",
f"Failed to match {cls} with {cls._tag_key}={tag_value} to any of {tag_types}: {data}",
exceptions,
)
@property
@abstractmethod
def _tag_value(self) -> str:
...

View file

@ -1,38 +1,7 @@
__all__ = [
"BlockState",
"BlockStateIngredient",
"BrainsweepRecipe",
"ModConditionalIngredient",
"HexBook",
"HexBookState",
"VillagerIngredient",
"PageWithPattern",
"LookupPatternPage",
"ManualPatternNosigPage",
"ManualOpPatternPage",
"ManualRawPatternPage",
"CraftingMultiPage",
"BrainsweepPage",
]
from patchouli import Book
from .hex_pages import (
BrainsweepPage,
CraftingMultiPage,
LookupPatternPage,
ManualOpPatternPage,
ManualPatternNosigPage,
ManualRawPatternPage,
PageWithPattern,
)
from .hex_recipes import (
BlockState,
BlockStateIngredient,
BrainsweepRecipe,
ModConditionalIngredient,
VillagerIngredient,
)
from .hex_state import HexBookState
HexBook = Book[HexBookState]
from .hex_state import HexBook, HexBookState

View file

@ -2,11 +2,11 @@ from dataclasses import dataclass
from typing import Any, Literal
from common.types import LocalizedItem
from minecraft.recipe import Recipe
from minecraft.recipe.ingredients import (
from minecraft.recipe import (
ItemIngredient,
MinecraftItemIdIngredient,
MinecraftItemTagIngredient,
Recipe,
)
from minecraft.resource import ResourceLocation

View file

@ -3,6 +3,7 @@ from typing import Any
from common.pattern import PatternInfo
from minecraft.resource import ResourceLocation
from patchouli.book import Book
from patchouli.state import BookState
@ -26,3 +27,6 @@ class HexBookState(BookState):
f"Duplicate pattern {pattern.id}\n{pattern}\n{duplicate}"
)
self.patterns[pattern.id] = pattern
HexBook = Book[HexBookState]

View file

@ -13,13 +13,12 @@ from hexcasting.hex_pages import (
LookupPatternPage,
PageWithPattern,
)
from patchouli import Category, Entry, FormatTree
from patchouli import Category, Entry, FormatTree, Page
from patchouli.page import (
CraftingPage,
EmptyPage,
ImagePage,
LinkPage,
Page,
PageWithText,
PageWithTitle,
SpotlightPage,

View file

@ -16,7 +16,7 @@ class ItemResult:
@dataclass(kw_only=True)
class Recipe(StatefulTypeTaggedUnion[AnyState], type=None):
class Recipe(StatefulTypeTaggedUnion[AnyState], group="hexdoc.Recipe", type=None):
id: ResourceLocation
group: str | None = None

View file

@ -5,7 +5,11 @@ from minecraft.resource import ResourceLocation
from patchouli.state import AnyState, BookState, StatefulTypeTaggedUnion
class ItemIngredient(StatefulTypeTaggedUnion[AnyState], type=None):
class ItemIngredient(
StatefulTypeTaggedUnion[AnyState],
group="hexdoc.ItemIngredient",
type=None,
):
pass

View file

@ -7,8 +7,7 @@ from typing import Literal, Self
from common.deserialize import from_dict_checked, load_json_data, rename
from common.types import Color, LocalizedStr
from minecraft.i18n import I18n
from minecraft.recipe import Recipe
from minecraft.recipe.ingredients import ItemIngredient
from minecraft.recipe import ItemIngredient, Recipe
from minecraft.resource import ItemStack, ResLoc, ResourceLocation
from .category import Category

View file

@ -14,7 +14,7 @@ _T = TypeVar("_T")
@dataclass(kw_only=True)
class Page(StatefulTypeTaggedUnion[AnyState], type=None):
class Page(StatefulTypeTaggedUnion[AnyState], group="hexdoc.Page", type=None):
"""Base class for Patchouli page types.
See: https://vazkiimods.github.io/Patchouli/docs/patchouli-basics/page-types

View file

@ -120,6 +120,7 @@ class StatefulFile(Stateful[AnyState]):
class StatefulInternallyTaggedUnion(
Stateful[AnyState],
InternallyTaggedUnion,
group=None,
key=None,
value=None,
):
@ -149,8 +150,13 @@ class StatefulTypeTaggedUnion(
): # :(
type: ResourceLocation | None = field(init=False)
def __init_subclass__(cls, type: TagValue | None) -> None:
super().__init_subclass__("type", type)
def __init_subclass__(
cls,
*,
group: str | None = None,
type: TagValue | None,
) -> None:
super().__init_subclass__(group=group, value=type)
match type:
case str():
cls.type = ResourceLocation.from_str(type)
@ -158,7 +164,3 @@ class StatefulTypeTaggedUnion(
cls.type = None
case None:
pass
@property
def _tag_value(self) -> str:
return str(self.type)