2020-03-02 14:29:12 +01:00
|
|
|
import argparse
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
2020-07-04 01:42:09 +02:00
|
|
|
from dlexport.tensorflow import to_savedmodel, to_onnx, to_tensorrt
|
2020-03-02 14:29:12 +01:00
|
|
|
from utils.data_loader import Dataset
|
|
|
|
from utils.model_fn import unet_fn
|
|
|
|
|
|
|
|
PARSER = argparse.ArgumentParser(description="U-Net medical")
|
|
|
|
|
2020-07-04 01:42:09 +02:00
|
|
|
PARSER.add_argument('--to', dest='to', choices=['savedmodel', 'tensorrt', 'onnx'], required=True)
|
2020-03-02 14:29:12 +01:00
|
|
|
|
|
|
|
PARSER.add_argument('--use_amp', dest='use_amp', action='store_true', default=False)
|
|
|
|
PARSER.add_argument('--use_xla', dest='use_xla', action='store_true', default=False)
|
|
|
|
PARSER.add_argument('--compress', dest='compress', action='store_true', default=False)
|
|
|
|
|
|
|
|
PARSER.add_argument('--input_shape',
|
|
|
|
nargs='+',
|
|
|
|
type=int,
|
|
|
|
help="""Directory where to download the dataset""")
|
|
|
|
|
|
|
|
PARSER.add_argument('--data_dir',
|
|
|
|
type=str,
|
|
|
|
help="""Directory where to download the dataset""")
|
|
|
|
|
|
|
|
PARSER.add_argument('--checkpoint_dir',
|
|
|
|
type=str,
|
|
|
|
help="""Directory where to download the dataset""")
|
|
|
|
|
|
|
|
PARSER.add_argument('--savedmodel_dir',
|
|
|
|
type=str,
|
|
|
|
help="""Directory where to download the dataset""")
|
|
|
|
|
|
|
|
PARSER.add_argument('--precision',
|
|
|
|
type=str,
|
|
|
|
choices=['FP32', 'FP16', 'INT8'],
|
|
|
|
help="""Directory where to download the dataset""")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
"""
|
|
|
|
Starting point of the application
|
|
|
|
"""
|
|
|
|
flags = PARSER.parse_args()
|
|
|
|
|
|
|
|
if flags.to == 'savedmodel':
|
|
|
|
to_savedmodel(input_shape=flags.input_shape,
|
|
|
|
model_fn=unet_fn,
|
2020-07-04 01:42:09 +02:00
|
|
|
src_dir=flags.checkpoint_dir,
|
|
|
|
dst_dir='./saved_model',
|
2020-03-02 14:29:12 +01:00
|
|
|
input_names=['IteratorGetNext'],
|
|
|
|
output_names=['total_loss_ref'],
|
|
|
|
use_amp=flags.use_amp,
|
|
|
|
use_xla=flags.use_xla,
|
|
|
|
compress=flags.compress)
|
2020-07-04 01:42:09 +02:00
|
|
|
if flags.to == 'tensorrt':
|
2020-03-02 14:29:12 +01:00
|
|
|
ds = Dataset(data_dir=flags.data_dir,
|
|
|
|
batch_size=1,
|
|
|
|
augment=False,
|
|
|
|
gpu_id=0,
|
|
|
|
num_gpus=1,
|
|
|
|
seed=42)
|
|
|
|
iterator = ds.test_fn(count=1).make_one_shot_iterator()
|
|
|
|
features = iterator.get_next()
|
|
|
|
|
|
|
|
sess = tf.Session()
|
|
|
|
|
|
|
|
def input_data():
|
|
|
|
return {'input_tensor:0': sess.run(features)}
|
|
|
|
|
2020-07-04 01:42:09 +02:00
|
|
|
to_tensorrt(src_dir=flags.savedmodel_dir,
|
|
|
|
dst_dir='./tf_trt_model',
|
2020-03-02 14:29:12 +01:00
|
|
|
precision=flags.precision,
|
|
|
|
feed_dict_fn=input_data,
|
|
|
|
num_runs=1,
|
|
|
|
output_tensor_names=['Softmax:0'],
|
|
|
|
compress=flags.compress)
|
|
|
|
if flags.to == 'onnx':
|
2020-07-04 01:42:09 +02:00
|
|
|
to_onnx(src_dir=flags.savedmodel_dir,
|
|
|
|
dst_dir='./onnx_model',
|
2020-03-02 14:29:12 +01:00
|
|
|
compress=flags.compress)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|
|
|
|
|