DeepLearningExamples/PyTorch/Classification/ConvNets/model2onnx.py
2021-07-12 10:20:43 +00:00

158 lines
5.3 KiB
Python

import argparse
import torch
import pytorch_quantization
from image_classification.models import (
resnet50,
resnext101_32x4d,
se_resnext101_32x4d,
efficientnet_b0,
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
)
def available_models():
models = {
m.name: m
for m in [
resnet50,
resnext101_32x4d,
se_resnext101_32x4d,
efficientnet_b0,
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
]
}
return models
def parse_args(parser):
"""
Parse commandline arguments.
"""
model_names = available_models().keys()
parser.add_argument("--arch", "-a", metavar="ARCH", default="resnet50", choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)")
parser.add_argument("--device", metavar="DEVICE", default="cuda", choices=['cpu', 'cuda'],
help="device on which model is settled: cpu, cuda (default: cuda)")
parser.add_argument("--image-size", default=None, type=int, help="resolution of image")
parser.add_argument('--output', type=str, help='Path to converted model')
parser.add_argument("-b", "--batch-size", default=256, type=int, metavar="N",
help="mini-batch size (default: 256) per gpu")
return parser
def final_name(base_name):
splitted = base_name.split('.')
if 'pt' in splitted:
fin_name = base_name.replace('pt', 'onnx')
elif 'pth' in splitted:
fin_name = base_name.replace('pth', 'onnx')
elif len(splitted) > 1:
fin_name = '.'.join(splitted[:-1] + ['onnx'])
else:
fin_name = base_name + '.onnx'
return fin_name
def get_dataloader(image_size, bs, num_classes):
"""return dataloader for inference"""
from image_classification.dataloaders import get_syntetic_loader
def data_loader():
loader, _ = get_syntetic_loader(None, image_size, bs, num_classes, False)
for inp, _ in loader:
yield inp
break
return data_loader()
def prepare_inputs(dataloader, device):
"""load sample inputs to device"""
inputs = []
for batch in dataloader:
if type(batch) is torch.Tensor:
batch_d = batch.to(device)
batch_d = (batch_d, )
inputs.append(batch_d)
else:
batch_d = []
for x in batch:
assert type(x) is torch.Tensor, "input is not a tensor"
batch_d.append(x.to(device))
batch_d = tuple(batch_d)
inputs.append(batch_d)
return inputs
def check_quant_weight_correctness(checkpoint_path, model):
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
state_dict = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()}
quantizers_sd_keys = {f'{n[0]}._amax' for n in model.named_modules() if 'quantizer' in n[0]}
sd_all_keys = quantizers_sd_keys | set(model.state_dict().keys())
assert set(state_dict.keys()) == sd_all_keys, (f'Passed quantized architecture, but following keys are missing in '
f'checkpoint: {list(sd_all_keys - set(state_dict.keys()))}')
def main(args, model_args, model_arch):
quant_arch = args.arch in ['efficientnet-quant-b0', 'efficientnet-quant-b4']
if quant_arch:
pytorch_quantization.nn.modules.tensor_quantizer.TensorQuantizer.use_fb_fake_quant = True
model = model_arch(**model_args.__dict__)
if quant_arch and model_args.pretrained_from_file is not None:
check_quant_weight_correctness(model_args.pretrained_from_file, model)
image_size = args.image_size if args.image_size is not None else model.arch.default_image_size
train_loader = get_dataloader(image_size, args.batch_size, model_args.num_classes)
inputs = prepare_inputs(train_loader, args.device)
final_model_path = args.output if args.output is not None else final_name(model_args.pretrained_from_file)
model.to(args.device)
model.eval()
with torch.no_grad():
torch.onnx.export(model,
inputs[0],
final_model_path,
verbose=True,
opset_version=13,
enable_onnx_checker=True,
do_constant_folding=True)
if __name__ == '__main__':
epilog = [
"Based on the architecture picked by --arch flag, you may use the following options:\n"
]
for model, ep in available_models().items():
model_help = "\n".join(ep.parser().format_help().split("\n")[2:])
epilog.append(model_help)
parser = argparse.ArgumentParser(
description="PyTorch ImageNet Training",
epilog="\n".join(epilog),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser = parse_args(parser)
args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
assert len(rest) == 0, f"Unknown args passed: {rest}"
main(args, model_args, model_arch)