additional cli
This commit is contained in:
parent
adc0516db9
commit
09cc72dad7
14
model.py
14
model.py
|
@ -87,18 +87,22 @@ class Generator(nn.Module):
|
||||||
nn.Tanh()
|
nn.Tanh()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input, align_corners=True):
|
||||||
out = self.block_a(input)
|
out = self.block_a(input)
|
||||||
half_size = out.size()[-2:]
|
half_size = out.size()[-2:]
|
||||||
out = self.block_b(out)
|
out = self.block_b(out)
|
||||||
out = self.block_c(out)
|
out = self.block_c(out)
|
||||||
|
|
||||||
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
|
if align_corners:
|
||||||
# out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
|
||||||
|
else:
|
||||||
|
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
||||||
out = self.block_d(out)
|
out = self.block_d(out)
|
||||||
|
|
||||||
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
|
if align_corners:
|
||||||
# out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
|
||||||
|
else:
|
||||||
|
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
||||||
out = self.block_e(out)
|
out = self.block_e(out)
|
||||||
|
|
||||||
out = self.out_layer(out)
|
out = self.out_layer(out)
|
||||||
|
|
24
test.py
24
test.py
|
@ -11,15 +11,16 @@ torch.backends.cudnn.enabled = False
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
def load_image(image_path):
|
def load_image(image_path, x32=False):
|
||||||
img = cv2.imread(image_path).astype(np.float32)
|
img = cv2.imread(image_path).astype(np.float32)
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
def to_32s(x):
|
if x32: # resize image to multiple of 32s
|
||||||
return 256 if x < 256 else x - x%32
|
def to_32s(x):
|
||||||
|
return 256 if x < 256 else x - x%32
|
||||||
|
img = cv2.resize(img, (to_32s(w), to_32s(h)))
|
||||||
|
|
||||||
img = cv2.resize(img, (to_32s(w), to_32s(h)))
|
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
img = img/127.5 - 1.0
|
img = img/127.5 - 1.0
|
||||||
return img
|
return img
|
||||||
|
@ -36,14 +37,14 @@ def test(args):
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
for image_name in sorted(os.listdir(args.input_dir)):
|
for image_name in sorted(os.listdir(args.input_dir)):
|
||||||
if os.path.splitext(image_name)[-1] not in [".jpg", ".png", ".bmp", ".tiff"]:
|
if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
image = load_image(os.path.join(args.input_dir, image_name))
|
image = load_image(os.path.join(args.input_dir, image_name), args.x32)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
||||||
out = net(input).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
||||||
out = (out + 1)*127.5
|
out = (out + 1)*127.5
|
||||||
out = np.clip(out, 0, 255).astype(np.uint8)
|
out = np.clip(out, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
@ -74,6 +75,15 @@ if __name__ == '__main__':
|
||||||
type=str,
|
type=str,
|
||||||
default='cuda:0',
|
default='cuda:0',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--upsample_align',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--x32',
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
test(args)
|
test(args)
|
||||||
|
|
Loading…
Reference in a new issue