[nnUnet/PyT] Notebook and minor corrections

This commit is contained in:
kkudrynski 2021-02-01 15:34:21 +01:00
parent a0abdfffe6
commit d07eb83bdb
21 changed files with 1429 additions and 88 deletions

View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2021 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -31,13 +31,14 @@ This repository provides a script and recipe to train the nnU-Net model to achie
* [Inference performance benchmark](#inference-performance-benchmark)
* [Results](#results)
* [Training accuracy results](#training-accuracy-results)
* [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
* [Training accuracy: NVIDIA DGX A100 (8x A100 80G)](#training-accuracy-nvidia-dgx-a100-8x-a100-80g)
* [Training accuracy: NVIDIA DGX-1 (8x V100 16G)](#training-accuracy-nvidia-dgx-1-8x-v100-16g)
* [Training performance results](#training-performance-results)
* [Training performance: NVIDIA DGX A100 80G](#training-performance-nvidia-dgx-a100-80G)
* [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
* [Training performance: NVIDIA DGX A100 (8x A100 80G)](#training-performance-nvidia-dgx-a100-8x-a100-80g)
* [Training performance: NVIDIA DGX-1 (8x V100 16G)](#training-performance-nvidia-dgx-1-8x-v100-16g)
* [Inference performance results](#inference-performance-results)
* [Inference performance: NVIDIA DGX A100 80G](#inference-performance-nvidia-dgx-a100-80G)
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
* [Inference performance: NVIDIA DGX A100 (1x A100 80G)](#inference-performance-nvidia-dgx-a100-1x-a100-80g)
* [Inference performance: NVIDIA DGX-1 (1x V100 16G)](#inference-performance-nvidia-dgx-1-1x-v100-16g)
- [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
@ -45,19 +46,20 @@ This repository provides a script and recipe to train the nnU-Net model to achie
## Model overview
The nnU-Net ("no-new-Net") refers to a robust and self-adapting framework for U-Net based medical image segmentation. This repository contains a nnU-Net implementation as described in the paper: [nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation](https://arxiv.org/abs/1809.10486).
    
The differences between this nnU-net and [original model](https://github.com/MIC-DKFZ/nnUNet) are:
- Dynamic selection of patch size and spacings for low resolution U-Net are not supported and they need to be set in `data_preprocessing/configs.py` file.
- Cascaded U-Net is not supported.
- The following data augmentations are not used: rotation, simulation of low resolution, gamma augmentation.
  - Dynamic selection of patch size is not supported, and it has to be set in `data_preprocessing/configs.py` file.
  - Cascaded U-Net is not supported.
  - The following data augmentations are not used: rotation, simulation of low resolution, gamma augmentation.
    
This model is trained with mixed precision using Tensor Cores on Volta, Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results 2x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
### Model architecture
    
The nnU-Net allows training two types of networks: 2D U-Net and 3D U-Net to perform semantic segmentation of 3D images, with high accuracy and performance.
The following figure shows the architecture of the 3D U-Net model and its different components. U-Net is composed of a contractive and an expanding path, that aims at building a bottleneck in its center-most part through a combination of convolution, instance norm and leaky relu operations. After this bottleneck, the image is reconstructed through a combination of convolutions and upsampling. Skip connections are added with the goal of helping the backward flow of gradients in order to improve the training.
    
The following figure shows the architecture of the 3D U-Net model and its different components. U-Net is composed of a contractive and an expanding path, that aims at building a bottleneck in its centremost part through a combination of convolution, instance norm and leaky relu operations. After this bottleneck, the image is reconstructed through a combination of convolutions and upsampling. Skip connections are added with the goal of helping the backward flow of gradients in order to improve the training.
<img src="images/unet3d.png" width="900"/>
@ -65,13 +67,13 @@ The following figure shows the architecture of the 3D U-Net model and its differ
### Default configuration
All convolution blocks in U-Net in both encoder and decoder are using two convolution layers followed by instance normalization and a leaky ReLU nonlinearity. For downsampling we are using strided convolution whereas transposed convolution for upsampling.
All convolution blocks in U-Net in both encoder and decoder are using two convolution layers followed by instance normalization and a leaky ReLU nonlinearity. For downsampling we are using stride convolution whereas transposed convolution for upsampling.
All models were trained with RAdam optimizer, learning rate 0.001 and weight_decay 0.0001. For loss function we use the average of [cross-entropy](https://en.wikipedia.org/wiki/Cross_entropy) and [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient).
All models were trained with RAdam optimizer, learning rate 0.001 and weight decay 0.0001. For loss function we use the average of [cross-entropy](https://en.wikipedia.org/wiki/Cross_entropy) and [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient).
Early stopping is triggered if validation dice score wasn't improved during the last 100 epochs.
Used data augmentation: crop with oversampling the foreground class, mirroring, zoom, gaussian noise, gaussian blur, brightness.
Used data augmentation: crop with oversampling the foreground class, mirroring, zoom, Gaussian noise, Gaussian blur, brightness.
### Feature support matrix
@ -114,7 +116,7 @@ For information about:
#### Enabling mixed precision
For training and inference, mixed precision can be enabled by adding the `--amp` flag. Mixed precision is using [native Pytorch implementation](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/).
For training and inference, mixed precision can be enabled by adding the `--amp` flag. Mixed precision is using [native PyTorch implementation](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/).
#### TF32
@ -144,7 +146,7 @@ The following section lists the requirements that you need to meet in order to s
This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
- PyTorch 20.12 NGC container
- PyTorch 21.02 NGC container
- Supported GPUs:
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/geforce/turing/)
@ -166,15 +168,15 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
Executing this command will create your local repository with all the code to run nnU-Net.
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/Pytorch/Segmentation/nnunet_pyt
cd DeepLearningExamples/PyTorch/Segmentation/nnUNet
```
2. Build the nnU-Net PyTorch NGC container.
This command will use the Dockerfile to create a Docker image named `nnunet_pyt`, downloading all the required components automatically.
This command will use the Dockerfile to create a Docker image named `nnunet`, downloading all the required components automatically.
```
docker build -t nnunet_pyt .
docker build -t nnunet .
```
The NGC container contains all the components optimized for usage on NVIDIA hardware.
@ -185,24 +187,19 @@ The following command will launch the container and mount the `./data` directory
```
mkdir data results
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/data:/data -v ${PWD}/results:/results nnunet_pyt:latest /bin/bash
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/data:/data -v ${PWD}/results:/results nnunet:latest /bin/bash
```
4. Prepare BraTS dataset.
To download dataset run:
To download and preprocess the data run:
```
python download.py --task 01
python preprocess.py --task 01 --dim 3
python preprocess.py --task 01 --dim 2
```
then to preprocess 2D or 3D dataset version run:
```
python preprocess.py --task 01 --dim {2,3}
```
If you have prepared both 2D and 3D datasets then `ls /data` should print:
Then `ls /data` should print:
```
01_3d 01_2d Task01_BrainTumour
```
@ -271,32 +268,38 @@ In the root directory, the most important files are:
* `Dockerfile`: Container with the basic set of dependencies to run nnU-Net.
* `requirements.txt:` Set of extra requirements for running nnU-Net.
The `data_preprocessing/` folder contains information about the data preprocessing used by nnU-Net. Its contents are:
The `data_preprocessing` folder contains information about the data preprocessing used by nnU-Net. Its contents are:
* `configs.py`: Defines dataset configuration like patch size or spacing.
* `preprocessor.py`: Implements data preprocessing pipeline.
* `convert2tfrec.py`: Implements conversion from numpy files to tfrecords.
* `convert2tfrec.py`: Implements conversion from NumPy files to tfrecords.
The `data_loading/` folder contains information about the data pipeline used by nnU-Net. Its contents are:
The `data_loading` folder contains information about the data pipeline used by nnU-Net. Its contents are:
* `data_module.py`: Defines `LightningDataModule` used by PyTorch Lightning.
* `dali_loader.py`: Implements DALI data loader.
The `model/` folder contains information about the building blocks of nnU-Net and the way they are assembled. Its contents are:
The `models` folder contains information about the building blocks of nnU-Net and the way they are assembled. Its contents are:
* `layers.py`: Implements convolution blocks used by U-Net template.
* `metrics.py`: Implements metrics and loss function.
* `metrics.py`: Implements dice metric
* `loss.py`: Implements loss function.
* `nn_unet.py`: Implements training/validation/test logic and dynamic creation of U-Net architecture used by nnU-Net.
* `unet.py`: Implements the U-Net template.
The `utils/` folder includes:
The `utils` folder includes:
* `utils.py`: Defines some utility functions e.g. parser initialization.
* `logger.py`: Defines logging callback for performance benchmarking.
The `notebooks` folder includes:
* `custom_dataset.ipynb`: Shows instructions how to use nnU-Net for custom dataset.
Other folders included in the root directory are:
* `images/`: Contains a model diagram.
* `scripts/`: Provides scripts for data preprocessing, training, benchmarking and inference of nnU-Net.
* `scripts/`: Provides scripts for training, benchmarking and inference of nnU-Net.
### Parameters
@ -314,6 +317,7 @@ The complete list of the available parameters for the `main.py` script contains:
* `--dim`: U-Net dimension (default: `3`)
* `--amp`: Enable automatic mixed precision (default: `False`)
* `--negative_slope` Negative slope for LeakyReLU (default: `0.01`)
* `--gradient_clip_val`: Gradient clipping value (default: `0`)
* `--fold`: Fold number (default: `0`)
* `--nfolds`: Number of cross-validation folds (default: `5`)
* `--patience`: Early stopping patience (default: `50`)
@ -346,7 +350,6 @@ The complete list of the available parameters for the `main.py` script contains:
* `--test_batches`: Limit number of batches for evaluation/inference (default: 0)
* `--affinity`: Type of CPU affinity (default: `socket_unique_interleaved`)
* `--save_ckpt`: Enable saving checkpoint (default: `False`)
* `--gradient_clip_val`: Gradient clipping value (default: `0`)
### Command-line options
@ -357,7 +360,7 @@ To see the full list of available options and their descriptions, use the `-h` o
The following example output is printed when running the model:
```
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--num_nodes NUM_NODES] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--deep_supervision] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--create_idx] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--oversampling OVERSAMPLING] [--norm {instance,batch,group}] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,adam,adamw,radam,fused_adam}] [--val_mode {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data_dim {2,3}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--deep_supervision] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--nvol NVOL] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--create_idx] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--oversampling OVERSAMPLING] [--norm {instance,batch,group}] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,adam,adamw,radam,fused_adam}] [--val_mode {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
optional arguments:
-h, --help show this help message and exit
@ -390,6 +393,7 @@ optional arguments:
Patience for ReduceLROnPlateau scheduler (default: 70)
--batch_size BATCH_SIZE
Batch size (default: 2)
--nvol NVOL For 2D effective batch size is batch_size*nvol (default: 1)
--val_batch_size VAL_BATCH_SIZE
Validation batch size (default: 4)
--steps STEPS [STEPS ...]
@ -406,7 +410,7 @@ optional arguments:
--num_workers NUM_WORKERS
Number of subprocesses to use for data loading (default: 8)
--min_epochs MIN_EPOCHS
Force training for at least these many epochs (default: 100)
Force training for at least these many epochs (default: 30)
--max_epochs MAX_EPOCHS
Stop training after this number of epochs (default: 10000)
--warmup WARMUP Warmup iterations before collecting statistics (default: 5)
@ -416,17 +420,7 @@ optional arguments:
Normalization layer (default: instance)
--overlap OVERLAP Amount of overlap between scans during sliding window inference (default: 0.25)
--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}
type of GPU affinity (default: socket_unique_interleaved)
--scheduler {none,multistep,cosine,plateau}
Learning rate scheduler (default: none)
--optimizer {sgd,adam,adamw,radam,fused_adam}
Optimizer (default: radam)
--val_mode {gaussian,constant}
How to blend output of overlapping windows (default: gaussian)
--train_batches TRAIN_BATCHES
Limit number of batches for training (used for benchmarking mode only) (default: 0)
--test_batches TEST_BATCHES
Limit number of batches for inference (used for benchmarking mode only) (default: 0)
type of CPU affinity (default: socket_unique_interleaved)
```
### Getting the data
@ -440,13 +434,13 @@ To train nnU-Net you will need to preprocess your dataset as a first step with `
The `preprocess.py` script is using the following command-line options:
```
--data Path to data directory (default: `/data`)
--results Path to directory for saving preprocessed data (default: `/data`)
--exec_mode Mode for data preprocessing
--task Number of tasks to be run. MSD uses numbers 01-10
--dim Data dimension to prepare (default: `3`)
--n_jobs Number of parallel jobs for data preprocessing (default: `-1`)
--vpf Number of volumes per tfrecord (default: `1`)
  --data               Path to data directory (default: `/data`)
  --results            Path to directory for saving preprocessed data (default: `/data`)
  --exec_mode          Mode for data preprocessing
  --task               Number of tasks to be run. MSD uses numbers 01-10
  --dim                Data dimension to prepare (default: `3`)
  --n_jobs             Number of parallel jobs for data preprocessing (default: `-1`) 
  --vpf                Number of volumes per tfrecord (default: `1`) 
```
To preprocess data for 3D U-Net run: `python preprocess.py --task 01 --dim 3`
@ -455,24 +449,24 @@ In `data_preprocessing/configs.py` for each [Medical Segmentation Decathlon](htt
The preprocessing pipeline consists of the following steps:
1. Cropping to the region of nonzero values.
1. Cropping to the region of non-zero values.
2. Resampling to the median voxel spacing of their respective dataset (exception for anisotropic datasets where the lowest resolution axis is selected to be the 10th percentile of the spacings).
3. Padding volumes so that dimensions are at least as patch size.
4. Normalizing
* For CT modalities the voxel values are clipped to 0.5 and 99.5 percentiles of the foreground voxels and then data is normalized with mean and standard deviation from collected from foreground voxels.
* For MRI modalities z-score normalization is applied.
    * For CT modalities the voxel values are clipped to 0.5 and 99.5 percentiles of the foreground voxels and then data is normalized with mean and standard deviation from collected from foreground voxels.
    * For MRI modalities z-score normalization is applied.
#### Multi-dataset
Adding your dataset is possible, however, your data should correspond to [Medical Segmentation Decathlon](http://medicaldecathlon.com/) (i.e data should be `NIfTi` format and there should be `dataset.json` file where you need to provide fields: modality, labels and at least one of training, test).
Adding your dataset is possible, however, your data should correspond to [Medical Segmentation Decathlon](http://medicaldecathlon.com/) (i.e. data should be `NIfTi` format and there should be `dataset.json` file where you need to provide fields: modality, labels and at least one of training, test).
To add your dataset, perform the following:
1. Mount your dataset to `/data` directory.
 
2. In `data_preprocessing/config.py`:
- Add to the `task_dir` dictionary your dataset directory name. For example, for Brain Tumour dataset, it corresponds to `"01": "Task01_BrainTumour"`.
- Add the patch size that you want to use for training to the `patch_size` dictionary. For example, for Brain Tumour dataset it corresponds to `"01_3d": [128, 128, 128]` for 3D U-Net and `"01_2d": [192, 160]` for 2D U-Net. There are three types of suffixes `_3d, _2d` they correspond to 3D UNet and 2D U-Net.
    - Add to the `task_dir` dictionary your dataset directory name. For example, for Brain Tumour dataset, it corresponds to `"01": "Task01_BrainTumour"`.
    - Add the patch size that you want to use for training to the `patch_size` dictionary. For example, for Brain Tumour dataset it corresponds to `"01_3d": [128, 128, 128]` for 3D U-Net and `"01_2d": [192, 160]` for 2D U-Net. There are three types of suffixes `_3d, _2d` they correspond to 3D UNet and 2D U-Net.
3. Preprocess your data with `preprocess.py` scripts. For example, to preprocess Brain Tumour dataset for 2D U-Net you should run `python preprocess.py --task 01 --dim 2`.
@ -484,7 +478,7 @@ The model trains for at least `--min_epochs` and at most `--max_epochs` epochs.
This default parametrization is applied when running scripts from the `./scripts` directory and when running `main.py` without explicitly overriding these parameters. By default, the training is in full precision. To enable AMP, pass the `--amp` flag. AMP can be enabled for every mode of execution.
The default configuration minimizes a function `L = 0.5 * (1 - dice) + 0.5 * cross entropy` during training and reports achieved convergence as [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) per class. The training, with a combination of dice and cross entropy has been proven to achieve better convergence than a training using only dice.
The default configuration minimizes a function `L = (1 - dice_coefficient) + cross_entropy` during training and reports achieved convergence as [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) per class. The training, with a combination of dice and cross entropy has been proven to achieve better convergence than a training using only dice.
The training can be run directly without using the predefined scripts. The name of the training script is `main.py`. For example:
@ -562,9 +556,9 @@ The following sections provide details on how to achieve the same performance an
#### Training accuracy results
##### Training accuracy: NVIDIA DGX A100 80G
##### Training accuracy: NVIDIA DGX A100 (8x A100 80G)
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 20.12 NGC container on NVIDIA DGX A100 with (8x A100 80G) GPUs.
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - TF32| Time to train speedup (TF32 to mixed precision)
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
@ -573,9 +567,9 @@ Our results were obtained by running the `python scripts/train.py --gpus {1,8} -
| 3 | 1 | 2 |0.7436 |0.7433 |241 min|342 min| 1.42 |
| 3 | 8 | 2 |0.7443 |0.7443 |36 min | 44 min| 1.22 |
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - FP32 | Time to train speedup (FP32 to mixed precision)
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
@ -586,9 +580,9 @@ Our results were obtained by running the `python scripts/train.py --gpus {1,8} -
#### Training performance results
##### Training performance: NVIDIA DGX A100 80G
##### Training performance: NVIDIA DGX A100 (8x A100 80G)
Our results were obtained by running the `python scripts/benchmark.py --mode train --gpus {1,8} --dim {2,3} --batch_size <bsize> [--amp]` training script in the NGC container on NVIDIA DGX A100 80G GPUs. Performance numbers (in volumes per second) were averaged over an entire training epoch.
Our results were obtained by running the `python scripts/benchmark.py --mode train --gpus {1,8} --dim {2,3} --batch_size <bsize> [--amp]` training script in the NGC container on NVIDIA DGX A100 (8x A100 80G) GPUs. Performance numbers (in volumes per second) were averaged over an entire training epoch.
| Dimension | GPUs | Batch size / GPU | Throughput - mixed precision [img/s] | Throughput - TF32 [img/s] | Throughput speedup (TF32 - mixed precision) | Weak scaling - mixed precision | Weak scaling - TF32 |
|:-:|:-:|:--:|:------:|:------:|:-----:|:-----:|:-----:|
@ -608,9 +602,9 @@ Our results were obtained by running the `python scripts/benchmark.py --mode tra
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
##### Training performance: NVIDIA DGX-1 (8x V100 16G)
Our results were obtained by running the `python scripts/benchmark.py --mode train --gpus {1,8} --dim {2,3} --batch_size <bsize> [--amp]` training script in the PyTorch 20.10 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in volumes per second) were averaged over an entire training epoch.
Our results were obtained by running the `python scripts/benchmark.py --mode train --gpus {1,8} --dim {2,3} --batch_size <bsize> [--amp]` training script in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs. Performance numbers (in volumes per second) were averaged over an entire training epoch.
| Dimension | GPUs | Batch size / GPU | Throughput - mixed precision [img/s] | Throughput - FP32 [img/s] | Throughput speedup (FP32 - mixed precision) | Weak scaling - mixed precision | Weak scaling - FP32 |
|:-:|:-:|:---:|:---------:|:-----------:|:--------:|:---------:|:-------------:|
@ -632,9 +626,9 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
#### Inference performance results
##### Inference performance: NVIDIA DGX A100 80G
##### Inference performance: NVIDIA DGX A100 (1x A100 80G)
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 20.10 NGC container on NVIDIA DGX A100 80G GPU.
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 20.12 NGC container on NVIDIA DGX A100 (1x A100 80G) GPU.
FP16
@ -664,9 +658,9 @@ Throughput is reported in images per second. Latency is reported in milliseconds
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
##### Inference performance: NVIDIA DGX-1 (1x V100 16G)
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 20.10 NGC container on NVIDIA DGX-1 with (1x V100 16GB) GPU.
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (1x V100 16G) GPU.
FP16
@ -699,6 +693,7 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
January 2021
- Initial release
- Add notebook with custom dataset loading
### Known issues

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from subprocess import call

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
task = {
"01": "Task01_BrainTumour",
"02": "Task02_Heart",

View file

@ -1,3 +1,18 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from glob import glob

View file

@ -1,3 +1,18 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import json
import math
@ -11,7 +26,8 @@ from joblib import Parallel, delayed
from skimage.transform import resize
from utils.utils import get_task_code, make_empty_dir
from data_preprocessing.configs import ct_max, ct_mean, ct_min, ct_std, patch_size, spacings, task
from data_preprocessing.configs import (ct_max, ct_mean, ct_min, ct_std,
patch_size, spacings, task)
class Preprocessor:

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from subprocess import call

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pyprof

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import monai
import torch
import torch.nn as nn
@ -47,4 +61,4 @@ class Loss(nn.Module):
def forward(self, y_pred, y_true):
dice = self.dice(y_pred, y_true)
cross_entropy = self.cross_entropy(y_pred, y_true[:, 0].long())
return (dice + cross_entropy) / 2
return dice + cross_entropy

View file

@ -1,6 +1,19 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import apex
import numpy as np
import pytorch_lightning as pl
import torch
@ -48,12 +61,16 @@ class NNUnet(pl.LightningModule):
def training_step(self, batch, batch_idx):
img, lbl = batch["image"], batch["label"]
if self.args.dim == 2 and len(lbl.shape) == 3:
lbl = lbl.unsqueeze(1)
pred = self.model(img)
loss = self.compute_loss(pred, lbl)
return loss
def validation_step(self, batch, batch_idx):
img, lbl = batch["image"], batch["label"]
if self.args.dim == 2 and len(lbl.shape) == 3:
lbl = lbl.unsqueeze(1)
pred = self.forward(img)
loss = self.loss(pred, lbl)
dice = self.dice(pred, lbl[:, 0])
@ -125,6 +142,8 @@ class NNUnet(pl.LightningModule):
def do_inference(self, image):
if self.args.dim == 2:
if self.args.data2d_dim == 2:
return self.model(image)
if self.args.exec_mode == "predict" and not self.args.benchmark:
return self.inference2d_test(image)
return self.inference2d(image)
@ -215,9 +234,6 @@ class NNUnet(pl.LightningModule):
"adam": torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
"adamw": torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
"radam": optim.RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
"fused_adam": apex.optimizers.FusedAdam(
self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay
),
}[self.args.optimizer.lower()]
scheduler = {

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from models.layers import ConvBlock, OutputBlock, UpsampleBlock

File diff suppressed because one or more lines are too long

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from os.path import dirname

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from os.path import dirname

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from os.path import dirname

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import math
import os

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
import time

View file

@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
@ -59,7 +73,7 @@ def float_0_1(value):
return ivalue
def get_main_args():
def get_main_args(strings=None):
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--exec_mode",
@ -140,6 +154,14 @@ def get_main_args():
],
help="type of CPU affinity",
)
parser.add_argument(
"--data2d_dim",
choices=[2, 3],
type=int,
default=3,
help="Input data dimension for 2d model",
)
parser.add_argument(
"--scheduler",
type=str,
@ -173,5 +195,14 @@ def get_main_args():
default=0,
help="Limit number of batches for inference (used for benchmarking mode only)",
)
args = parser.parse_args()
if strings is not None:
parser.add_argument(
"strings",
metavar="STRING",
nargs="*",
help="String for searching",
)
args = parser.parse_args(strings.split())
else:
args = parser.parse_args()
return args