From 449279b69ea2aa684d7619e1c1d3169358e405d6 Mon Sep 17 00:00:00 2001 From: xhlulu Date: Sat, 6 Nov 2021 16:23:03 -0400 Subject: [PATCH] Update device --- hubconf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hubconf.py b/hubconf.py index d79d336..96d41fe 100644 --- a/hubconf.py +++ b/hubconf.py @@ -3,7 +3,9 @@ import torch def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): from model import Generator - model = Generator() + device = torch.device(device) + + model = Generator().to(device) if type(pretrained) == str: ckpt_url = pretrained @@ -14,7 +16,7 @@ def animegan2(pretrained=True, device="cpu", progress=True, check_hash=True): if pretrained is True: state_dict = torch.hub.load_state_dict_from_url( ckpt_url, - map_location=torch.device(device), + map_location=device, progress=progress, check_hash=check_hash, )