init
This commit is contained in:
parent
92a736d876
commit
a2d49c885c
26
README.md
26
README.md
|
@ -1 +1,25 @@
|
|||
# animegan2-pytorch
|
||||
### PyTorch Implementation of [AnimeGANv2](https://github.com/TachibanaYoshino/AnimeGANv2)
|
||||
|
||||
|
||||
**Weight Conversion (Optional)**
|
||||
```
|
||||
git clone https://github.com/TachibanaYoshino/AnimeGANv2
|
||||
python convert_weights.py
|
||||
|
||||
```
|
||||
|
||||
**Inference**
|
||||
```
|
||||
python test.py --input_dir [image_folder_path]
|
||||
|
||||
```
|
||||
|
||||
**Results from converted [[Paprika](https://drive.google.com/file/d/1K_xN32uoQKI8XmNYNLTX5gDn1UnQVe5I/view?usp=sharing)] style model**
|
||||
|
||||
(input image, original tensorflow result, pytorch result from left to right)
|
||||
|
||||
<img src="./samples/compare/1.jpg" width="650">
|
||||
<img src="./samples/compare/2.jpg" width="650">
|
||||
<img src="./samples/compare/3.jpg" width="650">
|
||||
|
||||
**Note:** Training code not included / Results looks slightly blurrier than the original ones.
|
140
convert_weights.py
Normal file
140
convert_weights.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
from AnimeGANv2.net import generator as tf_generator
|
||||
|
||||
import torch
|
||||
from model import Generator
|
||||
|
||||
|
||||
def load_tf_weights(tf_path):
|
||||
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
|
||||
with tf.variable_scope("generator", reuse=False):
|
||||
test_generated = tf_generator.G_net(test_real).fake
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess:
|
||||
ckpt = tf.train.get_checkpoint_state(tf_path)
|
||||
|
||||
assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {checkpoint_dir}"
|
||||
|
||||
saver.restore(sess, ckpt.model_checkpoint_path)
|
||||
print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded")
|
||||
|
||||
tf_weights = {}
|
||||
for v in tf.trainable_variables():
|
||||
tf_weights[v.name] = v.eval()
|
||||
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_keys(k):
|
||||
|
||||
# 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias]
|
||||
# 2. handle each part & merge into a pytorch model keys
|
||||
|
||||
k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/")
|
||||
keys = k.split("/")[2:]
|
||||
|
||||
is_dconv = False
|
||||
|
||||
# handle C block..
|
||||
if keys[0] == "C":
|
||||
if keys[1] in ["Conv_1", "LayerNorm_1"]:
|
||||
keys[1] = keys[1].replace("1", "5")
|
||||
|
||||
if len(keys) == 4:
|
||||
assert "r" in keys[1]
|
||||
|
||||
if keys[1] == keys[2]:
|
||||
is_dconv = True
|
||||
keys[2] = "1.1"
|
||||
|
||||
block_c_maps = {
|
||||
"1": "1.2",
|
||||
"Conv_1": "2",
|
||||
"2": "3",
|
||||
}
|
||||
if keys[2] in block_c_maps:
|
||||
keys[2] = block_c_maps[keys[2]]
|
||||
|
||||
keys[1] = keys[1].replace("r", "") + ".layers." + keys[2]
|
||||
keys[2] = keys[3]
|
||||
keys.pop(-1)
|
||||
assert len(keys) == 3
|
||||
|
||||
# handle output block
|
||||
if "out" in keys[0]:
|
||||
keys[1] = "0"
|
||||
|
||||
# first part
|
||||
if keys[0] in ["A", "B", "C", "D", "E"]:
|
||||
keys[0] = "block_" + keys[0].lower()
|
||||
|
||||
# second part
|
||||
if "LayerNorm_" in keys[1]:
|
||||
keys[1] = keys[1].replace("LayerNorm_", "") + ".2"
|
||||
if "Conv_" in keys[1]:
|
||||
keys[1] = keys[1].replace("Conv_", "") + ".1"
|
||||
|
||||
# third part
|
||||
keys[2] = {
|
||||
"weights:0": "weight",
|
||||
"w:0": "weight",
|
||||
"bias:0": "bias",
|
||||
"gamma:0": "weight",
|
||||
"beta:0": "bias",
|
||||
}[keys[2]]
|
||||
|
||||
return ".".join(keys), is_dconv
|
||||
|
||||
|
||||
def convert_and_save(tf_checkpoint_path, save_name):
|
||||
|
||||
tf_weights = load_tf_weights(tf_checkpoint_path)
|
||||
|
||||
torch_net = Generator()
|
||||
torch_weights = torch_net.state_dict()
|
||||
|
||||
torch_converted_weights = {}
|
||||
for k, v in tf_weights.items():
|
||||
torch_k, is_dconv = convert_keys(k)
|
||||
assert torch_k in torch_weights, f"weight name mismatch: {k}"
|
||||
|
||||
converted_weight = torch.from_numpy(v)
|
||||
if len(converted_weight.shape) == 4:
|
||||
if is_dconv:
|
||||
converted_weight = converted_weight.permute(2, 3, 0, 1)
|
||||
else:
|
||||
converted_weight = converted_weight.permute(3, 2, 0, 1)
|
||||
|
||||
assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}"
|
||||
|
||||
torch_converted_weights[torch_k] = converted_weight
|
||||
|
||||
assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing"
|
||||
torch_net.load_state_dict(torch_converted_weights)
|
||||
torch.save(torch_net.state_dict(), save_name)
|
||||
print(f"PyTorch model saved at {save_name}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--tf_checkpoint_path',
|
||||
type=str,
|
||||
default='AnimeGANv2/checkpoint/generator_Paprika_weight',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--save_name',
|
||||
type=str,
|
||||
default='pytorch_generator_Paprika.pt',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_and_save(args.tf_checkpoint_path, args.save_name)
|
106
model.py
Normal file
106
model.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvNormLReLU(nn.Sequential):
|
||||
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
|
||||
|
||||
pad_layer = {
|
||||
"zero": nn.ZeroPad2d,
|
||||
"same": nn.ReplicationPad2d,
|
||||
"reflect": nn.ReflectionPad2d,
|
||||
}
|
||||
if pad_mode not in pad_layer:
|
||||
raise NotImplementedError
|
||||
|
||||
super(ConvNormLReLU, self).__init__(
|
||||
pad_layer[pad_mode](padding),
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
|
||||
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, expansion_ratio=2):
|
||||
super(InvertedResBlock, self).__init__()
|
||||
|
||||
self.use_res_connect = in_ch == out_ch
|
||||
bottleneck = int(round(in_ch*expansion_ratio))
|
||||
layers = []
|
||||
if expansion_ratio != 1:
|
||||
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
|
||||
|
||||
# dw
|
||||
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
|
||||
# pw
|
||||
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
|
||||
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.layers(input)
|
||||
if self.use_res_connect:
|
||||
out = input + out
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, ):
|
||||
super().__init__()
|
||||
|
||||
self.block_a = nn.Sequential(
|
||||
ConvNormLReLU(3, 32, kernel_size=7, padding=3),
|
||||
ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)),
|
||||
ConvNormLReLU(64, 64)
|
||||
)
|
||||
|
||||
self.block_b = nn.Sequential(
|
||||
ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)),
|
||||
ConvNormLReLU(128, 128)
|
||||
)
|
||||
|
||||
self.block_c = nn.Sequential(
|
||||
ConvNormLReLU(128, 128),
|
||||
InvertedResBlock(128, 256, 2),
|
||||
InvertedResBlock(256, 256, 2),
|
||||
InvertedResBlock(256, 256, 2),
|
||||
InvertedResBlock(256, 256, 2),
|
||||
ConvNormLReLU(256, 128),
|
||||
)
|
||||
|
||||
self.block_d = nn.Sequential(
|
||||
ConvNormLReLU(128, 128),
|
||||
ConvNormLReLU(128, 128)
|
||||
)
|
||||
|
||||
self.block_e = nn.Sequential(
|
||||
ConvNormLReLU(128, 64),
|
||||
ConvNormLReLU(64, 64),
|
||||
ConvNormLReLU(64, 32, kernel_size=7, padding=3)
|
||||
)
|
||||
|
||||
self.out_layer = nn.Sequential(
|
||||
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.block_a(input)
|
||||
half_size = out.size()[-2:]
|
||||
out = self.block_b(out)
|
||||
out = self.block_c(out)
|
||||
|
||||
# out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
|
||||
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
||||
out = self.block_d(out)
|
||||
|
||||
# out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
|
||||
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
||||
out = self.block_e(out)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
|
BIN
samples/compare/1.jpg
Normal file
BIN
samples/compare/1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 863 KiB |
BIN
samples/compare/2.jpg
Normal file
BIN
samples/compare/2.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.3 MiB |
BIN
samples/compare/3.jpg
Normal file
BIN
samples/compare/3.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
BIN
samples/inputs/1.jpg
Normal file
BIN
samples/inputs/1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 719 KiB |
BIN
samples/inputs/2.jpg
Normal file
BIN
samples/inputs/2.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 536 KiB |
BIN
samples/inputs/3.jpg
Normal file
BIN
samples/inputs/3.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 227 KiB |
80
test.py
Normal file
80
test.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from model import Generator
|
||||
|
||||
torch.backends.cudnn.enabled = False
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
def load_image(image_path):
|
||||
img = cv2.imread(image_path).astype(np.float32)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
def to_32s(x):
|
||||
return 256 if x < 256 else x - x%32
|
||||
|
||||
img = cv2.resize(img, (to_32s(w), to_32s(h)))
|
||||
img = torch.from_numpy(img)
|
||||
img = img/127.5 - 1.0
|
||||
return img
|
||||
|
||||
|
||||
def test(args):
|
||||
device = args.device
|
||||
|
||||
net = Generator()
|
||||
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
|
||||
net.to(device).eval()
|
||||
print(f"model loaded: {args.checkpoint}")
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
for image_name in sorted(os.listdir(args.input_dir)):
|
||||
if os.path.splitext(image_name)[-1] not in [".jpg", ".png", ".bmp", ".tiff"]:
|
||||
continue
|
||||
|
||||
image = load_image(os.path.join(args.input_dir, image_name))
|
||||
|
||||
with torch.no_grad():
|
||||
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
||||
out = net(input).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
||||
out = (out + 1)*127.5
|
||||
out = np.clip(out, 0, 255).astype(np.uint8)
|
||||
|
||||
cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
|
||||
print(f"image saved: {image_name}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
type=str,
|
||||
default='./pytorch_generator_Paprika.pt',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input_dir',
|
||||
type=str,
|
||||
default='./samples/inputs',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='./samples/results',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
type=str,
|
||||
default='cuda:0',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test(args)
|
||||
|
Loading…
Reference in a new issue