Update device
This commit is contained in:
parent
b8e3a27db1
commit
449279b69e
|
@ -3,7 +3,9 @@ import torch
|
||||||
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
||||||
from model import Generator
|
from model import Generator
|
||||||
|
|
||||||
model = Generator()
|
device = torch.device(device)
|
||||||
|
|
||||||
|
model = Generator().to(device)
|
||||||
|
|
||||||
if type(pretrained) == str:
|
if type(pretrained) == str:
|
||||||
ckpt_url = pretrained
|
ckpt_url = pretrained
|
||||||
|
@ -14,7 +16,7 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
|
||||||
if pretrained is True:
|
if pretrained is True:
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
state_dict = torch.hub.load_state_dict_from_url(
|
||||||
ckpt_url,
|
ckpt_url,
|
||||||
map_location=torch.device(device),
|
map_location=device,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
check_hash=check_hash,
|
check_hash=check_hash,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue