diff --git a/hubconf.py b/hubconf.py index d79d336..96d41fe 100644 --- a/hubconf.py +++ b/hubconf.py @@ -3,7 +3,9 @@ import torch def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): from model import Generator - model = Generator() + device = torch.device(device) + + model = Generator().to(device) if type(pretrained) == str: ckpt_url = pretrained @@ -14,7 +16,7 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): if pretrained is True: state_dict = torch.hub.load_state_dict_from_url( ckpt_url, - map_location=torch.device(device), + map_location=device, progress=progress, check_hash=check_hash, )