diff --git a/hubconf.py b/hubconf.py index 4ed1837..3cacaef 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,18 +1,17 @@ import torch def generator(pretrained=True, device="cpu", progress=True, check_hash=True): - release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights" + from model import Generator + + release_url = "https://github.com/xhlulu/animegan2-pytorch/raw/main/weights" known = { name: f"{release_url}/{name}.pt" for name in [ - 'face_paint_512_v0', 'face_paint_512_v2' + 'celeba_distill', 'face_paint_512_v1', 'face_paint_512_v2', 'paprika' ] } - from model import Generator - device = torch.device(device) - model = Generator().to(device) if type(pretrained) == str: @@ -34,14 +33,14 @@ def generator(pretrained=True, device="cpu", progress=True, check_hash=True): return model -def face2paint(device="cpu"): +def face2paint(device="cpu", size=512): from PIL import Image from torchvision.transforms.functional import to_tensor, to_pil_image def face2paint( model: torch.nn.Module, img: Image.Image, - size: int, + size: int = size, side_by_side: bool = True, device: str = device, ) -> Image.Image: @@ -50,13 +49,14 @@ def face2paint(device="cpu"): img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) img = img.resize((size, size), Image.LANCZOS) - input = to_tensor(img).unsqueeze(0) * 2 - 1 - output = model(input.to(device)).cpu()[0] + with torch.no_grad(): + input = to_tensor(img).unsqueeze(0) * 2 - 1 + output = model(input.to(device)).cpu()[0] - if side_by_side: - output = torch.cat([input[0], output], dim=2) + if side_by_side: + output = torch.cat([input[0], output], dim=2) - output = (output * 0.5 + 0.5).clip(0, 1) + output = (output * 0.5 + 0.5).clip(0, 1) return to_pil_image(output)