diff --git a/PyTorch/Segmentation/MaskRCNN/Dockerfile b/PyTorch/Segmentation/MaskRCNN/Dockerfile index 0c017d7e..2b6f8594 100644 --- a/PyTorch/Segmentation/MaskRCNN/Dockerfile +++ b/PyTorch/Segmentation/MaskRCNN/Dockerfile @@ -19,7 +19,7 @@ FROM ${FROM_IMAGE_NAME} RUN pip install --upgrade --no-cache-dir pip \ && pip install --no-cache-dir \ mlperf-compliance==0.0.10 \ - opencv-python==3.4.1.15 \ + opencv-python==4.2.0.32 \ git+https://github.com/NVIDIA/dllogger \ yacs diff --git a/PyTorch/Segmentation/MaskRCNN/README.md b/PyTorch/Segmentation/MaskRCNN/README.md index 21060da9..796def46 100755 --- a/PyTorch/Segmentation/MaskRCNN/README.md +++ b/PyTorch/Segmentation/MaskRCNN/README.md @@ -47,7 +47,7 @@ Mask R-CNN is a convolution based neural network for the task of object instance The repository also contains scripts to interactively launch training, benchmarking and inference routines in a Docker container. The major differences between the official implementation of the paper and our version of Mask R-CNN are as follows: - - Mixed precision support with [PyTorch AMP](https://github.com/NVIDIA/apex). + - Mixed precision support with [PyTorch AMP](https://pytorch.org/docs/stable/amp.html). - Gradient accumulation to simulate larger batches. - Custom fused CUDA kernels for faster computations. @@ -117,7 +117,8 @@ The following features are supported by this model. | **Feature** | **Mask R-CNN** | |:---------:|:----------:| -|APEX AMP|Yes| +|APEX AMP|No| +|PyTorch AMP|Yes| |APEX DDP|Yes| #### Features @@ -150,22 +151,8 @@ For information about: #### Enabling mixed precision -In this repository, mixed precision training is enabled by NVIDIA’s [APEX](https://github.com/NVIDIA/apex) library. The APEX library has an automatic mixed precision module that allows mixed precision to be enabled with minimal code changes. - -Automatic mixed precision can be enabled with the following code changes: - -``` -from apex import amp -if fp16: - # Wrap optimizer and model - model, optimizer = amp.initialize(model, optimizer, opt_level=, loss_scale="dynamic") - -if fp16: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - ``` - -Where is the optimization level. In the MaskRCNN, "O1" is set as the optimization level. Mixed precision training can be turned on by passing in the argument fp16 to the pre-training and fine-tuning Python scripts. Shell scripts all have a positional argument available to enable mixed precision training. +In this repository, mixed precision training is enabled by using Pytorch's [AMP](https://pytorch.org/docs/stable/amp.html). + #### Enabling TF32 @@ -608,6 +595,10 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic ### Changelog +October 2021 +- Replace APEX AMP with PyTorch native AMP +- Use opencv-python version 4.2.0.32 + June 2020 - Updated accuracy and performance tables to include A100 results diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/Dockerfile b/PyTorch/Segmentation/MaskRCNN/pytorch/Dockerfile index 6e4c8d51..73a2acda 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/Dockerfile +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/Dockerfile @@ -19,7 +19,7 @@ FROM ${FROM_IMAGE_NAME} RUN pip install --upgrade --no-cache-dir pip \ && pip install --no-cache-dir \ mlperf-compliance==0.0.10 \ - opencv-python==3.4.1.15 \ + opencv-python==4.2.0.32 \ git+https://github.com/NVIDIA/dllogger \ yacs diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/config/defaults.py b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/config/defaults.py index 4d7b2590..80238831 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/config/defaults.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/config/defaults.py @@ -318,8 +318,6 @@ _C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") # Precision of input, allowable: (float32, float16) _C.DTYPE = "float32" -# Enable verbosity in apex.amp -_C.AMP_VERBOSE = False # Evaluate every epoch _C.PER_EPOCH_EVAL = False diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/engine/trainer.py b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/engine/trainer.py index 36fd6ffb..b8a19fa5 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/engine/trainer.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/engine/trainer.py @@ -10,13 +10,6 @@ import torch.distributed as dist from maskrcnn_benchmark.utils.comm import get_world_size from maskrcnn_benchmark.utils.metric_logger import MetricLogger -try: - from apex import amp - use_amp = True -except ImportError: - print('Use APEX for multi-precision via apex.amp') - use_amp = False - def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank @@ -63,6 +56,8 @@ def do_train( model.train() start_training_time = time.time() end = time.time() + if use_amp: + scaler = torch.cuda.amp.GradScaler(init_scale=8192.0) for iteration, (images, targets, _) in enumerate(data_loader, start_iter): data_time = time.time() - end iteration = iteration + 1 @@ -71,7 +66,11 @@ def do_train( images = images.to(device) targets = [target.to(device) for target in targets] - loss_dict = model(images, targets) + if use_amp: + with torch.cuda.amp.autocast(): + loss_dict = model(images, targets) + else: + loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -84,23 +83,27 @@ def do_train( # Note: If mixed precision is not used, this ends up doing nothing # Otherwise apply loss scaling for mixed-precision recipe if use_amp: - with amp.scale_loss(losses, optimizer) as scaled_losses: - scaled_losses.backward() + scaler.scale(losses).backward() else: losses.backward() - if not cfg.SOLVER.ACCUMULATE_GRAD: - optimizer.step() + def _take_step(): + if use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() scheduler.step() optimizer.zero_grad() + + if not cfg.SOLVER.ACCUMULATE_GRAD: + _take_step() else: if (iteration + 1) % cfg.SOLVER.ACCUMULATE_STEPS == 0: for param in model.parameters(): if param.grad is not None: param.grad.data.div_(cfg.SOLVER.ACCUMULATE_STEPS) - optimizer.step() - scheduler.step() - optimizer.zero_grad() + _take_step() batch_time = time.time() - end end = time.time() diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/nms.py b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/nms.py index 8c057a63..7bffeb02 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/nms.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/nms.py @@ -3,10 +3,10 @@ # from ._utils import _C from maskrcnn_benchmark import _C -from apex import amp +from torch.cuda.amp import custom_fwd # Only valid with fp32 inputs - give AMP the hint -nms = amp.float_function(_C.nms) +nms = custom_fwd(_C.nms) # nms.__doc__ = """ # This function performs Non-maximum suppresion""" diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_align.py b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_align.py index 46832aba..c930bc28 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_align.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_align.py @@ -8,8 +8,6 @@ from torch.nn.modules.utils import _pair from maskrcnn_benchmark import _C -from apex import amp - class _ROIAlign(Function): @staticmethod def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): @@ -55,7 +53,7 @@ class ROIAlign(nn.Module): self.spatial_scale = spatial_scale self.sampling_ratio = sampling_ratio - @amp.float_function + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, input, rois): return roi_align( input, rois, self.output_size, self.spatial_scale, self.sampling_ratio diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_pool.py b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_pool.py index 58633907..6871ae7f 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_pool.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/layers/roi_pool.py @@ -7,8 +7,6 @@ from torch.nn.modules.utils import _pair from maskrcnn_benchmark import _C -from apex import amp - class _ROIPool(Function): @staticmethod def forward(ctx, input, roi, output_size, spatial_scale): @@ -53,7 +51,7 @@ class ROIPool(nn.Module): self.output_size = output_size self.spatial_scale = spatial_scale - @amp.float_function + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, input, rois): return roi_pool(input, rois, self.output_size, self.spatial_scale) diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/requirements.txt b/PyTorch/Segmentation/MaskRCNN/pytorch/requirements.txt index 5c634323..872ba7bd 100644 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/requirements.txt +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/requirements.txt @@ -1,4 +1,4 @@ mlperf-compliance==0.0.10 -opencv-python==3.4.1.15 +opencv-python==4.2.0.32 yacs git+https://github.com/NVIDIA/cocoapi.git@nvidia/master#egg=cocoapi&subdirectory=PythonAPI diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/tools/test_net.py b/PyTorch/Segmentation/MaskRCNN/pytorch/tools/test_net.py index b8ac7d46..338be82b 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/tools/test_net.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/tools/test_net.py @@ -20,13 +20,6 @@ from maskrcnn_benchmark.utils.miscellaneous import mkdir from maskrcnn_benchmark.utils.logger import format_step import dllogger -# Check if we can enable mixed-precision via apex.amp -try: - from apex import amp -except ImportError: - raise ImportError('Use APEX for mixed precision via apex.amp') - - def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference") parser.add_argument( @@ -100,9 +93,7 @@ def main(): if args.fp16: use_mixed_precision = True else: - use_mixed_precision = cfg.DTYPE == "float16" - amp_opt_level = 'O1' if use_mixed_precision else 'O0' - model = amp.initialize(model, opt_level=amp_opt_level) + use_mixed_precision = cfg.DTYPE == "float16" output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) @@ -122,19 +113,35 @@ def main(): results = [] for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): - result = inference( - model, - data_loader_val, - dataset_name=dataset_name, - iou_types=iou_types, - box_only=cfg.MODEL.RPN_ONLY, - device=cfg.MODEL.DEVICE, - expected_results=cfg.TEST.EXPECTED_RESULTS, - expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, - output_folder=output_folder, - skip_eval=args.skip_eval, - dllogger=dllogger, - ) + if use_mixed_precision: + with torch.cuda.amp.autocast(): + result = inference( + model, + data_loader_val, + dataset_name=dataset_name, + iou_types=iou_types, + box_only=cfg.MODEL.RPN_ONLY, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=output_folder, + skip_eval=args.skip_eval, + dllogger=dllogger, + ) + else: + result = inference( + model, + data_loader_val, + dataset_name=dataset_name, + iou_types=iou_types, + box_only=cfg.MODEL.RPN_ONLY, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=output_folder, + skip_eval=args.skip_eval, + dllogger=dllogger, + ) synchronize() results.append(result) diff --git a/PyTorch/Segmentation/MaskRCNN/pytorch/tools/train_net.py b/PyTorch/Segmentation/MaskRCNN/pytorch/tools/train_net.py index 3b9ef27b..c30d337c 100755 --- a/PyTorch/Segmentation/MaskRCNN/pytorch/tools/train_net.py +++ b/PyTorch/Segmentation/MaskRCNN/pytorch/tools/train_net.py @@ -34,13 +34,6 @@ import dllogger from maskrcnn_benchmark.utils.logger import format_step # See if we can use apex.DistributedDataParallel instead of the torch default, -# and enable mixed-precision via apex.amp -try: - from apex import amp - use_amp = True -except ImportError: - print('Use APEX for multi-precision via apex.amp') - use_amp = False try: from apex.parallel import DistributedDataParallel as DDP use_apex_ddp = True @@ -98,15 +91,11 @@ def train(cfg, local_rank, distributed, fp16, dllogger): optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) - if use_amp: - # Initialize mixed-precision training - if fp16: - use_mixed_precision = True - else: - use_mixed_precision = cfg.DTYPE == "float16" - - amp_opt_level = 'O1' if use_mixed_precision else 'O0' - model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) + use_amp = False + if fp16: + use_amp = True + else: + use_amp = cfg.DTYPE == "float16" if distributed: if use_apex_ddp: