[MRCNN] use native AMP, upgrade opencv-python

This commit is contained in:
Shriya Balaji Palsamudram 2021-10-21 13:21:50 -07:00 committed by Krzysztof Kudrynski
parent cc194960e6
commit abe062867f
11 changed files with 69 additions and 85 deletions

View file

@ -19,7 +19,7 @@ FROM ${FROM_IMAGE_NAME}
RUN pip install --upgrade --no-cache-dir pip \
&& pip install --no-cache-dir \
mlperf-compliance==0.0.10 \
opencv-python==3.4.1.15 \
opencv-python==4.2.0.32 \
git+https://github.com/NVIDIA/dllogger \
yacs

View file

@ -47,7 +47,7 @@ Mask R-CNN is a convolution based neural network for the task of object instance
The repository also contains scripts to interactively launch training, benchmarking and inference routines in a Docker container.
The major differences between the official implementation of the paper and our version of Mask R-CNN are as follows:
- Mixed precision support with [PyTorch AMP](https://github.com/NVIDIA/apex).
- Mixed precision support with [PyTorch AMP](https://pytorch.org/docs/stable/amp.html).
- Gradient accumulation to simulate larger batches.
- Custom fused CUDA kernels for faster computations.
@ -117,7 +117,8 @@ The following features are supported by this model.
| **Feature** | **Mask R-CNN** |
|:---------:|:----------:|
|APEX AMP|Yes|
|APEX AMP|No|
|PyTorch AMP|Yes|
|APEX DDP|Yes|
#### Features
@ -150,22 +151,8 @@ For information about:
#### Enabling mixed precision
In this repository, mixed precision training is enabled by NVIDIAs [APEX](https://github.com/NVIDIA/apex) library. The APEX library has an automatic mixed precision module that allows mixed precision to be enabled with minimal code changes.
Automatic mixed precision can be enabled with the following code changes:
```
from apex import amp
if fp16:
# Wrap optimizer and model
model, optimizer = amp.initialize(model, optimizer, opt_level=<opt_level>, loss_scale="dynamic")
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
```
Where <opt_level> is the optimization level. In the MaskRCNN, "O1" is set as the optimization level. Mixed precision training can be turned on by passing in the argument fp16 to the pre-training and fine-tuning Python scripts. Shell scripts all have a positional argument available to enable mixed precision training.
In this repository, mixed precision training is enabled by using Pytorch's [AMP](https://pytorch.org/docs/stable/amp.html).
#### Enabling TF32
@ -608,6 +595,10 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
### Changelog
October 2021
- Replace APEX AMP with PyTorch native AMP
- Use opencv-python version 4.2.0.32
June 2020
- Updated accuracy and performance tables to include A100 results

View file

@ -19,7 +19,7 @@ FROM ${FROM_IMAGE_NAME}
RUN pip install --upgrade --no-cache-dir pip \
&& pip install --no-cache-dir \
mlperf-compliance==0.0.10 \
opencv-python==3.4.1.15 \
opencv-python==4.2.0.32 \
git+https://github.com/NVIDIA/dllogger \
yacs

View file

@ -318,8 +318,6 @@ _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")
# Precision of input, allowable: (float32, float16)
_C.DTYPE = "float32"
# Enable verbosity in apex.amp
_C.AMP_VERBOSE = False
# Evaluate every epoch
_C.PER_EPOCH_EVAL = False

View file

@ -10,13 +10,6 @@ import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
try:
from apex import amp
use_amp = True
except ImportError:
print('Use APEX for multi-precision via apex.amp')
use_amp = False
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
@ -63,6 +56,8 @@ def do_train(
model.train()
start_training_time = time.time()
end = time.time()
if use_amp:
scaler = torch.cuda.amp.GradScaler(init_scale=8192.0)
for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
data_time = time.time() - end
iteration = iteration + 1
@ -71,7 +66,11 @@ def do_train(
images = images.to(device)
targets = [target.to(device) for target in targets]
loss_dict = model(images, targets)
if use_amp:
with torch.cuda.amp.autocast():
loss_dict = model(images, targets)
else:
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
@ -84,23 +83,27 @@ def do_train(
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
if use_amp:
with amp.scale_loss(losses, optimizer) as scaled_losses:
scaled_losses.backward()
scaler.scale(losses).backward()
else:
losses.backward()
if not cfg.SOLVER.ACCUMULATE_GRAD:
optimizer.step()
def _take_step():
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if not cfg.SOLVER.ACCUMULATE_GRAD:
_take_step()
else:
if (iteration + 1) % cfg.SOLVER.ACCUMULATE_STEPS == 0:
for param in model.parameters():
if param.grad is not None:
param.grad.data.div_(cfg.SOLVER.ACCUMULATE_STEPS)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
_take_step()
batch_time = time.time() - end
end = time.time()

View file

@ -3,10 +3,10 @@
# from ._utils import _C
from maskrcnn_benchmark import _C
from apex import amp
from torch.cuda.amp import custom_fwd
# Only valid with fp32 inputs - give AMP the hint
nms = amp.float_function(_C.nms)
nms = custom_fwd(_C.nms)
# nms.__doc__ = """
# This function performs Non-maximum suppresion"""

View file

@ -8,8 +8,6 @@ from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C
from apex import amp
class _ROIAlign(Function):
@staticmethod
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
@ -55,7 +53,7 @@ class ROIAlign(nn.Module):
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
@amp.float_function
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, input, rois):
return roi_align(
input, rois, self.output_size, self.spatial_scale, self.sampling_ratio

View file

@ -7,8 +7,6 @@ from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C
from apex import amp
class _ROIPool(Function):
@staticmethod
def forward(ctx, input, roi, output_size, spatial_scale):
@ -53,7 +51,7 @@ class ROIPool(nn.Module):
self.output_size = output_size
self.spatial_scale = spatial_scale
@amp.float_function
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, input, rois):
return roi_pool(input, rois, self.output_size, self.spatial_scale)

View file

@ -1,4 +1,4 @@
mlperf-compliance==0.0.10
opencv-python==3.4.1.15
opencv-python==4.2.0.32
yacs
git+https://github.com/NVIDIA/cocoapi.git@nvidia/master#egg=cocoapi&subdirectory=PythonAPI

View file

@ -20,13 +20,6 @@ from maskrcnn_benchmark.utils.miscellaneous import mkdir
from maskrcnn_benchmark.utils.logger import format_step
import dllogger
# Check if we can enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for mixed precision via apex.amp')
def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
parser.add_argument(
@ -100,9 +93,7 @@ def main():
if args.fp16:
use_mixed_precision = True
else:
use_mixed_precision = cfg.DTYPE == "float16"
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
model = amp.initialize(model, opt_level=amp_opt_level)
use_mixed_precision = cfg.DTYPE == "float16"
output_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
@ -122,19 +113,35 @@ def main():
results = []
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
result = inference(
model,
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=output_folder,
skip_eval=args.skip_eval,
dllogger=dllogger,
)
if use_mixed_precision:
with torch.cuda.amp.autocast():
result = inference(
model,
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=output_folder,
skip_eval=args.skip_eval,
dllogger=dllogger,
)
else:
result = inference(
model,
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=output_folder,
skip_eval=args.skip_eval,
dllogger=dllogger,
)
synchronize()
results.append(result)

View file

@ -34,13 +34,6 @@ import dllogger
from maskrcnn_benchmark.utils.logger import format_step
# See if we can use apex.DistributedDataParallel instead of the torch default,
# and enable mixed-precision via apex.amp
try:
from apex import amp
use_amp = True
except ImportError:
print('Use APEX for multi-precision via apex.amp')
use_amp = False
try:
from apex.parallel import DistributedDataParallel as DDP
use_apex_ddp = True
@ -98,15 +91,11 @@ def train(cfg, local_rank, distributed, fp16, dllogger):
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)
if use_amp:
# Initialize mixed-precision training
if fp16:
use_mixed_precision = True
else:
use_mixed_precision = cfg.DTYPE == "float16"
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)
use_amp = False
if fp16:
use_amp = True
else:
use_amp = cfg.DTYPE == "float16"
if distributed:
if use_apex_ddp: