From c78943efdeb2813a6a1ff75da0c3cd2fe354f203 Mon Sep 17 00:00:00 2001 From: xhlulu Date: Sat, 6 Nov 2021 16:51:14 -0400 Subject: [PATCH] add face2paint --- hubconf.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/hubconf.py b/hubconf.py index 9f55e34..4ed1837 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,6 +1,6 @@ import torch -def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): +def generator(pretrained=True, device="cpu", progress=True, check_hash=True): release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights" known = { name: f"{release_url}/{name}.pt" @@ -31,4 +31,33 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): ) model.load_state_dict(state_dict) - return model \ No newline at end of file + return model + + +def face2paint(device="cpu"): + from PIL import Image + from torchvision.transforms.functional import to_tensor, to_pil_image + + def face2paint( + model: torch.nn.Module, + img: Image.Image, + size: int, + side_by_side: bool = True, + device: str = device, + ) -> Image.Image: + w, h = img.size + s = min(w, h) + img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) + img = img.resize((size, size), Image.LANCZOS) + + input = to_tensor(img).unsqueeze(0) * 2 - 1 + output = model(input.to(device)).cpu()[0] + + if side_by_side: + output = torch.cat([input[0], output], dim=2) + + output = (output * 0.5 + 0.5).clip(0, 1) + + return to_pil_image(output) + + return face2paint