Add known networks
This commit is contained in:
parent
449279b69e
commit
4a9c882a0a
13
hubconf.py
13
hubconf.py
|
@ -1,6 +1,14 @@
|
|||
import torch
|
||||
|
||||
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
||||
release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights"
|
||||
known = {
|
||||
name: f"{release_url}/{name}.pt"
|
||||
for name in [
|
||||
'face_paint_512_v0', 'face_paint_512_v2'
|
||||
]
|
||||
}
|
||||
|
||||
from model import Generator
|
||||
|
||||
device = torch.device(device)
|
||||
|
@ -8,10 +16,11 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
|||
model = Generator().to(device)
|
||||
|
||||
if type(pretrained) == str:
|
||||
ckpt_url = pretrained
|
||||
# Look if a known name is passed, otherwise assume it's a URL
|
||||
ckpt_url = known.get(pretrained, pretrained)
|
||||
pretrained = True
|
||||
else:
|
||||
ckpt_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights/face_paint_512_v2_0.pt"
|
||||
ckpt_url = known.get('face_paint_512_v2')
|
||||
|
||||
if pretrained is True:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
|
|
Loading…
Reference in a new issue