2021-11-06 21:12:06 +01:00
|
|
|
import torch
|
|
|
|
|
|
|
|
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
2021-11-06 21:48:30 +01:00
|
|
|
release_url = "https://github.com/xhlulu/animegan2-pytorch/releases/download/weights"
|
|
|
|
known = {
|
|
|
|
name: f"{release_url}/{name}.pt"
|
|
|
|
for name in [
|
|
|
|
'face_paint_512_v0', 'face_paint_512_v2'
|
|
|
|
]
|
|
|
|
}
|
|
|
|
|
2021-11-06 21:20:24 +01:00
|
|
|
from model import Generator
|
|
|
|
|
2021-11-06 21:23:03 +01:00
|
|
|
device = torch.device(device)
|
|
|
|
|
|
|
|
model = Generator().to(device)
|
2021-11-06 21:20:24 +01:00
|
|
|
|
2021-11-06 21:12:06 +01:00
|
|
|
if type(pretrained) == str:
|
2021-11-06 21:48:30 +01:00
|
|
|
# Look if a known name is passed, otherwise assume it's a URL
|
|
|
|
ckpt_url = known.get(pretrained, pretrained)
|
2021-11-06 21:12:06 +01:00
|
|
|
pretrained = True
|
|
|
|
else:
|
2021-11-06 21:48:30 +01:00
|
|
|
ckpt_url = known.get('face_paint_512_v2')
|
2021-11-06 21:12:06 +01:00
|
|
|
|
|
|
|
if pretrained is True:
|
|
|
|
state_dict = torch.hub.load_state_dict_from_url(
|
|
|
|
ckpt_url,
|
2021-11-06 21:23:03 +01:00
|
|
|
map_location=device,
|
2021-11-06 21:12:06 +01:00
|
|
|
progress=progress,
|
|
|
|
check_hash=check_hash,
|
|
|
|
)
|
|
|
|
model.load_state_dict(state_dict)
|
2021-11-06 21:20:24 +01:00
|
|
|
|
|
|
|
return model
|