code Refactor

This commit is contained in:
Dheeraj Peri 2020-07-01 11:42:19 -07:00
parent e0f399def4
commit 6d2357a9b8
2 changed files with 38 additions and 23 deletions

View file

@ -4,28 +4,25 @@ import numpy as np
import argparse
import os
import shutil
def main(args):
def process_checkpoint(input_ckpt, output_ckpt_path):
"""
This function loads a RN50 checkpoint with Dense layer as the final layer
and transforms the final dense layer into a 1x1 convolution layer. The weights
of the dense layer are reshaped into weights of 1x1 conv layer.
Args:
input_ckpt: Path to the input RN50 ckpt which has dense layer as classification layer.
Returns:
None. New checkpoint with 1x1 conv layer as classification layer is generated.
"""
with tf.Session() as sess:
ckpt = args.ckpt
new_ckpt=args.out
output_dir = "./new_ckpt_dir"
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
# Create an output directory
os.mkdir(output_dir)
new_ckpt_path = os.path.join(output_dir, new_ckpt)
with open(os.path.join(output_dir, "checkpoint"), 'w') as file:
file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")
file.close()
# Load all the variables
all_vars = tf.train.list_variables(ckpt)
ckpt_reader = tf.train.load_checkpoint(ckpt)
# Capture the dense layer weights and reshape them to a 4D tensor which would be
# the weights of a 1x1 convolution layer. This code replaces the dense (FC) layer
# to a 1x1 conv layer.
dense_layer = 'resnet50_v1.5/output/dense/kernel'
dense_layer_value=0.
new_var_list=[]
for var in all_vars:
@ -34,18 +31,36 @@ def main(args):
dense_layer_value = curr_var
else:
new_var_list.append(tf.Variable(curr_var, name=var[0]))
new_var_value = np.reshape(dense_layer_value, [1, 1, 2048, 1001])
dense_layer_shape = [1, 1, 2048, 1001]
new_var_value = np.reshape(dense_layer_value, )
new_var = tf.Variable(new_var_value, name=dense_layer)
new_var_list.append(new_var)
sess.run(tf.global_variables_initializer())
tf.train.Saver(var_list=new_var_list).save(sess, new_ckpt_path, write_meta_graph=False, write_state=False)
print ("Rewriting checkpoints completed")
tf.train.Saver(var_list=new_var_list).save(sess, output_ckpt_path, write_meta_graph=False, write_state=False)
print ("Rewriting checkpoint completed")
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--out', type=str, default='./new.ckpt')
parser.add_argument('--input', type=str, required=True, help='Path to pretrained RN50 checkpoint with dense layer')
parser.add_argument('--dense_layer', type=str, default='resnet/output/dense/kernel')
parser.add_argument('--output', type=str, default='output_dir', help="Output directory to store new checkpoint")
args = parser.parse_args()
main(args)
main(args)
input_ckpt = args.input
# Create an output directory
os.mkdir(args.output)
new_ckpt='new.ckpt'
new_ckpt_path = os.path.join(args.output, new_ckpt)
with open(os.path.join(output_dir, "checkpoint"), 'w') as file:
file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")
# Process the input checkpoint, apply transforms and generate a new checkpoint.
process_checkpoint(input_ckpt, new_ckpt_path)

View file

@ -1,4 +1,4 @@
# This script does Quantization aware training of Resnet-50 by finetuning on the pre-trained model using 1 GPU and a batch size of 32.
# Usage ./GPU1_RN50_QAT.sh <path to the pre-trained model> <path to dataset> <path to results directory>
python main.py --mode=train_and_evaluate --batch_size=32 --lr_warmup_epochs=1 --label_smoothing 0.1 --lr_init=0.00005 --momentum=0.875 --weight_decay=3.0517578125e-05 --finetune_checkpoint=$1 --data_dir=$2 --results_dir=$3 --quantize --symmetric --num_iter 10 --data_format NHWC
python main.py --mode=train_and_evaluate --batch_size=32 --lr_warmup_epochs=1 --quantize --symmetric --use_qdq --label_smoothing 0.1 --lr_init=0.00005 --momentum=0.875 --weight_decay=3.0517578125e-05 --finetune_checkpoint=$1 --data_dir=$2 --results_dir=$3 --num_iter 10 --data_format NHWC