Update device

This commit is contained in:
xhlulu 2021-11-06 16:23:03 -04:00
parent b8e3a27db1
commit 449279b69e

View file

@ -3,7 +3,9 @@ import torch
def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
from model import Generator
model = Generator()
device = torch.device(device)
model = Generator().to(device)
if type(pretrained) == str:
ckpt_url = pretrained
@ -14,7 +16,7 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True):
if pretrained is True:
state_dict = torch.hub.load_state_dict_from_url(
ckpt_url,
map_location=torch.device(device),
map_location=device,
progress=progress,
check_hash=check_hash,
)