diff --git a/hubconf.py b/hubconf.py index 96d41fe..9f55e34 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,6 +1,14 @@ import torch def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): + release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights" + known = { + name: f"{release_url}/{name}.pt" + for name in [ + 'face_paint_512_v0', 'face_paint_512_v2' + ] + } + from model import Generator device = torch.device(device) @@ -8,10 +16,11 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): model = Generator().to(device) if type(pretrained) == str: - ckpt_url = pretrained + # Look if a known name is passed, otherwise assume it's a URL + ckpt_url = known.get(pretrained, pretrained) pretrained = True else: - ckpt_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights/face_paint_512_v2_0.pt" + ckpt_url = known.get('face_paint_512_v2') if pretrained is True: state_dict = torch.hub.load_state_dict_from_url(