Fix return function

This commit is contained in:
xhlulu 2021-11-06 16:20:24 -04:00
parent 1799501f07
commit b8e3a27db1

View file

@ -1,9 +1,10 @@
import torch
from model import Generator
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
from model import Generator
model = Generator()
if type(pretrained) == str:
ckpt_url = pretrained
pretrained = True
@ -18,3 +19,5 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
check_hash=check_hash,
)
model.load_state_dict(state_dict)
return model