Fix return function
This commit is contained in:
parent
1799501f07
commit
b8e3a27db1
|
@ -1,9 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
from model import Generator
|
|
||||||
|
|
||||||
|
|
||||||
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
||||||
|
from model import Generator
|
||||||
|
|
||||||
model = Generator()
|
model = Generator()
|
||||||
|
|
||||||
if type(pretrained) == str:
|
if type(pretrained) == str:
|
||||||
ckpt_url = pretrained
|
ckpt_url = pretrained
|
||||||
pretrained = True
|
pretrained = True
|
||||||
|
@ -18,3 +19,5 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
||||||
check_hash=check_hash,
|
check_hash=check_hash,
|
||||||
)
|
)
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
return model
|
Loading…
Reference in a new issue