code Refactor
This commit is contained in:
parent
e0f399def4
commit
6d2357a9b8
|
@ -4,28 +4,25 @@ import numpy as np
|
|||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
def main(args):
|
||||
|
||||
def process_checkpoint(input_ckpt, output_ckpt_path):
|
||||
"""
|
||||
This function loads a RN50 checkpoint with Dense layer as the final layer
|
||||
and transforms the final dense layer into a 1x1 convolution layer. The weights
|
||||
of the dense layer are reshaped into weights of 1x1 conv layer.
|
||||
Args:
|
||||
input_ckpt: Path to the input RN50 ckpt which has dense layer as classification layer.
|
||||
Returns:
|
||||
None. New checkpoint with 1x1 conv layer as classification layer is generated.
|
||||
"""
|
||||
|
||||
with tf.Session() as sess:
|
||||
ckpt = args.ckpt
|
||||
new_ckpt=args.out
|
||||
output_dir = "./new_ckpt_dir"
|
||||
if os.path.isdir(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
# Create an output directory
|
||||
os.mkdir(output_dir)
|
||||
|
||||
new_ckpt_path = os.path.join(output_dir, new_ckpt)
|
||||
with open(os.path.join(output_dir, "checkpoint"), 'w') as file:
|
||||
file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")
|
||||
file.close()
|
||||
# Load all the variables
|
||||
all_vars = tf.train.list_variables(ckpt)
|
||||
ckpt_reader = tf.train.load_checkpoint(ckpt)
|
||||
# Capture the dense layer weights and reshape them to a 4D tensor which would be
|
||||
# the weights of a 1x1 convolution layer. This code replaces the dense (FC) layer
|
||||
# to a 1x1 conv layer.
|
||||
dense_layer = 'resnet50_v1.5/output/dense/kernel'
|
||||
dense_layer_value=0.
|
||||
new_var_list=[]
|
||||
for var in all_vars:
|
||||
|
@ -34,18 +31,36 @@ def main(args):
|
|||
dense_layer_value = curr_var
|
||||
else:
|
||||
new_var_list.append(tf.Variable(curr_var, name=var[0]))
|
||||
|
||||
new_var_value = np.reshape(dense_layer_value, [1, 1, 2048, 1001])
|
||||
|
||||
dense_layer_shape = [1, 1, 2048, 1001]
|
||||
new_var_value = np.reshape(dense_layer_value, )
|
||||
new_var = tf.Variable(new_var_value, name=dense_layer)
|
||||
new_var_list.append(new_var)
|
||||
|
||||
sess.run(tf.global_variables_initializer())
|
||||
tf.train.Saver(var_list=new_var_list).save(sess, new_ckpt_path, write_meta_graph=False, write_state=False)
|
||||
print ("Rewriting checkpoints completed")
|
||||
tf.train.Saver(var_list=new_var_list).save(sess, output_ckpt_path, write_meta_graph=False, write_state=False)
|
||||
print ("Rewriting checkpoint completed")
|
||||
|
||||
if __name__=='__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ckpt', type=str, required=True)
|
||||
parser.add_argument('--out', type=str, default='./new.ckpt')
|
||||
parser.add_argument('--input', type=str, required=True, help='Path to pretrained RN50 checkpoint with dense layer')
|
||||
parser.add_argument('--dense_layer', type=str, default='resnet/output/dense/kernel')
|
||||
parser.add_argument('--output', type=str, default='output_dir', help="Output directory to store new checkpoint")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
input_ckpt = args.input
|
||||
# Create an output directory
|
||||
os.mkdir(args.output)
|
||||
|
||||
new_ckpt='new.ckpt'
|
||||
new_ckpt_path = os.path.join(args.output, new_ckpt)
|
||||
with open(os.path.join(output_dir, "checkpoint"), 'w') as file:
|
||||
file.write("model_checkpoint_path: "+ "\"" + new_ckpt + "\"")
|
||||
|
||||
# Process the input checkpoint, apply transforms and generate a new checkpoint.
|
||||
process_checkpoint(input_ckpt, new_ckpt_path)
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# This script does Quantization aware training of Resnet-50 by finetuning on the pre-trained model using 1 GPU and a batch size of 32.
|
||||
# Usage ./GPU1_RN50_QAT.sh <path to the pre-trained model> <path to dataset> <path to results directory>
|
||||
|
||||
python main.py --mode=train_and_evaluate --batch_size=32 --lr_warmup_epochs=1 --label_smoothing 0.1 --lr_init=0.00005 --momentum=0.875 --weight_decay=3.0517578125e-05 --finetune_checkpoint=$1 --data_dir=$2 --results_dir=$3 --quantize --symmetric --num_iter 10 --data_format NHWC
|
||||
python main.py --mode=train_and_evaluate --batch_size=32 --lr_warmup_epochs=1 --quantize --symmetric --use_qdq --label_smoothing 0.1 --lr_init=0.00005 --momentum=0.875 --weight_decay=3.0517578125e-05 --finetune_checkpoint=$1 --data_dir=$2 --results_dir=$3 --num_iter 10 --data_format NHWC
|
||||
|
|
Loading…
Reference in a new issue