From 09cc72dad7209ed393e111ed20343aa4608e5cc8 Mon Sep 17 00:00:00 2001 From: sr9 Date: Wed, 3 Mar 2021 19:44:57 +0900 Subject: [PATCH] additional cli --- model.py | 14 +++++++++----- test.py | 24 +++++++++++++++++------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/model.py b/model.py index 8024013..0970115 100644 --- a/model.py +++ b/model.py @@ -87,18 +87,22 @@ class Generator(nn.Module): nn.Tanh() ) - def forward(self, input): + def forward(self, input, align_corners=True): out = self.block_a(input) half_size = out.size()[-2:] out = self.block_b(out) out = self.block_c(out) - out = F.interpolate(out, half_size, mode="bilinear", align_corners=True) -# out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + if align_corners: + out = F.interpolate(out, half_size, mode="bilinear", align_corners=True) + else: + out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) out = self.block_d(out) - out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True) -# out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + if align_corners: + out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True) + else: + out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) out = self.block_e(out) out = self.out_layer(out) diff --git a/test.py b/test.py index f1ad1cc..7e6b60f 100644 --- a/test.py +++ b/test.py @@ -11,15 +11,16 @@ torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True -def load_image(image_path): +def load_image(image_path, x32=False): img = cv2.imread(image_path).astype(np.float32) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w = img.shape[:2] - def to_32s(x): - return 256 if x < 256 else x - x%32 + if x32: # resize image to multiple of 32s + def to_32s(x): + return 256 if x < 256 else x - x%32 + img = cv2.resize(img, (to_32s(w), to_32s(h))) - img = cv2.resize(img, (to_32s(w), to_32s(h))) img = torch.from_numpy(img) img = img/127.5 - 1.0 return img @@ -36,14 +37,14 @@ def test(args): os.makedirs(args.output_dir, exist_ok=True) for image_name in sorted(os.listdir(args.input_dir)): - if os.path.splitext(image_name)[-1] not in [".jpg", ".png", ".bmp", ".tiff"]: + if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]: continue - image = load_image(os.path.join(args.input_dir, image_name)) + image = load_image(os.path.join(args.input_dir, image_name), args.x32) with torch.no_grad(): input = image.permute(2, 0, 1).unsqueeze(0).to(device) - out = net(input).squeeze(0).permute(1, 2, 0).cpu().numpy() + out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy() out = (out + 1)*127.5 out = np.clip(out, 0, 255).astype(np.uint8) @@ -74,6 +75,15 @@ if __name__ == '__main__': type=str, default='cuda:0', ) + parser.add_argument( + '--upsample_align', + type=bool, + default=False, + ) + parser.add_argument( + '--x32', + action="store_true", + ) args = parser.parse_args() test(args)