Merge pull request #914 from NVIDIA/gh/release

[ConvNets/PyT] QAT for EfficientNet
This commit is contained in:
nv-kkudrynski 2021-04-13 19:15:50 +02:00 committed by GitHub
commit 26206dbf87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 842 additions and 136 deletions

View file

@ -1,8 +1,9 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.12-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.03-py3
FROM ${FROM_IMAGE_NAME}
ADD requirements.txt /workspace/
WORKDIR /workspace/
RUN pip install nvidia-pyindex
RUN pip install --no-cache-dir -r requirements.txt
ADD . /workspace/rn50
WORKDIR /workspace/rn50

View file

@ -25,7 +25,7 @@ The following table provides links to where you can find additional information
| resnet50 | [README](./resnet50v1.5/README.md) |
| resnext101-32x4d | [README](./resnext101-32x4d/README.md) |
| se-resnext101-32x4d | [README](./se-resnext101-32x4d/README.md) |
| EfficientNet-B0 | [README](./efficientnet/README.md) |
| EfficientNet | [README](./efficientnet/README.md) |
## Validation accuracy results
@ -39,11 +39,16 @@ in the corresponding model's README.
The following table shows the validation accuracy results of the
three classification models side-by-side.
| **Model** | **Mixed Precision Top1** | **Mixed Precision Top5** | **32 bit Top1** | **32 bit Top5** |
|:-------------------:|:------------------------:|:------------------------:|:---------------:|:---------------:|
| resnet50 | 78.60 | 94.19 | 78.69 | 94.16 |
| resnext101-32x4d | 80.43 | 95.06 | 80.40 | 95.04 |
| se-resnext101-32x4d | 81.00 | 95.48 | 81.09 | 95.45 |
| **Model** | **Mixed Precision Top1** | **Mixed Precision Top5** | **32 bit Top1** | **32 bit Top5** |
|:----------------------:|:------------------------:|:------------------------:|:---------------:|:---------------:|
| efficientnet-b0 | 77.63 | 93.82 | 77.31 | 93.76 |
| efficientnet-b4 | 82.98 | 96.44 | 82.92 | 96.43 |
| efficientnet-widese-b0 | 77.89 | 94.00 | 77.97 | 94.05 |
| efficientnet-widese-b4 | 83.28 | 96.45 | 83.30 | 96.47 |
| resnet50 | 78.60 | 94.19 | 78.69 | 94.16 |
| resnext101-32x4d | 80.43 | 95.06 | 80.40 | 95.04 |
| se-resnext101-32x4d | 81.00 | 95.48 | 81.09 | 95.45 |
## Training performance results
@ -62,11 +67,15 @@ The following table shows the training accuracy results of the
three classification models side-by-side.
| **Model** | **Mixed Precision** | **TF32** | **Mixed Precision Speedup** |
|:-------------------:|:-------------------:|:----------:|:---------------------------:|
| resnet50 | 15977 img/s | 7365 img/s | 2.16 x |
| resnext101-32x4d | 7399 img/s | 3193 img/s | 2.31 x |
| se-resnext101-32x4d | 5248 img/s | 2665 img/s | 1.96 x |
| **Model** | **Mixed Precision** | **TF32** | **Mixed Precision Speedup** |
|:----------------------:|:-------------------:|:----------:|:---------------------------:|
| efficientnet-b0 | 14391 img/s | 8225 img/s | 1.74 x |
| efficientnet-b4 | 2341 img/s | 1204 img/s | 1.94 x |
| efficientnet-widese-b0 | 15053 img/s | 8233 img/s | 1.82 x |
| efficientnet-widese-b4 | 2339 img/s | 1202 img/s | 1.94 x |
| resnet50 | 15977 img/s | 7365 img/s | 2.16 x |
| resnext101-32x4d | 7399 img/s | 3193 img/s | 2.31 x |
| se-resnext101-32x4d | 5248 img/s | 2665 img/s | 1.96 x |
### Training performance: NVIDIA DGX-1 16G (8x V100 16GB)
@ -81,11 +90,15 @@ in the corresponding model's README.
The following table shows the training accuracy results of the
three classification models side-by-side.
| **Model** | **Mixed Precision** | **FP32** | **Mixed Precision Speedup** |
|:-------------------:|:-------------------:|:----------:|:---------------------------:|
| resnet50 | 7608 img/s | 2851 img/s | 2.66 x |
| resnext101-32x4d | 3742 img/s | 1117 img/s | 3.34 x |
| se-resnext101-32x4d | 2716 img/s | 994 img/s | 2.73 x |
| **Model** | **Mixed Precision** | **FP32** | **Mixed Precision Speedup** |
|:----------------------:|:-------------------:|:----------:|:---------------------------:|
| efficientnet-b0 | 7664 img/s | 4571 img/s | 1.67 x |
| efficientnet-b4 | 1330 img/s | 598 img/s | 2.22 x |
| efficientnet-widese-b0 | 7694 img/s | 4489 img/s | 1.71 x |
| efficientnet-widese-b4 | 1323 img/s | 590 img/s | 2.24 x |
| resnet50 | 7608 img/s | 2851 img/s | 2.66 x |
| resnext101-32x4d | 3742 img/s | 1117 img/s | 3.34 x |
| se-resnext101-32x4d | 2716 img/s | 994 img/s | 2.73 x |
## Model Comparison

View file

@ -36,7 +36,6 @@ if __name__ == "__main__":
k[len("module.") :] if "module." in k else k: v
for k, v in checkpoint["state_dict"].items()
}
print(f"Loaded model, acc : {checkpoint['best_prec1']}")
print(f"Loaded {checkpoint['arch']} : {checkpoint['best_prec1']}")
torch.save(model_state_dict, args.weight_path.format(arch=checkpoint['arch'][0], acc = checkpoint['best_prec1']))
torch.save(model_state_dict, args.weight_path)

View file

@ -30,6 +30,8 @@ from image_classification.models import (
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
)
def available_models():
@ -43,6 +45,8 @@ def available_models():
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
]
}
return models
@ -93,9 +97,21 @@ def load_jpeg_from_file(path, image_size, cuda=True):
return input
def check_quant_weight_correctness(checkpoint_path, model):
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
state_dict = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()}
quantizers_sd_keys = {f'{n[0]}._amax' for n in model.named_modules() if 'quantizer' in n[0]}
sd_all_keys = quantizers_sd_keys | set(model.state_dict().keys())
assert set(state_dict.keys()) == sd_all_keys, (f'Passed quantized architecture, but following keys are missing in '
f'checkpoint: {list(sd_all_keys - set(state_dict.keys()))}')
def main(args, model_args):
imgnet_classes = np.array(json.load(open("./LOC_synset_mapping.json", "r")))
model = available_models()[args.arch](**model_args.__dict__)
if args.arch in ['efficientnet-quant-b0', 'efficientnet-quant-b4']:
check_quant_weight_correctness(model_args.pretrained_from_file, model)
if not args.cpu:
model = model.cuda()
model.eval()

