diff --git a/README.md b/README.md index b445919..04f8eeb 100644 --- a/README.md +++ b/README.md @@ -1 +1,25 @@ -# animegan2-pytorch \ No newline at end of file +### PyTorch Implementation of [AnimeGANv2](https://github.com/TachibanaYoshino/AnimeGANv2) + + +**Weight Conversion (Optional)** +``` +git clone https://github.com/TachibanaYoshino/AnimeGANv2 +python convert_weights.py + +``` + +**Inference** +``` +python test.py --input_dir [image_folder_path] + +``` + +**Results from converted [[Paprika](https://drive.google.com/file/d/1K_xN32uoQKI8XmNYNLTX5gDn1UnQVe5I/view?usp=sharing)] style model** + +(input image, original tensorflow result, pytorch result from left to right) + +   +   +   + +**Note:** Training code not included / Results looks slightly blurrier than the original ones. \ No newline at end of file diff --git a/convert_weights.py b/convert_weights.py new file mode 100644 index 0000000..e064709 --- /dev/null +++ b/convert_weights.py @@ -0,0 +1,140 @@ +import argparse + +import numpy as np +import os + +import tensorflow as tf +from AnimeGANv2.net import generator as tf_generator + +import torch +from model import Generator + + +def load_tf_weights(tf_path): + test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test') + with tf.variable_scope("generator", reuse=False): + test_generated = tf_generator.G_net(test_real).fake + + saver = tf.train.Saver() + + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess: + ckpt = tf.train.get_checkpoint_state(tf_path) + + assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {checkpoint_dir}" + + saver.restore(sess, ckpt.model_checkpoint_path) + print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded") + + tf_weights = {} + for v in tf.trainable_variables(): + tf_weights[v.name] = v.eval() + + return tf_weights + + +def convert_keys(k): + + # 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias] + # 2. handle each part & merge into a pytorch model keys + + k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/") + keys = k.split("/")[2:] + + is_dconv = False + + # handle C block.. + if keys[0] == "C": + if keys[1] in ["Conv_1", "LayerNorm_1"]: + keys[1] = keys[1].replace("1", "5") + + if len(keys) == 4: + assert "r" in keys[1] + + if keys[1] == keys[2]: + is_dconv = True + keys[2] = "1.1" + + block_c_maps = { + "1": "1.2", + "Conv_1": "2", + "2": "3", + } + if keys[2] in block_c_maps: + keys[2] = block_c_maps[keys[2]] + + keys[1] = keys[1].replace("r", "") + ".layers." + keys[2] + keys[2] = keys[3] + keys.pop(-1) + assert len(keys) == 3 + + # handle output block + if "out" in keys[0]: + keys[1] = "0" + + # first part + if keys[0] in ["A", "B", "C", "D", "E"]: + keys[0] = "block_" + keys[0].lower() + + # second part + if "LayerNorm_" in keys[1]: + keys[1] = keys[1].replace("LayerNorm_", "") + ".2" + if "Conv_" in keys[1]: + keys[1] = keys[1].replace("Conv_", "") + ".1" + + # third part + keys[2] = { + "weights:0": "weight", + "w:0": "weight", + "bias:0": "bias", + "gamma:0": "weight", + "beta:0": "bias", + }[keys[2]] + + return ".".join(keys), is_dconv + + +def convert_and_save(tf_checkpoint_path, save_name): + + tf_weights = load_tf_weights(tf_checkpoint_path) + + torch_net = Generator() + torch_weights = torch_net.state_dict() + + torch_converted_weights = {} + for k, v in tf_weights.items(): + torch_k, is_dconv = convert_keys(k) + assert torch_k in torch_weights, f"weight name mismatch: {k}" + + converted_weight = torch.from_numpy(v) + if len(converted_weight.shape) == 4: + if is_dconv: + converted_weight = converted_weight.permute(2, 3, 0, 1) + else: + converted_weight = converted_weight.permute(3, 2, 0, 1) + + assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}" + + torch_converted_weights[torch_k] = converted_weight + + assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing" + torch_net.load_state_dict(torch_converted_weights) + torch.save(torch_net.state_dict(), save_name) + print(f"PyTorch model saved at {save_name}") + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '--tf_checkpoint_path', + type=str, + default='AnimeGANv2/checkpoint/generator_Paprika_weight', + ) + parser.add_argument( + '--save_name', + type=str, + default='pytorch_generator_Paprika.pt', + ) + args = parser.parse_args() + + convert_and_save(args.tf_checkpoint_path, args.save_name) \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..14cd552 --- /dev/null +++ b/model.py @@ -0,0 +1,106 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class ConvNormLReLU(nn.Sequential): + def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False): + + pad_layer = { + "zero": nn.ZeroPad2d, + "same": nn.ReplicationPad2d, + "reflect": nn.ReflectionPad2d, + } + if pad_mode not in pad_layer: + raise NotImplementedError + + super(ConvNormLReLU, self).__init__( + pad_layer[pad_mode](padding), + nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias), + nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True), + nn.LeakyReLU(0.2, inplace=True) + ) + + +class InvertedResBlock(nn.Module): + def __init__(self, in_ch, out_ch, expansion_ratio=2): + super(InvertedResBlock, self).__init__() + + self.use_res_connect = in_ch == out_ch + bottleneck = int(round(in_ch*expansion_ratio)) + layers = [] + if expansion_ratio != 1: + layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0)) + + # dw + layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True)) + # pw + layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False)) + layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, input): + out = self.layers(input) + if self.use_res_connect: + out = input + out + return out + + +class Generator(nn.Module): + def __init__(self, ): + super().__init__() + + self.block_a = nn.Sequential( + ConvNormLReLU(3, 32, kernel_size=7, padding=3), + ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)), + ConvNormLReLU(64, 64) + ) + + self.block_b = nn.Sequential( + ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)), + ConvNormLReLU(128, 128) + ) + + self.block_c = nn.Sequential( + ConvNormLReLU(128, 128), + InvertedResBlock(128, 256, 2), + InvertedResBlock(256, 256, 2), + InvertedResBlock(256, 256, 2), + InvertedResBlock(256, 256, 2), + ConvNormLReLU(256, 128), + ) + + self.block_d = nn.Sequential( + ConvNormLReLU(128, 128), + ConvNormLReLU(128, 128) + ) + + self.block_e = nn.Sequential( + ConvNormLReLU(128, 64), + ConvNormLReLU(64, 64), + ConvNormLReLU(64, 32, kernel_size=7, padding=3) + ) + + self.out_layer = nn.Sequential( + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False), + nn.Tanh() + ) + + def forward(self, input): + out = self.block_a(input) + half_size = out.size()[-2:] + out = self.block_b(out) + out = self.block_c(out) + +# out = F.interpolate(out, half_size, mode="bilinear", align_corners=True) + out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + out = self.block_d(out) + +# out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True) + out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + out = self.block_e(out) + + out = self.out_layer(out) + return out + \ No newline at end of file diff --git a/samples/compare/1.jpg b/samples/compare/1.jpg new file mode 100644 index 0000000..58cb590 Binary files /dev/null and b/samples/compare/1.jpg differ diff --git a/samples/compare/2.jpg b/samples/compare/2.jpg new file mode 100644 index 0000000..97a5417 Binary files /dev/null and b/samples/compare/2.jpg differ diff --git a/samples/compare/3.jpg b/samples/compare/3.jpg new file mode 100644 index 0000000..cd6fdd0 Binary files /dev/null and b/samples/compare/3.jpg differ diff --git a/samples/inputs/1.jpg b/samples/inputs/1.jpg new file mode 100644 index 0000000..f1cf3fa Binary files /dev/null and b/samples/inputs/1.jpg differ diff --git a/samples/inputs/2.jpg b/samples/inputs/2.jpg new file mode 100644 index 0000000..c249fd3 Binary files /dev/null and b/samples/inputs/2.jpg differ diff --git a/samples/inputs/3.jpg b/samples/inputs/3.jpg new file mode 100644 index 0000000..d74bed2 Binary files /dev/null and b/samples/inputs/3.jpg differ diff --git a/test.py b/test.py new file mode 100644 index 0000000..f1ad1cc --- /dev/null +++ b/test.py @@ -0,0 +1,80 @@ +import argparse + +import torch +import cv2 +import numpy as np +import os + +from model import Generator + +torch.backends.cudnn.enabled = False +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +def load_image(image_path): + img = cv2.imread(image_path).astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + h, w = img.shape[:2] + + def to_32s(x): + return 256 if x < 256 else x - x%32 + + img = cv2.resize(img, (to_32s(w), to_32s(h))) + img = torch.from_numpy(img) + img = img/127.5 - 1.0 + return img + + +def test(args): + device = args.device + + net = Generator() + net.load_state_dict(torch.load(args.checkpoint, map_location="cpu")) + net.to(device).eval() + print(f"model loaded: {args.checkpoint}") + + os.makedirs(args.output_dir, exist_ok=True) + + for image_name in sorted(os.listdir(args.input_dir)): + if os.path.splitext(image_name)[-1] not in [".jpg", ".png", ".bmp", ".tiff"]: + continue + + image = load_image(os.path.join(args.input_dir, image_name)) + + with torch.no_grad(): + input = image.permute(2, 0, 1).unsqueeze(0).to(device) + out = net(input).squeeze(0).permute(1, 2, 0).cpu().numpy() + out = (out + 1)*127.5 + out = np.clip(out, 0, 255).astype(np.uint8) + + cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB)) + print(f"image saved: {image_name}") + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '--checkpoint', + type=str, + default='./pytorch_generator_Paprika.pt', + ) + parser.add_argument( + '--input_dir', + type=str, + default='./samples/inputs', + ) + parser.add_argument( + '--output_dir', + type=str, + default='./samples/results', + ) + parser.add_argument( + '--device', + type=str, + default='cuda:0', + ) + args = parser.parse_args() + + test(args) +