animegan2-pytorch/convert_weights.py

140 lines
4.0 KiB
Python

import argparse
import numpy as np
import os
import tensorflow as tf
from AnimeGANv2.net import generator as tf_generator
import torch
from model import Generator
def load_tf_weights(tf_path):
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
with tf.variable_scope("generator", reuse=False):
test_generated = tf_generator.G_net(test_real).fake
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess:
ckpt = tf.train.get_checkpoint_state(tf_path)
assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {tf_path}"
saver.restore(sess, ckpt.model_checkpoint_path)
print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded")
tf_weights = {}
for v in tf.trainable_variables():
tf_weights[v.name] = v.eval()
return tf_weights
def convert_keys(k):
# 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias]
# 2. handle each part & merge into a pytorch model keys
k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/")
keys = k.split("/")[2:]
is_dconv = False
# handle C block..
if keys[0] == "C":
if keys[1] in ["Conv_1", "LayerNorm_1"]:
keys[1] = keys[1].replace("1", "5")
if len(keys) == 4:
assert "r" in keys[1]
if keys[1] == keys[2]:
is_dconv = True
keys[2] = "1.1"
block_c_maps = {
"1": "1.2",
"Conv_1": "2",
"2": "3",
}
if keys[2] in block_c_maps:
keys[2] = block_c_maps[keys[2]]
keys[1] = keys[1].replace("r", "") + ".layers." + keys[2]
keys[2] = keys[3]
keys.pop(-1)
assert len(keys) == 3
# handle output block
if "out" in keys[0]:
keys[1] = "0"
# first part
if keys[0] in ["A", "B", "C", "D", "E"]:
keys[0] = "block_" + keys[0].lower()
# second part
if "LayerNorm_" in keys[1]:
keys[1] = keys[1].replace("LayerNorm_", "") + ".2"
if "Conv_" in keys[1]:
keys[1] = keys[1].replace("Conv_", "") + ".1"
# third part
keys[2] = {
"weights:0": "weight",
"w:0": "weight",
"bias:0": "bias",
"gamma:0": "weight",
"beta:0": "bias",
}[keys[2]]
return ".".join(keys), is_dconv
def convert_and_save(tf_checkpoint_path, save_name):
tf_weights = load_tf_weights(tf_checkpoint_path)
torch_net = Generator()
torch_weights = torch_net.state_dict()
torch_converted_weights = {}
for k, v in tf_weights.items():
torch_k, is_dconv = convert_keys(k)
assert torch_k in torch_weights, f"weight name mismatch: {k}"
converted_weight = torch.from_numpy(v)
if len(converted_weight.shape) == 4:
if is_dconv:
converted_weight = converted_weight.permute(2, 3, 0, 1)
else:
converted_weight = converted_weight.permute(3, 2, 0, 1)
assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}"
torch_converted_weights[torch_k] = converted_weight
assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing"
torch_net.load_state_dict(torch_converted_weights)
torch.save(torch_net.state_dict(), save_name)
print(f"PyTorch model saved at {save_name}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--tf_checkpoint_path',
type=str,
default='AnimeGANv2/checkpoint/generator_Paprika_weight',
)
parser.add_argument(
'--save_name',
type=str,
default='pytorch_generator_Paprika.pt',
)
args = parser.parse_args()
convert_and_save(args.tf_checkpoint_path, args.save_name)