View file

@ -12,6 +12,8 @@ achieve state-of-the-art accuracy, and is tested and maintained by NVIDIA.
* [Mixed precision training](#mixed-precision-training)
* [Enabling mixed precision](#enabling-mixed-precision)
* [Enabling TF32](#enabling-tf32)
* [Quantization](#quantization)
* [Quantization-aware training](#qat)
* [Setup](#setup)
* [Requirements](#requirements)
* [Quick Start Guide](#quick-start-guide)
@ -22,6 +24,7 @@ achieve state-of-the-art accuracy, and is tested and maintained by NVIDIA.
* [Training process](#training-process)
* [Inference process](#inference-process)
* [NGC pretrained weights](#ngc-pretrained-weights)
* [QAT process](#qat-process)
* [Performance](#performance)
* [Benchmarking](#benchmarking)
* [Training performance benchmark](#training-performance-benchmark)
@ -37,7 +40,10 @@ achieve state-of-the-art accuracy, and is tested and maintained by NVIDIA.
* [Training performance: NVIDIA DGX-1 (8x V100 32GB)](#training-performance-nvidia-dgx-1-8x-v100-32gb)
* [Inference performance results](#inference-performance-results)
* [Inference performance: NVIDIA A100 (1x A100 80GB)](#inference-performance-nvidia-a100-1x-a100-80gb)
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
* [Inference performance: NVIDIA V100 (1x V100 16GB)](#inference-performance-nvidia-v100-1x-v100-16gb)
* [QAT results](#qat-results)
* [QAT Training performance: NVIDIA DGX-1 (8x V100 32GB)](#qat-training-performance-nvidia-dgx-1-8x-v100-32gb))
* [QAT Inference accuracy](#qat-inference-accuracy)
* [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
@ -76,6 +82,28 @@ scale the learning rate.
* [MixUp](https://arxiv.org/pdf/1710.09412.pdf) = 0.2
* We train for 400 epochs
**Optimizer for QAT**
This model uses SGD optimizer for B0 models and RMSPROP optimizer alpha=0.853 epsilon=0.00422 for B4 models. Other hyperparameters we used are:
* Momentum:
* 0.89 for B0 models
* 0.9 for B4 models
* Learning rate (LR):
* 0.0125 for 128 batch size for B0 models
* 4.09e-06 for 32 batch size for B4 models
scale the learning rate.
* Learning rate schedule:
* cosine LR schedule for B0 models
* linear LR schedule for B4 models
* Weight decay (WD):
* 4.50e-05 for B0 models
* 9.714e-04 for B4 models
* We do not apply WD on Batch Norm trainable parameters (gamma/bias)
* We train for:
*10 epochs for B0 models
*2 epochs for B4 models
**Data augmentation**
This model uses the following data augmentation:
@ -102,6 +130,7 @@ The following features are supported by this model:
|-----------------------|--------------------------
|[DALI](https://docs.nvidia.com/deeplearning/dali/release-notes/index.html) | Yes (without autoaugmentation)
|[APEX AMP](https://nvidia.github.io/apex/amp.html) | Yes
|[QAT](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization) | Yes
#### Features
@ -124,6 +153,11 @@ DALI currently does not support Autoaugmentation, so for best accuracy it has to
A PyTorch extension that contains utility libraries, such as [Automatic Mixed Precision (AMP)](https://nvidia.github.io/apex/amp.html), which require minimal network code changes to leverage Tensor Cores performance. Refer to the [Enabling mixed precision](#enabling-mixed-precision) section for more details.
**[QAT](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization)**
Quantization aware training (QAT) is a method for changing precision to INT8 which speeds up the inference process at the price of a slight decrease of network accuracy. Refer to the [Quantization](#quantization) section for more details.
### Mixed precision training
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format, while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in Volta, and following with both the Turing and Ampere architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
@ -173,6 +207,28 @@ For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates A
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
### Quantization
Quantization is the process of transforming deep learning models to use parameters and computations at a lower precision. Traditionally, DNN training and inference have relied on the IEEE single-precision floating-point format, using 32 bits to represent the floating-point model weights and activation tensors.
This compute budget may be acceptable at training as most DNNs are trained in data centers or in the cloud with NVIDIA V100 or A100 GPUs that have significantly large compute capability and much larger power budgets. However, during deployment, these models are most often required to run on devices with much smaller computing resources and lower power budgets at the edge. Running a DNN inference using the full 32-bit representation is not practical for real-time analysis given the compute, memory, and power constraints of the edge.
To help reduce the compute budget, while not compromising on the structure and number of parameters in the model, you can run inference at a lower precision. Initially, quantized inferences were run at half-point precision with tensors and weights represented as 16-bit floating-point numbers. While this resulted in compute savings of about 1.21.5x, there was still some compute budget and memory bandwidth that could be leveraged. In lieu of this, models are now quantized to an even lower precision, with an 8-bit integer representation for weights and tensors. This results in a model that is 4x smaller in memory and about 24x faster in throughput.
While 8-bit quantization is appealing to save compute and memory budgets, it is a lossy process. During quantization, a small range of floating-point numbers are squeezed to a fixed number of information buckets. This results in loss of information.
The minute differences which could originally be resolved using 32-bit representations are now lost because they are quantized to the same bucket in 8-bit representations. This is similar to rounding errors that one encounters when representing fractional numbers as integers. To maintain accuracy during inferences at a lower precision, it is important to try and mitigate errors arising due to this loss of information.
#### Quantization-aware training
In QAT, the quantization error is considered when training the model. The training graph is modified to simulate the lower precision behavior in the forward pass of the training process. This introduces the quantization errors as part of the training loss, which the optimizer tries to minimize during the training. Thus, QAT helps in modeling the quantization errors during training and mitigates its effects on the accuracy of the model at deployment.
However, the process of modifying the training graph to simulate lower precision behavior is intricate. To run QAT, it is necessary to insert FakeQuantization nodes for the weights of the DNN Layers and Quantize-Dequantize (QDQ) nodes to the intermediate activation tensors to compute their dynamic ranges.
For more information, see this [Quantization paper](https://arxiv.org/abs/2004.09602) and [Quantization-Aware Training](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html#quantization-training) documentation.
Tutorial for `pytoch-quantization` library can be found here [`pytorch-quantization` tutorial](https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/tutorials/quant_resnet50.html).
It is important to mention that EfficientNet is NN, which is hard to quantize because the activation function all across the network is the SiLU (called also the Swish), whose negative values lie in very short range, which introduce a large quantization error. More details can be found in Appendix D of the [Quantization paper](https://arxiv.org/abs/2004.09602).
## Setup
@ -183,7 +239,7 @@ The following section lists the requirements that you need to meet in order to s
This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
* [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
* [PyTorch 20.12-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch) or newer
* [PyTorch 21.03-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch) or newer
* Supported GPUs:
* [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
* [NVIDIA Turing architecture](https://www.nvidia.com/en-us/geforce/turing/)
@ -393,6 +449,29 @@ To run inference on JPEG images using pretrained weights, run:
`python classify.py --arch efficientnet-<version> --weights --precision AMP|FP32 --image <path to JPEG image>`
### Quantization process
EfficientNet-b0 and EfficientNet-b4 models can be quantized using the QAT process from running the `quant_main.py` script.
`python ./quant_main.py <path to imagenet> --arch efficientnet-quant-<version> --epochs <# of QAT epochs> --pretrained-from-file <path to non-quantized model weights> <any other parameters for training such as batch, momentum etc.>`
During the QAT process, evaluation is done in the same way as during standard training. `quant_main.py` works in the same way as the original `main.py` script, but with quantized models. It means that `quant_main.py` can be used to resume the QAT process with the flag `--resume`:
`python ./quant_main.py <path to imagenet> --arch efficientnet-quant-<version> --resume <path to mid-training checkpoint> ...`
or to evaluate a created checkpoint with the flag `--evaluate`:
`python ./quant_main.py --arch efficientnet-quant-<version> --evaluate --epochs 1 --resume <path to checkpoint> -b <batch size> <path to imagenet>`
It also can run on multi-GPU in an identical way as the standard `main.py` script:
`python ./multiproc.py --nproc_per_node 8 ./quant_main.py --arch efficientnet-quant-<version> ... <path to imagenet>`
There is also a possibility to transform trained models (quantized or not) into ONNX format, which is needed to convert it later into TensorRT, where quantized networks are much faster during inference. Conversion to TensorRT will be supported in the next release. The conversion to ONNX consists of two steps:
* translate checkpoint to pure weights:
`python checkpoint2model.py --checkpoint-path <path to quant checkpoint> --weight-path <path where quant weights will be stored>`
* translate pure weights to ONNX:
`python model2onnx.py --arch efficientnet-quant-<version> --pretrained-from-file <path to model quant weights> -b <batch size>`
Quantized models could also be used to classify new images using the `classify.py` flag. For example:
`python classify.py --arch efficientnet-quant-<version> -c fanin --pretrained-from-file <path to quant weights> --image <path to JPEG image>`
## Performance
### Benchmarking
@ -466,9 +545,7 @@ Our results were obtained by running the applicable `efficientnet/training/<AMP|
| **Model** | **Epochs** | **GPUs** | **Top1 accuracy - FP32** | **Top1 accuracy - mixed precision** | **Time to train - FP32** | **Time to train - mixed precision** | **Time to train speedup (FP32 to mixed precision)** |
|:----------------------:|:----------:|:--------:|:------------------------:|:-----------------------------------:|:------------------------:|:-----------------------------------:|:---------------------------------------------------:|
| efficientnet-b0 | 400 | 8 | 77.02 +/- 0.04 | 77.17 +/- 0.08 | 34 | 24 | 1.417 |
| efficientnet-b4 | 400 | 8 | NaN | 82.68 +/- 0.1 | NaN | 113 | NaN |
| efficientnet-widese-b0 | 400 | 8 | 77.59 +/- 0.16 | 77.69 +/- 0.12 | 35 | 24 | 1.458 |
| efficientnet-widese-b4 | 400 | 8 | NaN | 82.89 +/- 0.07 | NaN | 116 | NaN |
##### Example plots
@ -707,8 +784,27 @@ Our results were obtained by running the applicable `efficientnet/inference/<AMP
| efficientnet-widese-b4 | 256 | 771 img/s | 344.05 ms | 329.69 ms | 330.7 ms |
## Release notes
#### Quantization results
##### QAT Training performance: NVIDIA DGX-1 (8x V100 32GB)
| **Model** | **GPUs** | **Calibration** | **QAT model** | **FP32** | **QAT ratio** |
|:---------------------:|:---------|:---------------:|:---------------:|:----------:|:-------------:|
| efficientnet-quant-b0 | 8 | 14.71 img/s | 2644.62 img/s | 3798 img/s | 0.696 x |
| efficientnet-quant-b4 | 8 | 1.85 img/s | 310.41 img/s | 666 img/s | 0.466 x |
###### Quant Inference accuracy
The best checkpoints generated during training were used as a base for the QAT.
| **Model** | **QAT Epochs** | **QAT Top1** | **Gap between FP32 Top1 and QAT Top1** |
|:---------------------:|:--------------:|:------------:|:--------------------------------------:|
| efficientnet-quant-b0 | 10 | 77.12 | 0.51 |
| efficientnet-quant-b4 | 2 | 82.54 | 0.44 |
## Release notes
### Changelog
1. April 2020

View file

@ -0,0 +1,14 @@
python ./multiproc.py \
--nproc_per_node 8 \
./quant_main.py /imagenet \
--arch efficientnet-quant-b0 \
--epochs 10 \
-j5 -p 500 \
--data-backend pytorch \
--optimizer sgd \
-b 128 \
--lr 0.0125 \
--momentum 0.89 \
--weight-decay 4.50e-05 \
--lr-schedule cosine \
--pretrained-from-file "${1}"

View file

@ -0,0 +1,16 @@
python ./multiproc.py \
--nproc_per_node 8 \
./quant_main.py /imagenet \
--arch efficientnet-quant-b4 \
--epochs 2 \
-j5 -p 500 \
--data-backend pytorch \
--optimizer rmsprop \
-b 32 \
--lr 4.09e-06 \
--momentum 0.9 \
--weight-decay 9.714e-04 \
--lr-schedule linear \
--rmsprop-alpha 0.853 \
--rmsprop-eps 0.00422 \
--pretrained-from-file "${1}"

View file

@ -444,7 +444,7 @@ def get_pytorch_train_loader(
pin_memory=True,
collate_fn=partial(fast_collate, memory_format),
drop_last=True,
persistent_workers=False,
persistent_workers=True,
)
return (
@ -498,7 +498,7 @@ def get_pytorch_val_loader(
pin_memory=True,
collate_fn=partial(fast_collate, memory_format),
drop_last=False,
persistent_workers=False,
persistent_workers=True,
)
return PrefetchedWrapper(val_loader, 0, num_classes, one_hot), len(val_loader)

View file

@ -28,6 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import OrderedDict
from numbers import Number
import dllogger
import numpy as np
@ -37,7 +38,10 @@ def format_step(step):
return step
s = ""
if len(step) > 0:
s += "Epoch: {} ".format(step[0])
if isinstance(step[0], Number):
s += "Epoch: {} ".format(step[0])
else:
s += "{} ".format(step[0])
if len(step) > 1:
s += "Iteration: {} ".format(step[1])
if len(step) > 2:
@ -211,6 +215,7 @@ class Logger(object):
self.epoch = start_epoch
self.iteration = -1
self.val_iteration = -1
self.calib_iteration = -1
self.metrics = OrderedDict()
self.backends = backends
self.print_interval = print_interval
@ -229,23 +234,32 @@ class Logger(object):
def log_metric(self, metric_name, val, n=1):
self.metrics[metric_name]["meter"].record(val, n=n)
def start_iteration(self, val=False):
if val:
def start_iteration(self, mode='train'):
if mode == 'val':
self.val_iteration += 1
else:
elif mode == 'train':
self.iteration += 1
elif mode == 'calib':
self.calib_iteration += 1
def end_iteration(self, val=False):
it = self.val_iteration if val else self.iteration
if it % self.print_interval == 0:
def end_iteration(self, mode='train'):
if mode == 'val':
it = self.val_iteration
elif mode == 'train':
it = self.iteration
elif mode == 'calib':
it = self.calib_iteration
if it % self.print_interval == 0 or mode == 'calib':
metrics = {
n: m for n, m in self.metrics.items() if n.startswith("val") == val
n: m for n, m in self.metrics.items() if n.startswith(mode)
}
step = (
(self.epoch, self.iteration)
if not val
else (self.epoch, self.iteration, self.val_iteration)
)
if mode == 'train':
step = (self.epoch, self.iteration)
elif mode == 'val':
step = (self.epoch, self.iteration, self.val_iteration)
elif mode == 'calib':
step = ('Calibration', self.calib_iteration)
verbositys = {m["level"] for _, m in metrics.items()}
for ll in verbositys:
@ -268,11 +282,13 @@ class Logger(object):
self.val_iteration = 0
for n, m in self.metrics.items():
m["meter"].reset_epoch()
if not n.startswith('calib'):
m["meter"].reset_epoch()
def end_epoch(self):
for n, m in self.metrics.items():
m["meter"].reset_iteration()
if not n.startswith('calib'):
m["meter"].reset_iteration()
verbositys = {m["level"] for _, m in self.metrics.items()}
for ll in verbositys:
@ -282,6 +298,18 @@ class Logger(object):
data={n: m["meter"].get_epoch() for n, m in llm.items()},
)
def start_calibration(self):
self.calib_iteration = 0
for n, m in self.metrics.items():
if n.startswith('calib'):
m["meter"].reset_epoch()
def end_calibration(self):
for n, m in self.metrics.items():
if n.startswith('calib'):
m["meter"].reset_iteration()
def end(self):
for n, m in self.metrics.items():
m["meter"].reset_epoch()
@ -298,11 +326,11 @@ class Logger(object):
dllogger.flush()
def iteration_generator_wrapper(self, gen, val=False):
def iteration_generator_wrapper(self, gen, mode='train'):
for g in gen:
self.start_iteration(val=val)
self.start_iteration(mode=mode)
yield g
self.end_iteration(val=val)
self.end_iteration(mode=mode)
def epoch_generator_wrapper(self, gen):
for g in gen:

View file

@ -18,4 +18,6 @@ from .efficientnet import (
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
)

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from pytorch_quantization import nn as quant_nn
# LayerBuilder {{{
class LayerBuilder(object):
@ -135,18 +136,18 @@ class LambdaLayer(nn.Module):
class SqueezeAndExcitation(nn.Module):
def __init__(self, in_channels, squeeze, activation):
super(SqueezeAndExcitation, self).__init__()
self.squeeze = nn.Linear(in_channels, squeeze)
self.expand = nn.Linear(squeeze, in_channels)
self.pooling = nn.AdaptiveAvgPool2d(1)
self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
self.expand = nn.Conv2d(squeeze, in_channels, 1)
self.activation = activation
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = torch.mean(x, [2, 3])
out = self.pooling(x)
out = self.squeeze(out)
out = self.activation(out)
out = self.expand(out)
out = self.sigmoid(out)
out = out.unsqueeze(2).unsqueeze(3)
return out
@ -198,5 +199,16 @@ class ONNXSiLU(nn.Module):
class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
def __init__(self, in_channels, squeeze, activation, quantized=False):
super().__init__(in_channels, squeeze, activation,)
self.quantized = quantized
if quantized:
self.mul_a_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
self.mul_b_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
def forward(self, x):
return super().forward(x) * x
if not self.quantized:
return super().forward(x) * x
else:
x_quant = self.mul_a_quantizer(super().forward(x))
return x_quant * self.mul_b_quantizer(x)

View file

@ -8,6 +8,7 @@ from dataclasses import dataclass, replace
import torch
from torch import nn
from functools import partial
from pytorch_quantization import nn as quant_nn
from .common import (
SqueezeAndExcitation,
@ -26,6 +27,8 @@ from .model import (
EntryPoint,
)
from ..quantization import switch_on_quantization
# EffNetArch {{{
@dataclass
class EffNetArch(ModelArch):
@ -103,6 +106,7 @@ class EffNetParams(ModelParams):
bn_momentum: float = 1 - 0.99
bn_epsilon: float = 1e-3
survival_prob: float = 1
quantized: bool = False
def parser(self, name):
p = super().parser(name)
@ -141,14 +145,6 @@ class EffNetParams(ModelParams):
p.add_argument(
"--dropout", default=self.dropout, type=float, help="Dropout drop prob"
)
p.add_argument(
"--onnx",
dest="activation",
action="store_const",
const="onnx-silu",
default=self.activation,
help="Use ONNX-compatible SiLU implementation",
)
return p
@ -166,44 +162,47 @@ class EfficientNet(nn.Module):
bn_momentum: float = 1 - 0.99,
bn_epsilon: float = 1e-3,
survival_prob: float = 1,
quantized: bool = False
):
super(EfficientNet, self).__init__()
self.arch = arch
self.num_layers = arch.num_layers()
self.num_blocks = sum(arch.num_repeat)
self.survival_prob = survival_prob
self.builder = LayerBuilder(
LayerBuilder.Config(
activation=activation,
conv_init=conv_init,
bn_momentum=bn_momentum,
bn_epsilon=bn_epsilon,
self.quantized = quantized
with switch_on_quantization(self.quantized):
super(EfficientNet, self).__init__()
self.arch = arch
self.num_layers = arch.num_layers()
self.num_blocks = sum(arch.num_repeat)
self.survival_prob = survival_prob
self.builder = LayerBuilder(
LayerBuilder.Config(
activation=activation,
conv_init=conv_init,
bn_momentum=bn_momentum,
bn_epsilon=bn_epsilon,
)
)
)
self.stem = self._make_stem(arch.stem_channels)
out_channels = arch.stem_channels
self.stem = self._make_stem(arch.stem_channels)
out_channels = arch.stem_channels
plc = 0
for i, (k, s, r, e, c) in arch.enumerate():
layer, out_channels = self._make_layer(
block=arch.block,
kernel_size=k,
stride=s,
num_repeat=r,
expansion=e,
in_channels=out_channels,
out_channels=c,
squeeze_excitation_ratio=arch.squeeze_excitation_ratio,
prev_layer_count=plc,
plc = 0
for i, (k, s, r, e, c) in arch.enumerate():
layer, out_channels = self._make_layer(
block=arch.block,
kernel_size=k,
stride=s,
num_repeat=r,
expansion=e,
in_channels=out_channels,
out_channels=c,
squeeze_excitation_ratio=arch.squeeze_excitation_ratio,
prev_layer_count=plc,
)
plc = plc + r
setattr(self, f"layer{i+1}", layer)
self.features = self._make_features(out_channels, arch.feature_channels)
self.classifier = self._make_classifier(
arch.feature_channels, num_classes, dropout
)
plc = plc + r
setattr(self, f"layer{i+1}", layer)
self.features = self._make_features(out_channels, arch.feature_channels)
self.classifier = self._make_classifier(
arch.feature_channels, num_classes, dropout
)
def forward(self, x):
x = self.stem(x)
@ -275,7 +274,8 @@ class EfficientNet(nn.Module):
return nn.Sequential(
OrderedDict(
[
("pooling", LambdaLayer(lambda x: torch.mean(x, [2, 3]))),
("pooling", nn.AdaptiveAvgPool2d(1)),
("squeeze", LambdaLayer(lambda x: x.squeeze(-1).squeeze(-1))),
("dropout", nn.Dropout(dropout)),
("fc", nn.Linear(num_features, num_classes)),
]
@ -307,6 +307,7 @@ class EfficientNet(nn.Module):
stride,
self.arch.squeeze_excitation_ratio,
survival_prob if stride == 1 and in_channels == out_channels else 1.0,
self.quantized
)
layers.append((f"block{idx}", blk))
@ -321,6 +322,7 @@ class EfficientNet(nn.Module):
1, # stride
squeeze_excitation_ratio,
survival_prob,
self.quantized
)
layers.append((f"block{idx}", blk))
return nn.Sequential(OrderedDict(layers)), out_channels
@ -341,8 +343,10 @@ class MBConvBlock(nn.Module):
squeeze_excitation_ratio: int,
squeeze_hidden=False,
survival_prob: float = 1.0,
quantized: bool = False
):
super().__init__()
self.quantized = quantized
self.residual = stride == 1 and in_channels == out_channels
hidden_dim = in_channels * expand_ratio
squeeze_base = hidden_dim if squeeze_hidden else in_channels
@ -357,12 +361,15 @@ class MBConvBlock(nn.Module):
depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
)
self.se = SequentialSqueezeAndExcitation(
hidden_dim, squeeze_dim, builder.activation()
hidden_dim, squeeze_dim, builder.activation(), self.quantized
)
self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
self.survival_prob = survival_prob
if self.quantized and self.residual:
self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input) # TODO QuantConv2d ?!?
def drop(self):
if self.survival_prob == 1.0:
return False
@ -384,7 +391,8 @@ class MBConvBlock(nn.Module):
multiplication_factor = 1.0 / self.survival_prob
else:
multiplication_factor = 1.0
if self.quantized:
x = self.residual_quantizer(x)
return torch.add(x, alpha=multiplication_factor, other=b)
@ -397,6 +405,7 @@ def original_mbconv(
stride: int,
squeeze_excitation_ratio: int,
survival_prob: float,
quantized: bool,
):
return MBConvBlock(
builder,
@ -408,6 +417,7 @@ def original_mbconv(
squeeze_excitation_ratio,
squeeze_hidden=False,
survival_prob=survival_prob,
quantized=quantized
)
@ -420,6 +430,7 @@ def widese_mbconv(
stride: int,
squeeze_excitation_ratio: int,
survival_prob: float,
quantized: bool,
):
return MBConvBlock(
builder,
@ -431,6 +442,7 @@ def widese_mbconv(
squeeze_excitation_ratio,
squeeze_hidden=True,
survival_prob=survival_prob,
quantized=False
)
@ -477,6 +489,14 @@ architectures = {
"efficientnet-widese-b5": _m(arch=replace(effnet_b5_layers, block=widese_mbconv), params=EffNetParams(dropout=0.4)),
"efficientnet-widese-b6": _m(arch=replace(effnet_b6_layers, block=widese_mbconv), params=EffNetParams(dropout=0.5)),
"efficientnet-widese-b7": _m(arch=replace(effnet_b7_layers, block=widese_mbconv), params=EffNetParams(dropout=0.5)),
"efficientnet-quant-b0": _m(arch=effnet_b0_layers, params=EffNetParams(dropout=0.2, quantized=True)),
"efficientnet-quant-b1": _m(arch=effnet_b1_layers, params=EffNetParams(dropout=0.2, quantized=True)),
"efficientnet-quant-b2": _m(arch=effnet_b2_layers, params=EffNetParams(dropout=0.3, quantized=True)),
"efficientnet-quant-b3": _m(arch=effnet_b3_layers, params=EffNetParams(dropout=0.3, quantized=True)),
"efficientnet-quant-b4": _m(arch=effnet_b4_layers, params=EffNetParams(dropout=0.4, survival_prob=0.8, quantized=True)),
"efficientnet-quant-b5": _m(arch=effnet_b5_layers, params=EffNetParams(dropout=0.4, quantized=True)),
"efficientnet-quant-b6": _m(arch=effnet_b6_layers, params=EffNetParams(dropout=0.5, quantized=True)),
"efficientnet-quant-b7": _m(arch=effnet_b7_layers, params=EffNetParams(dropout=0.5, quantized=True)),
}
# fmt: on
@ -488,3 +508,6 @@ efficientnet_b4 = _ce("efficientnet-b4")
efficientnet_widese_b0 = _ce("efficientnet-widese-b0")
efficientnet_widese_b4 = _ce("efficientnet-widese-b4")
efficientnet_quant_b0 = _ce("efficientnet-quant-b0")
efficientnet_quant_b4 = _ce("efficientnet-quant-b4")

View file

@ -13,7 +13,7 @@ class ModelArch:
@dataclass
class ModelParams:
def parser(self, name):
return argparse.ArgumentParser(description=f"{name} arguments", add_help = False, usage="")
return argparse.ArgumentParser(description=f"{name} arguments", add_help=False, usage="")
@dataclass
@ -44,7 +44,7 @@ class EntryPoint:
state_dict = None
if pretrained:
assert self.model.checkpoint_url is not None
state_dict = torch.hub.load_state_dict_from_url(self.model.checkpoint_url, map_location=torch.device('cpu'))
state_dict = torch.hub.load_state_dict_from_url(self.model.checkpoint_url, map_location=torch.device('cpu'))
if pretrained_from_file is not None:
if os.path.isfile(pretrained_from_file):
@ -63,8 +63,12 @@ class EntryPoint:
# Temporary fix to allow NGC checkpoint loading
if state_dict is not None:
state_dict = {
k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()
}
k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()
}
state_dict = {
k: v.view(v.shape[0], -1, 1, 1) if is_linear_se_weight(k, v) else v for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
return model
@ -89,10 +93,13 @@ class EntryPoint:
return parser
def is_linear_se_weight(key, value):
return (key.endswith('squeeze.weight') or key.endswith('expand.weight')) and len(value.shape) == 2
def create_entrypoint(m: Model):
def _ep(**kwargs):
params = replace(m.params, **kwargs)
return m.constructor(arch=m.arch, **asdict(params))
return _ep

View file

@ -0,0 +1,144 @@
from tqdm import tqdm
import torch
import contextlib
import time
import logging
from pytorch_quantization import quant_modules
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from . import logger as log
from .utils import calc_ips
import dllogger
initialize = quant_modules.initialize
deactivate = quant_modules.deactivate
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
TIME_METADATA = {"unit": "s", "format": ":.5f"}
def select_default_calib_method(calib_method='histogram'):
"""Set up selected calibration method in whole network"""
quant_desc_input = QuantDescriptor(calib_method=calib_method)
quant_nn.QuantConv1d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(quant_desc_input)
def quantization_setup(calib_method='histogram'):
"""Change network into quantized version "automatically" and selects histogram as default quantization method"""
select_default_calib_method(calib_method)
initialize()
def disable_calibration(model):
"""Disables calibration in whole network. Should be run always before running interference."""
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
def collect_stats(model, data_loader, logger, num_batches):
"""Feed data to the network and collect statistic"""
if logger is not None:
logger.register_metric(
f"calib.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
f"calib.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=TIME_METADATA,
)
logger.register_metric(
f"calib.compute_latency",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=TIME_METADATA,
)
# Enable calibrators
data_iter = enumerate(data_loader)
if logger is not None:
data_iter = logger.iteration_generator_wrapper(data_iter, mode='calib')
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
end = time.time()
if logger is not None:
logger.start_calibration()
for i, (image, _) in data_iter:
bs = image.size(0)
data_time = time.time() - end
model(image.cuda())
it_time = time.time() - end
if logger is not None:
logger.log_metric(f"calib.total_ips", calc_ips(bs, it_time))
logger.log_metric(f"calib.data_time", data_time)
logger.log_metric(f"calib.compute_latency", it_time - data_time)
if i >= num_batches:
time.sleep(5)
break
end = time.time()
if logger is not None:
logger.end_calibration()
logging.disable(logging.WARNING)
disable_calibration(model)
def compute_amax(model, **kwargs):
"""Loads statistics of data and calculates quantization parameters in whole network"""
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer) and module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
model.cuda()
def calibrate(model, train_loader, logger, calib_iter=1, percentile=99.99):
"""Calibrates whole network i.e. gathers data for quantization and calculates quantization parameters"""
model.eval()
with torch.no_grad():
collect_stats(model, train_loader, logger, num_batches=calib_iter)
compute_amax(model, method="percentile", percentile=percentile)
logging.disable(logging.NOTSET)
@contextlib.contextmanager
def switch_on_quantization(do_quantization=True):
"""Context manager for quantization activation"""
if do_quantization:
initialize()
try:
yield
finally:
if do_quantization:
deactivate()

View file

@ -274,7 +274,7 @@ def train(
data_iter = enumerate(train_loader)
if logger is not None:
data_iter = logger.iteration_generator_wrapper(data_iter)
data_iter = logger.iteration_generator_wrapper(data_iter, mode='train')
for i, (input, target) in data_iter:
bs = input.size(0)
@ -290,8 +290,8 @@ def train(
if logger is not None:
logger.log_metric("train.loss", loss.item(), bs)
logger.log_metric("train.compute_ips", calc_ips(bs, it_time - data_time))
logger.log_metric("train.total_ips", calc_ips(bs, it_time))
logger.log_metric("train.compute_ips", utils.calc_ips(bs, it_time - data_time))
logger.log_metric("train.total_ips", utils.calc_ips(bs, it_time))
logger.log_metric("train.data_time", data_time)
logger.log_metric("train.compute_time", it_time - data_time)
@ -413,7 +413,7 @@ def validate(
data_iter = enumerate(val_loader)
if not logger is None:
data_iter = logger.iteration_generator_wrapper(data_iter, val=True)
data_iter = logger.iteration_generator_wrapper(data_iter, mode='val')
for i, (input, target) in data_iter:
bs = input.size(0)
@ -428,8 +428,8 @@ def validate(
logger.log_metric(f"{prefix}.top1", prec1.item(), bs)
logger.log_metric(f"{prefix}.top5", prec5.item(), bs)
logger.log_metric(f"{prefix}.loss", loss.item(), bs)
logger.log_metric(f"{prefix}.compute_ips", calc_ips(bs, it_time - data_time))
logger.log_metric(f"{prefix}.total_ips", calc_ips(bs, it_time))
logger.log_metric(f"{prefix}.compute_ips", utils.calc_ips(bs, it_time - data_time))
logger.log_metric(f"{prefix}.total_ips", utils.calc_ips(bs, it_time))
logger.log_metric(f"{prefix}.data_time", data_time)
logger.log_metric(f"{prefix}.compute_latency", it_time - data_time)
logger.log_metric(f"{prefix}.compute_latency_at95", it_time - data_time)
@ -445,12 +445,6 @@ def validate(
# Train loop {{{
def calc_ips(batch_size, time):
world_size = (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
tbs = world_size * batch_size
return tbs / time
def train_loop(
@ -483,7 +477,9 @@ def train_loop(
if early_stopping_patience > 0:
epochs_since_improvement = 0
backup_prefix = checkpoint_filename[:-len("checkpoint.pth.tar")] if \
checkpoint_filename.endswith("checkpoint.pth.tar") else ""
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
with utils.TimeoutHandler() as timeout_handler:
interrupted = False
@ -546,7 +542,7 @@ def train_loop(
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
):
if should_backup_checkpoint(epoch):
backup_filename = "checkpoint-{}.pth.tar".format(epoch + 1)
backup_filename = "{}checkpoint-{}.pth.tar".format(backup_prefix, epoch + 1)
else:
backup_filename = None
checkpoint_state = {

View file

@ -154,3 +154,11 @@ class TimeoutHandler:
signal.signal(self.sig, self.original_handler)
self.released = True
return True
def calc_ips(batch_size, time):
world_size = (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
tbs = world_size * batch_size
return tbs / time

View file

@ -31,7 +31,7 @@ if __name__ == "__main__":
with open(yaml_args.cfg_file, "r") as cfg_file:
config = yaml.load(cfg_file, Loader=yaml.FullLoader)
cfg = {
**config["precision"][yaml_args.precision],
**config["platform"][yaml_args.platform],
@ -43,11 +43,11 @@ if __name__ == "__main__":
add_parser_arguments(parser)
parser.set_defaults(**cfg)
args, rest = parser.parse_known_args(rest)
model_parser = available_models()[args.arch].parser()
model_args, rest = model_parser.parse_known_args(rest)
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
assert len(rest) == 0, f"Unknown args passed: {rest}"
cudnn.benchmark = True
main(args, model_args)
main(args, model_args, model_arch)

View file

@ -77,7 +77,7 @@ def available_models():
return models
def add_parser_arguments(parser):
def add_parser_arguments(parser, skip_arch=False):
parser.add_argument("data", metavar="DIR", help="path to dataset")
parser.add_argument(
"--data-backend",
@ -94,15 +94,16 @@ def add_parser_arguments(parser):
default="bilinear",
help="interpolation type for resizing images: bilinear, bicubic or triangular(DALI only)",
)
model_names = available_models().keys()
parser.add_argument(
"--arch",
"-a",
metavar="ARCH",
default="resnet50",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
)
if not skip_arch:
model_names = available_models().keys()
parser.add_argument(
"--arch",
"-a",
metavar="ARCH",
default="resnet50",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
)
parser.add_argument(
"-j",
@ -312,14 +313,11 @@ def add_parser_arguments(parser):
type=str,
default=None,
choices=[None, "autoaugment"],
help="augmenation method",
help="augmentation method",
)
def main(args, model_args):
exp_start_time = time.time()
global best_prec1
best_prec1 = 0
def prepare_for_training(args, model_args, model_arch):
args.distributed = False
if "WORLD_SIZE" in os.environ:
@ -413,7 +411,7 @@ def main(args, model_args):
memory_format = (
torch.channels_last if args.memory_format == "nhwc" else torch.contiguous_format
)
model = available_models()[args.arch](
model = model_arch(
**{
k: v
if k != "pretrained"
@ -536,6 +534,18 @@ def main(args, model_args):
print("load ema")
ema.load_state_dict(model_state_ema)
return (model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema,
train_loader_len, batch_size_multiplier, start_epoch)
def main(args, model_args, model_arch):
exp_start_time = time.time()
global best_prec1
best_prec1 = 0
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
train_loop(
model_and_loss,
optimizer,
@ -586,12 +596,13 @@ if __name__ == "__main__":
add_parser_arguments(parser)
args, rest = parser.parse_known_args()
model_args, rest = available_models()[args.arch].parser().parse_known_args(rest)
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
print(model_args)
assert len(rest) == 0, f"Unknown args passed: {rest}"
cudnn.benchmark = True
main(args, model_args)
main(args, model_args, model_arch)

View file

@ -0,0 +1,157 @@
import argparse
import torch
import pytorch_quantization
from image_classification.models import (
resnet50,
resnext101_32x4d,
se_resnext101_32x4d,
efficientnet_b0,
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
)
def available_models():
models = {
m.name: m
for m in [
resnet50,
resnext101_32x4d,
se_resnext101_32x4d,
efficientnet_b0,
efficientnet_b4,
efficientnet_widese_b0,
efficientnet_widese_b4,
efficientnet_quant_b0,
efficientnet_quant_b4,
]
}
return models
def parse_args(parser):
"""
Parse commandline arguments.
"""
model_names = available_models().keys()
parser.add_argument("--arch", "-a", metavar="ARCH", default="resnet50", choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)")
parser.add_argument("--device", metavar="DEVICE", default="cuda", choices=['cpu', 'cuda'],
help="device on which model is settled: cpu, cuda (default: cuda)")
parser.add_argument("--image-size", default=None, type=int, help="resolution of image")
parser.add_argument('--output', type=str, help='Path to converted model')
parser.add_argument("-b", "--batch-size", default=256, type=int, metavar="N",
help="mini-batch size (default: 256) per gpu")
return parser
def final_name(base_name):
splitted = base_name.split('.')
if 'pt' in splitted:
fin_name = base_name.replace('pt', 'onnx')
elif 'pth' in splitted:
fin_name = base_name.replace('pth', 'onnx')
elif len(splitted) > 1:
fin_name = '.'.join(splitted[:-1] + ['onnx'])
else:
fin_name = base_name + '.onnx'
return fin_name
def get_dataloader(image_size, bs, num_classes):
"""return dataloader for inference"""
from image_classification.dataloaders import get_syntetic_loader
def data_loader():
loader, _ = get_syntetic_loader(None, image_size, bs, num_classes, False)
for inp, _ in loader:
yield inp
break
return data_loader()
def prepare_inputs(dataloader, device):
"""load sample inputs to device"""
inputs = []
for batch in dataloader:
if type(batch) is torch.Tensor:
batch_d = batch.to(device)
batch_d = (batch_d, )
inputs.append(batch_d)
else:
batch_d = []
for x in batch:
assert type(x) is torch.Tensor, "input is not a tensor"
batch_d.append(x.to(device))
batch_d = tuple(batch_d)
inputs.append(batch_d)
return inputs
def check_quant_weight_correctness(checkpoint_path, model):
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
state_dict = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()}
quantizers_sd_keys = {f'{n[0]}._amax' for n in model.named_modules() if 'quantizer' in n[0]}
sd_all_keys = quantizers_sd_keys | set(model.state_dict().keys())
assert set(state_dict.keys()) == sd_all_keys, (f'Passed quantized architecture, but following keys are missing in '
f'checkpoint: {list(sd_all_keys - set(state_dict.keys()))}')
def main(args, model_args, model_arch):
quant_arch = args.arch in ['efficientnet-quant-b0', 'efficientnet-quant-b4']
if quant_arch:
pytorch_quantization.nn.modules.tensor_quantizer.TensorQuantizer.use_fb_fake_quant = True
model = model_arch(**model_args.__dict__)
if quant_arch:
check_quant_weight_correctness(model_args.pretrained_from_file, model)
image_size = args.image_size if args.image_size is not None else model.arch.default_image_size
train_loader = get_dataloader(image_size, args.batch_size, model_args.num_classes)
inputs = prepare_inputs(train_loader, args.device)
final_model_path = args.output if args.output is not None else final_name(model_args.pretrained_from_file)
model.to(args.device)
model.eval()
with torch.no_grad():
torch.onnx.export(model,
inputs[0],
final_model_path,
verbose=True,
opset_version=13,
enable_onnx_checker=True,
do_constant_folding=True)
if __name__ == '__main__':
epilog = [
"Based on the architecture picked by --arch flag, you may use the following options:\n"
]
for model, ep in available_models().items():
model_help = "\n".join(ep.parser().format_help().split("\n")[2:])
epilog.append(model_help)
parser = argparse.ArgumentParser(
description="PyTorch ImageNet Training",
epilog="\n".join(epilog),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser = parse_args(parser)
args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
assert len(rest) == 0, f"Unknown args passed: {rest}"
main(args, model_args, model_arch)

View file

@ -0,0 +1,162 @@
# Copyright (c) 2018-2019, NVIDIA CORPORATION
# Copyright (c) 2017- Facebook, Inc
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import random
from copy import deepcopy
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from image_classification.training import *
from image_classification.utils import *
from image_classification.quantization import *
from image_classification.models import efficientnet_quant_b0, efficientnet_quant_b4
from main import prepare_for_training, add_parser_arguments as parse_training
import dllogger
def available_models():
models = {
m.name: m
for m in [
efficientnet_quant_b0,
efficientnet_quant_b4,
]
}
return models
def parse_quantization(parser):
model_names = available_models().keys()
parser.add_argument(
"--arch",
"-a",
metavar="ARCH",
default="efficientnet-quant-b0",
choices=model_names,
help="model architecture: " + " | ".join(model_names) + " (default: efficientnet-quant-b0)",
)
parser.add_argument(
"--skip-calibration",
action="store_true",
help="skip calibration before training, (default: false)",
)
def parse_training_args(parser):
from main import add_parser_arguments
return add_parser_arguments(parser)
def main(args, model_args, model_arch):
exp_start_time = time.time()
global best_prec1
best_prec1 = 0
skip_calibration = args.skip_calibration or args.evaluate or args.resume is not None
select_default_calib_method()
model_and_loss, optimizer, lr_policy, scaler, train_loader, val_loader, logger, ema, model_ema, train_loader_len, \
batch_size_multiplier, start_epoch = prepare_for_training(args, model_args, model_arch)
print(f"RUNNING QUANTIZATION")
if not skip_calibration:
calibrate(model_and_loss.model, train_loader, logger, calib_iter=10)
train_loop(
model_and_loss,
optimizer,
scaler,
lr_policy,
train_loader,
val_loader,
logger,
should_backup_checkpoint(args),
ema=ema,
model_ema=model_ema,
steps_per_epoch=train_loader_len,
use_amp=args.amp,
batch_size_multiplier=batch_size_multiplier,
start_epoch=start_epoch,
end_epoch=min((start_epoch + args.run_epochs), args.epochs)
if args.run_epochs != -1
else args.epochs,
best_prec1=best_prec1,
prof=args.prof,
skip_training=args.evaluate,
skip_validation=args.training_only,
save_checkpoints=args.save_checkpoints,
checkpoint_dir=args.workspace,
checkpoint_filename='quantized_' + args.checkpoint_filename,
)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.end()
print("Experiment ended")
if __name__ == "__main__":
epilog = [
"Based on the architecture picked by --arch flag, you may use the following options:\n"
]
for model, ep in available_models().items():
model_help = "\n".join(ep.parser().format_help().split("\n")[2:])
epilog.append(model_help)
parser = argparse.ArgumentParser(
description="PyTorch ImageNet Training",
epilog="\n".join(epilog),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parse_quantization(parser)
parse_training(parser, skip_arch=True)
args, rest = parser.parse_known_args()
model_arch = available_models()[args.arch]
model_args, rest = model_arch.parser().parse_known_args(rest)
print(model_args)
assert len(rest) == 0, f"Unknown args passed: {rest}"
cudnn.benchmark = True
main(args, model_args, model_arch)

View file

@ -1 +1,2 @@
git+git://github.com/NVIDIA/dllogger.git@26a0f8f1958de2c0c460925ff6102a4d2486d6cc#egg=dllogger
pytorch-quantization