Merge pull request #849 from NVIDIA/gh/release
[nnUnet/PyT] Update to 21.02
This commit is contained in:
commit
32fbd288d1
|
@ -1,4 +1,4 @@
|
||||||
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.12-py3
|
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.02-py3
|
||||||
FROM ${FROM_IMAGE_NAME}
|
FROM ${FROM_IMAGE_NAME}
|
||||||
|
|
||||||
ADD . /workspace/nnunet_pyt
|
ADD . /workspace/nnunet_pyt
|
||||||
|
@ -8,9 +8,9 @@ RUN pip install --upgrade pip
|
||||||
RUN pip install --disable-pip-version-check -r requirements.txt
|
RUN pip install --disable-pip-version-check -r requirements.txt
|
||||||
RUN pip install pytorch-lightning==1.0.0 --no-dependencies
|
RUN pip install pytorch-lightning==1.0.0 --no-dependencies
|
||||||
RUN pip install monai==0.4.0 --no-dependencies
|
RUN pip install monai==0.4.0 --no-dependencies
|
||||||
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.29.0
|
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.30.0
|
||||||
RUN pip install torch_optimizer==0.0.1a15 --no-dependencies
|
RUN pip install torch_optimizer==0.0.1a15 --no-dependencies
|
||||||
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
||||||
RUN unzip awscliv2.zip
|
RUN unzip -qq awscliv2.zip
|
||||||
RUN ./aws/install
|
RUN ./aws/install
|
||||||
RUN rm -rf awscliv2.zip aws
|
RUN rm -rf awscliv2.zip aws
|
||||||
|
|
|
@ -18,7 +18,6 @@ This repository provides a script and recipe to train the nnU-Net model to achie
|
||||||
- [Quick Start Guide](#quick-start-guide)
|
- [Quick Start Guide](#quick-start-guide)
|
||||||
- [Advanced](#advanced)
|
- [Advanced](#advanced)
|
||||||
* [Scripts and sample code](#scripts-and-sample-code)
|
* [Scripts and sample code](#scripts-and-sample-code)
|
||||||
* [Parameters](#parameters)
|
|
||||||
* [Command-line options](#command-line-options)
|
* [Command-line options](#command-line-options)
|
||||||
* [Getting the data](#getting-the-data)
|
* [Getting the data](#getting-the-data)
|
||||||
* [Dataset guidelines](#dataset-guidelines)
|
* [Dataset guidelines](#dataset-guidelines)
|
||||||
|
@ -46,19 +45,18 @@ This repository provides a script and recipe to train the nnU-Net model to achie
|
||||||
## Model overview
|
## Model overview
|
||||||
|
|
||||||
The nnU-Net ("no-new-Net") refers to a robust and self-adapting framework for U-Net based medical image segmentation. This repository contains a nnU-Net implementation as described in the paper: [nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation](https://arxiv.org/abs/1809.10486).
|
The nnU-Net ("no-new-Net") refers to a robust and self-adapting framework for U-Net based medical image segmentation. This repository contains a nnU-Net implementation as described in the paper: [nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation](https://arxiv.org/abs/1809.10486).
|
||||||
|
|
||||||
The differences between this nnU-net and [original model](https://github.com/MIC-DKFZ/nnUNet) are:
|
|
||||||
|
|
||||||
- Dynamic selection of patch size is not supported, and it has to be set in `data_preprocessing/configs.py` file.
|
The differences between this nnU-net and [original model](https://github.com/MIC-DKFZ/nnUNet) are:
|
||||||
- Cascaded U-Net is not supported.
|
- Dynamic selection of patch size is not supported, and it has to be set in `data_preprocessing/configs.py` file.
|
||||||
- The following data augmentations are not used: rotation, simulation of low resolution, gamma augmentation.
|
- Cascaded U-Net is not supported.
|
||||||
|
- The following data augmentations are not used: rotation, simulation of low resolution, gamma augmentation.
|
||||||
|
|
||||||
This model is trained with mixed precision using Tensor Cores on Volta, Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results 2x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
|
This model is trained with mixed precision using Tensor Cores on Volta, Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results 2x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
|
||||||
|
|
||||||
### Model architecture
|
### Model architecture
|
||||||
|
|
||||||
The nnU-Net allows training two types of networks: 2D U-Net and 3D U-Net to perform semantic segmentation of 3D images, with high accuracy and performance.
|
The nnU-Net allows training two types of networks: 2D U-Net and 3D U-Net to perform semantic segmentation of 3D images, with high accuracy and performance.
|
||||||
|
|
||||||
The following figure shows the architecture of the 3D U-Net model and its different components. U-Net is composed of a contractive and an expanding path, that aims at building a bottleneck in its centremost part through a combination of convolution, instance norm and leaky relu operations. After this bottleneck, the image is reconstructed through a combination of convolutions and upsampling. Skip connections are added with the goal of helping the backward flow of gradients in order to improve the training.
|
The following figure shows the architecture of the 3D U-Net model and its different components. U-Net is composed of a contractive and an expanding path, that aims at building a bottleneck in its centremost part through a combination of convolution, instance norm and leaky relu operations. After this bottleneck, the image is reconstructed through a combination of convolutions and upsampling. Skip connections are added with the goal of helping the backward flow of gradients in order to improve the training.
|
||||||
|
|
||||||
<img src="images/unet3d.png" width="900"/>
|
<img src="images/unet3d.png" width="900"/>
|
||||||
|
@ -213,44 +211,40 @@ Training can be started with:
|
||||||
python scripts/train.py --gpus <gpus> --fold <fold> --dim <dim> [--amp]
|
python scripts/train.py --gpus <gpus> --fold <fold> --dim <dim> [--amp]
|
||||||
```
|
```
|
||||||
|
|
||||||
Where:
|
To see descriptions of the train script arguments run `python scripts/train.py --help`. You can customize the training process. For details, see the [Training process](#training-process) section.
|
||||||
```
|
|
||||||
--gpus number of gpus
|
|
||||||
--fold fold number, possible choices: `0, 1, 2, 3, 4`
|
|
||||||
--dim U-Net dimension, possible choices: `2, 3`
|
|
||||||
--amp enable automatic mixed precision
|
|
||||||
```
|
|
||||||
You can customize the training process. For details, see the [Training process](#training-process) section.
|
|
||||||
|
|
||||||
6. Start benchmarking.
|
6. Start benchmarking.
|
||||||
|
|
||||||
The training and inference performance can be evaluated by using benchmarking scripts, such as:
|
The training and inference performance can be evaluated by using benchmarking scripts, such as:
|
||||||
|
|
||||||
```
|
```
|
||||||
python scripts/benchmark.py --mode {train, predict} --gpus <ngpus> --dim {2,3} --batch_size <bsize> [--amp]
|
python scripts/benchmark.py --mode {train,predict} --gpus <ngpus> --dim {2,3} --batch_size <bsize> [--amp]
|
||||||
```
|
```
|
||||||
|
|
||||||
which will make the model run and report the performance.
|
To see descriptions of the benchmark script arguments run `python scripts/benchmark.py --help`.
|
||||||
|
|
||||||
|
|
||||||
7. Start inference/predictions.
|
7. Start inference/predictions.
|
||||||
|
|
||||||
Inference can be started with:
|
Inference can be started with:
|
||||||
```
|
```
|
||||||
python scripts/inference.py --dim <dim> --fold <fold> --ckpt_path <path/to/checkpoint> [--amp] [--tta] [--save_preds]
|
python scripts/inference.py --data <path/to/data> --dim <dim> --fold <fold> --ckpt_path <path/to/checkpoint> [--amp] [--tta] [--save_preds]
|
||||||
```
|
```
|
||||||
|
|
||||||
Where:
|
Note: You have to prepare either validation or test dataset to run this script by running `python preprocess.py --task 01 --dim {2,3} --exec_mode {val,test}`. After preprocessing inside given task directory (e.g. `/data/01_3d/` for task 01 and dim 3) it will create `val` or `test` directory with preprocessed data ready for inference. Possible workflow:
|
||||||
|
|
||||||
```
|
```
|
||||||
--dim U-Net dimension. Possible choices: `2, 3`
|
python preprocess.py --task 01 --dim 3 --exec_mode val
|
||||||
--fold fold number. Possible choices: `0, 1, 2, 3, 4`
|
python scripts/inference.py --data /data/01_3d/val --dim 3 --fold 0 --ckpt_path <path/to/checkpoint> --amp --tta --save_preds
|
||||||
--val_batch_size batch size (default: 4)
|
|
||||||
--ckpt_path path to checkpoint
|
|
||||||
--amp enable automatic mixed precision
|
|
||||||
--tta enable test time augmentation
|
|
||||||
--save_preds enable saving prediction masks
|
|
||||||
```
|
```
|
||||||
You can customize the inference process. For details, see the [Inference process](#inference-process) section.
|
|
||||||
|
Then if you have labels for predicted images you can evaluate it with `evaluate.py` script. For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
python evaluate.py --preds /results/preds_task_01_dim_3_fold_0_tta --lbls /data/Task01_BrainTumour/labelsTr
|
||||||
|
```
|
||||||
|
|
||||||
|
To see descriptions of the inference script arguments run `python scripts/inference.py --help`. You can customize the inference process. For details, see the [Inference process](#inference-process) section.
|
||||||
|
|
||||||
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark yours performance to [Training performance benchmark](#training-performance-results), or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
|
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark yours performance to [Training performance benchmark](#training-performance-results), or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
|
||||||
|
|
||||||
|
@ -267,6 +261,7 @@ In the root directory, the most important files are:
|
||||||
* `download.py`: Downloads given dataset from [Medical Segmentation Decathlon](http://medicaldecathlon.com/).
|
* `download.py`: Downloads given dataset from [Medical Segmentation Decathlon](http://medicaldecathlon.com/).
|
||||||
* `Dockerfile`: Container with the basic set of dependencies to run nnU-Net.
|
* `Dockerfile`: Container with the basic set of dependencies to run nnU-Net.
|
||||||
* `requirements.txt:` Set of extra requirements for running nnU-Net.
|
* `requirements.txt:` Set of extra requirements for running nnU-Net.
|
||||||
|
* `evaluate.py`: Compare predictions with ground truth and get final score.
|
||||||
|
|
||||||
The `data_preprocessing` folder contains information about the data preprocessing used by nnU-Net. Its contents are:
|
The `data_preprocessing` folder contains information about the data preprocessing used by nnU-Net. Its contents are:
|
||||||
|
|
||||||
|
@ -301,55 +296,6 @@ Other folders included in the root directory are:
|
||||||
* `images/`: Contains a model diagram.
|
* `images/`: Contains a model diagram.
|
||||||
* `scripts/`: Provides scripts for training, benchmarking and inference of nnU-Net.
|
* `scripts/`: Provides scripts for training, benchmarking and inference of nnU-Net.
|
||||||
|
|
||||||
### Parameters
|
|
||||||
|
|
||||||
The complete list of the available parameters for the `main.py` script contains:
|
|
||||||
|
|
||||||
* `--exec_mode`: Select the execution mode to run the model (default: `train`). Modes available:
|
|
||||||
- `train` - Trains model with validation evaluation after every epoch.
|
|
||||||
- `evaluate` - Loads checkpoint and performs evaluation on validation set (requires `--fold`).
|
|
||||||
- `predict` - Loads checkpoint and runs inference on the validation set. If flag `--save_preds` is also provided then stores the predictions in the `--results_dir` directory.
|
|
||||||
* `--data`: Path to data directory (default: `/data`)
|
|
||||||
* `--results` Path to results directory (default: `/results`)
|
|
||||||
* `--logname` Name of dlloger output (default: `None`)
|
|
||||||
* `--task` Task number. MSD uses numbers 01-10"
|
|
||||||
* `--gpus`: Number of GPUs (default: `1`)
|
|
||||||
* `--dim`: U-Net dimension (default: `3`)
|
|
||||||
* `--amp`: Enable automatic mixed precision (default: `False`)
|
|
||||||
* `--negative_slope` Negative slope for LeakyReLU (default: `0.01`)
|
|
||||||
* `--gradient_clip_val`: Gradient clipping value (default: `0`)
|
|
||||||
* `--fold`: Fold number (default: `0`)
|
|
||||||
* `--nfolds`: Number of cross-validation folds (default: `5`)
|
|
||||||
* `--patience`: Early stopping patience (default: `50`)
|
|
||||||
* `--min_epochs`: Force training for at least these many epochs (default: `100`)
|
|
||||||
* `--max_epochs`: Stop training after this number of epochs (default: `10000`)
|
|
||||||
* `--batch_size`: Batch size (default: `2`)
|
|
||||||
* `--val_batch_size`: Validation batch size (default: `4`)
|
|
||||||
* `--tta`: Enable test time augmentation (default: `False`)
|
|
||||||
* `--deep_supervision`: Enable deep supervision (default: `False`)
|
|
||||||
* `--benchmark`: Run model benchmarking (default: `False`)
|
|
||||||
* `--norm`: Normalization layer, one from: {`instance,batch,group`} (default: `instance`)
|
|
||||||
* `--oversampling`: Probability of cropped area to have foreground pixels (default: `0.33`)
|
|
||||||
* `--optimizer`: Optimizer, one from: {`sgd,adam,adamw,radam,fused_adam`} (default: `radam`)
|
|
||||||
* `--learning_rate`: Learning rate (default: `0.001`)
|
|
||||||
* `--momentum`: Momentum factor (default: `0.99`)
|
|
||||||
* `--scheduler`: Learning rate scheduler, one from: {`none,multistep,cosine,plateau`} (default: `none`)
|
|
||||||
* `--steps`: Steps for multi-step scheduler (default: `None`)
|
|
||||||
* `--factor`: Factor used by `multistep` and `reduceLROnPlateau` schedulers (default: `0.1`)
|
|
||||||
* `--lr_patience`: Patience for ReduceLROnPlateau scheduler (default: `75`)
|
|
||||||
* `--weight_decay`: Weight decay (L2 penalty) (default: `0.0001`)
|
|
||||||
* `--seed`: Random seed (default: `1`)
|
|
||||||
* `--num_workers`: Number of subprocesses to use for data loading (default: `8`)
|
|
||||||
* `--resume_training`: Resume training from the last checkpoint (default: `False`)
|
|
||||||
* `--overlap`: Amount of overlap between scans during sliding window inference (default: `0.25`)
|
|
||||||
* `--val_mode`: How to blend output of overlapping windows one from: {`gaussian,constant`} (default: `gaussian`)
|
|
||||||
* `--ckpt_path`: Path to checkpoint
|
|
||||||
* `--save_preds`: Enable prediction saving (default: `False`)
|
|
||||||
* `--warmup`: Warmup iterations before collecting statistics for model benchmarking. (default: `5`)
|
|
||||||
* `--train_batches`: Limit number of batches for training (default: 0)
|
|
||||||
* `--test_batches`: Limit number of batches for evaluation/inference (default: 0)
|
|
||||||
* `--affinity`: Type of CPU affinity (default: `socket_unique_interleaved`)
|
|
||||||
* `--save_ckpt`: Enable saving checkpoint (default: `False`)
|
|
||||||
|
|
||||||
### Command-line options
|
### Command-line options
|
||||||
|
|
||||||
|
@ -360,7 +306,7 @@ To see the full list of available options and their descriptions, use the `-h` o
|
||||||
The following example output is printed when running the model:
|
The following example output is printed when running the model:
|
||||||
|
|
||||||
```
|
```
|
||||||
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data_dim {2,3}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--deep_supervision] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--nvol NVOL] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--create_idx] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--oversampling OVERSAMPLING] [--norm {instance,batch,group}] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,adam,adamw,radam,fused_adam}] [--val_mode {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
|
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--deep_supervision] [--drop_block] [--attention] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,radam,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
@ -381,10 +327,16 @@ optional arguments:
|
||||||
--amp Enable automatic mixed precision (default: False)
|
--amp Enable automatic mixed precision (default: False)
|
||||||
--benchmark Run model benchmarking (default: False)
|
--benchmark Run model benchmarking (default: False)
|
||||||
--deep_supervision Enable deep supervision (default: False)
|
--deep_supervision Enable deep supervision (default: False)
|
||||||
|
--drop_block Enable drop block (default: False)
|
||||||
|
--attention Enable attention in decoder (default: False)
|
||||||
|
--residual Enable residual block in encoder (default: False)
|
||||||
|
--focal Use focal loss instead of cross entropy (default: False)
|
||||||
--sync_batchnorm Enable synchronized batchnorm (default: False)
|
--sync_batchnorm Enable synchronized batchnorm (default: False)
|
||||||
--save_ckpt Enable saving checkpoint (default: False)
|
--save_ckpt Enable saving checkpoint (default: False)
|
||||||
--nfolds NFOLDS Number of cross-validation folds (default: 5)
|
--nfolds NFOLDS Number of cross-validation folds (default: 5)
|
||||||
--seed SEED Random seed (default: 1)
|
--seed SEED Random seed (default: 1)
|
||||||
|
--skip_first_n_eval SKIP_FIRST_N_EVAL
|
||||||
|
Skip the evaluation for the first n epochs. (default: 0)
|
||||||
--ckpt_path CKPT_PATH
|
--ckpt_path CKPT_PATH
|
||||||
Path to checkpoint (default: None)
|
Path to checkpoint (default: None)
|
||||||
--fold FOLD Fold number (default: 0)
|
--fold FOLD Fold number (default: 0)
|
||||||
|
@ -393,12 +345,10 @@ optional arguments:
|
||||||
Patience for ReduceLROnPlateau scheduler (default: 70)
|
Patience for ReduceLROnPlateau scheduler (default: 70)
|
||||||
--batch_size BATCH_SIZE
|
--batch_size BATCH_SIZE
|
||||||
Batch size (default: 2)
|
Batch size (default: 2)
|
||||||
--nvol NVOL For 2D effective batch size is batch_size*nvol (default: 1)
|
|
||||||
--val_batch_size VAL_BATCH_SIZE
|
--val_batch_size VAL_BATCH_SIZE
|
||||||
Validation batch size (default: 4)
|
Validation batch size (default: 4)
|
||||||
--steps STEPS [STEPS ...]
|
--steps STEPS [STEPS ...]
|
||||||
Steps for multistep scheduler (default: None)
|
Steps for multistep scheduler (default: None)
|
||||||
--create_idx Create index files for tfrecord (default: False)
|
|
||||||
--profile Run dlprof profiling (default: False)
|
--profile Run dlprof profiling (default: False)
|
||||||
--momentum MOMENTUM Momentum factor (default: 0.99)
|
--momentum MOMENTUM Momentum factor (default: 0.99)
|
||||||
--weight_decay WEIGHT_DECAY
|
--weight_decay WEIGHT_DECAY
|
||||||
|
@ -414,13 +364,25 @@ optional arguments:
|
||||||
--max_epochs MAX_EPOCHS
|
--max_epochs MAX_EPOCHS
|
||||||
Stop training after this number of epochs (default: 10000)
|
Stop training after this number of epochs (default: 10000)
|
||||||
--warmup WARMUP Warmup iterations before collecting statistics (default: 5)
|
--warmup WARMUP Warmup iterations before collecting statistics (default: 5)
|
||||||
--oversampling OVERSAMPLING
|
|
||||||
Probability of crop to have some region with positive label (default: 0.33)
|
|
||||||
--norm {instance,batch,group}
|
--norm {instance,batch,group}
|
||||||
Normalization layer (default: instance)
|
Normalization layer (default: instance)
|
||||||
--overlap OVERLAP Amount of overlap between scans during sliding window inference (default: 0.25)
|
--nvol NVOL Number of volumes which come into single batch size for 2D model (default: 1)
|
||||||
|
--data2d_dim {2,3} Input data dimension for 2d model (default: 3)
|
||||||
|
--oversampling OVERSAMPLING
|
||||||
|
Probability of crop to have some region with positive label (default: 0.33)
|
||||||
|
--overlap OVERLAP Amount of overlap between scans during sliding window inference (default: 0.5)
|
||||||
--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}
|
--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}
|
||||||
type of CPU affinity (default: socket_unique_interleaved)
|
type of CPU affinity (default: socket_unique_interleaved)
|
||||||
|
--scheduler {none,multistep,cosine,plateau}
|
||||||
|
Learning rate scheduler (default: none)
|
||||||
|
--optimizer {sgd,radam,adam}
|
||||||
|
Optimizer (default: radam)
|
||||||
|
--blend {gaussian,constant}
|
||||||
|
How to blend output of overlapping windows (default: gaussian)
|
||||||
|
--train_batches TRAIN_BATCHES
|
||||||
|
Limit number of batches for training (used for benchmarking mode only) (default: 0)
|
||||||
|
--test_batches TEST_BATCHES
|
||||||
|
Limit number of batches for inference (used for benchmarking mode only) (default: 0)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Getting the data
|
### Getting the data
|
||||||
|
@ -429,21 +391,9 @@ The nnU-Net model was trained on the [Medical Segmentation Decathlon](http://med
|
||||||
|
|
||||||
#### Dataset guidelines
|
#### Dataset guidelines
|
||||||
|
|
||||||
To train nnU-Net you will need to preprocess your dataset as a first step with `preprocess.py` script.
|
To train nnU-Net you will need to preprocess your dataset as a first step with `preprocess.py` script. Run `python scripts/preprocess.py --help` to see descriptions of the preprocess script arguments.
|
||||||
|
|
||||||
The `preprocess.py` script is using the following command-line options:
|
For example to preprocess data for 3D U-Net run: `python preprocess.py --task 01 --dim 3`.
|
||||||
|
|
||||||
```
|
|
||||||
--data Path to data directory (default: `/data`)
|
|
||||||
--results Path to directory for saving preprocessed data (default: `/data`)
|
|
||||||
--exec_mode Mode for data preprocessing
|
|
||||||
--task Number of tasks to be run. MSD uses numbers 01-10
|
|
||||||
--dim Data dimension to prepare (default: `3`)
|
|
||||||
--n_jobs Number of parallel jobs for data preprocessing (default: `-1`)
|
|
||||||
--vpf Number of volumes per tfrecord (default: `1`)
|
|
||||||
```
|
|
||||||
|
|
||||||
To preprocess data for 3D U-Net run: `python preprocess.py --task 01 --dim 3`
|
|
||||||
|
|
||||||
In `data_preprocessing/configs.py` for each [Medical Segmentation Decathlon](http://medicaldecathlon.com/) task there are defined: patch size, precomputed spacings and statistics for CT datasets.
|
In `data_preprocessing/configs.py` for each [Medical Segmentation Decathlon](http://medicaldecathlon.com/) task there are defined: patch size, precomputed spacings and statistics for CT datasets.
|
||||||
|
|
||||||
|
@ -452,31 +402,31 @@ The preprocessing pipeline consists of the following steps:
|
||||||
1. Cropping to the region of non-zero values.
|
1. Cropping to the region of non-zero values.
|
||||||
2. Resampling to the median voxel spacing of their respective dataset (exception for anisotropic datasets where the lowest resolution axis is selected to be the 10th percentile of the spacings).
|
2. Resampling to the median voxel spacing of their respective dataset (exception for anisotropic datasets where the lowest resolution axis is selected to be the 10th percentile of the spacings).
|
||||||
3. Padding volumes so that dimensions are at least as patch size.
|
3. Padding volumes so that dimensions are at least as patch size.
|
||||||
4. Normalizing
|
4. Normalizing:
|
||||||
* For CT modalities the voxel values are clipped to 0.5 and 99.5 percentiles of the foreground voxels and then data is normalized with mean and standard deviation from collected from foreground voxels.
|
* For CT modalities the voxel values are clipped to 0.5 and 99.5 percentiles of the foreground voxels and then data is normalized with mean and standard deviation from collected from foreground voxels.
|
||||||
* For MRI modalities z-score normalization is applied.
|
* For MRI modalities z-score normalization is applied.
|
||||||
|
|
||||||
#### Multi-dataset
|
#### Multi-dataset
|
||||||
|
|
||||||
Adding your dataset is possible, however, your data should correspond to [Medical Segmentation Decathlon](http://medicaldecathlon.com/) (i.e. data should be `NIfTi` format and there should be `dataset.json` file where you need to provide fields: modality, labels and at least one of training, test).
|
It is possible to run nnUNet on custom dataset. If your dataset correspond to [Medical Segmentation Decathlon](http://medicaldecathlon.com/) (i.e. data should be `NIfTi` format and there should be `dataset.json` file where you need to provide fields: modality, labels and at least one of training, test) you need to perform the following:
|
||||||
|
|
||||||
To add your dataset, perform the following:
|
|
||||||
|
|
||||||
1. Mount your dataset to `/data` directory.
|
1. Mount your dataset to `/data` directory.
|
||||||
|
|
||||||
2. In `data_preprocessing/config.py`:
|
2. In `data_preprocessing/config.py`:
|
||||||
- Add to the `task_dir` dictionary your dataset directory name. For example, for Brain Tumour dataset, it corresponds to `"01": "Task01_BrainTumour"`.
|
- Add to the `task_dir` dictionary your dataset directory name. For example, for Brain Tumour dataset, it corresponds to `"01": "Task01_BrainTumour"`.
|
||||||
- Add the patch size that you want to use for training to the `patch_size` dictionary. For example, for Brain Tumour dataset it corresponds to `"01_3d": [128, 128, 128]` for 3D U-Net and `"01_2d": [192, 160]` for 2D U-Net. There are three types of suffixes `_3d, _2d` they correspond to 3D UNet and 2D U-Net.
|
- Add the patch size that you want to use for training to the `patch_size` dictionary. For example, for Brain Tumour dataset it corresponds to `"01_3d": [128, 128, 128]` for 3D U-Net and `"01_2d": [192, 160]` for 2D U-Net. There are three types of suffixes `_3d, _2d` they correspond to 3D UNet and 2D U-Net.
|
||||||
|
|
||||||
3. Preprocess your data with `preprocess.py` scripts. For example, to preprocess Brain Tumour dataset for 2D U-Net you should run `python preprocess.py --task 01 --dim 2`.
|
3. Preprocess your data with `preprocess.py` scripts. For example, to preprocess Brain Tumour dataset for 2D U-Net you should run `python preprocess.py --task 01 --dim 2`.
|
||||||
|
|
||||||
|
If you have dataset in other format or you want customize data preprocessing or data loading see `notebooks/custom_dataset.ipynb`.
|
||||||
|
|
||||||
### Training process
|
### Training process
|
||||||
|
|
||||||
The model trains for at least `--min_epochs` and at most `--max_epochs` epochs. After each epoch evaluation, the validation set is done and validation loss is monitored for early stopping (see `--patience` flag). Default training settings are:
|
The model trains for at least `--min_epochs` and at most `--max_epochs` epochs. After each epoch evaluation, the validation set is done and validation loss is monitored for early stopping (see `--patience` flag). Default training settings are:
|
||||||
* RAdam optimizer with learning rate of 0.001 and weight decay 0.0001.
|
* RAdam optimizer with learning rate of 0.001 and weight decay 0.0001.
|
||||||
* Training batch size is set to 2 for 3D U-Net and 16 for 2D U-Net.
|
* Training batch size is set to 2 for 3D U-Net and 16 for 2D U-Net.
|
||||||
|
|
||||||
This default parametrization is applied when running scripts from the `./scripts` directory and when running `main.py` without explicitly overriding these parameters. By default, the training is in full precision. To enable AMP, pass the `--amp` flag. AMP can be enabled for every mode of execution.
|
This default parametrization is applied when running scripts from the `scripts/` directory and when running `main.py` without explicitly overriding these parameters. By default, the training is in full precision. To enable AMP, pass the `--amp` flag. AMP can be enabled for every mode of execution.
|
||||||
|
|
||||||
The default configuration minimizes a function `L = (1 - dice_coefficient) + cross_entropy` during training and reports achieved convergence as [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) per class. The training, with a combination of dice and cross entropy has been proven to achieve better convergence than a training using only dice.
|
The default configuration minimizes a function `L = (1 - dice_coefficient) + cross_entropy` during training and reports achieved convergence as [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) per class. The training, with a combination of dice and cross entropy has been proven to achieve better convergence than a training using only dice.
|
||||||
|
|
||||||
|
@ -486,9 +436,8 @@ The training can be run directly without using the predefined scripts. The name
|
||||||
python main.py --exec_mode train --task 01 --fold 0 --gpus 1 --amp --deep_supervision
|
python main.py --exec_mode train --task 01 --fold 0 --gpus 1 --amp --deep_supervision
|
||||||
```
|
```
|
||||||
|
|
||||||
Training artifacts will be saved to `/results` (you can override it with `--results <path/to/results/>`) in the container. Some important artifacts are:
|
Training artifacts will be saved to `/results` in the container. Some important artifacts are:
|
||||||
* `/results/logs.json`: Collected dice scores and loss values evaluated after each epoch during training on validation set.
|
* `/results/logs.json`: Collected dice scores and loss values evaluated after each epoch during training on validation set.
|
||||||
* `/results/train_logs.json`: Selected best dice scores achieved during training.
|
|
||||||
* `/results/checkpoints`: Saved checkpoints. By default, two checkpoints are saved - one after each epoch ('last.ckpt') and one with the highest validation dice (e.g 'epoch=5.ckpt' for if highest dice was at 5th epoch).
|
* `/results/checkpoints`: Saved checkpoints. By default, two checkpoints are saved - one after each epoch ('last.ckpt') and one with the highest validation dice (e.g 'epoch=5.ckpt' for if highest dice was at 5th epoch).
|
||||||
|
|
||||||
To load the pretrained model provide `--ckpt_path <path/to/checkpoint>`.
|
To load the pretrained model provide `--ckpt_path <path/to/checkpoint>`.
|
||||||
|
@ -516,37 +465,37 @@ The following section shows how to run benchmarks to measure the model performan
|
||||||
|
|
||||||
#### Training performance benchmark
|
#### Training performance benchmark
|
||||||
|
|
||||||
To benchmark training, run one of the scripts in `./scripts`:
|
To benchmark training, run `scripts/benchmark.py` script with `--mode train`:
|
||||||
|
|
||||||
```
|
```
|
||||||
python scripts/benchmark.py --mode train --gpus <ngpus> --dim {2,3} --batch_size <bsize> [--amp]
|
python scripts/benchmark.py --mode train --gpus <ngpus> --dim {2,3} --batch_size <bsize> [--amp]
|
||||||
```
|
```
|
||||||
|
|
||||||
For example, to benchmark 3D U-Net training using mixed-precision on 8 GPUs with batch size of 2 for 80 batches, run:
|
For example, to benchmark 3D U-Net training using mixed-precision on 8 GPUs with batch size of 2, run:
|
||||||
|
|
||||||
```
|
```
|
||||||
python scripts/benchmark.py --mode train --gpus 8 --dim 3 --batch_size 2 --train_batches 80 --amp
|
python scripts/benchmark.py --mode train --gpus 8 --dim 3 --batch_size 2 --amp
|
||||||
```
|
```
|
||||||
|
|
||||||
Each of these scripts will by default run 10 warm-up iterations and benchmark the performance during the next 70 iterations. To modify these values provide: `--warmup <warmup> --train_batches <number/of/train/batches>`.
|
Each of these scripts will by default run 1 warm-up epoch and start performance benchmarking during the second epoch.
|
||||||
|
|
||||||
At the end of the script, a line reporting the best train throughput and latency will be printed.
|
At the end of the script, a line reporting the best train throughput and latency will be printed.
|
||||||
|
|
||||||
#### Inference performance benchmark
|
#### Inference performance benchmark
|
||||||
|
|
||||||
To benchmark inference, run one of the scripts in `./scripts`:
|
To benchmark inference, run `scripts/benchmark.py` script with `--mode predict`:
|
||||||
|
|
||||||
```
|
```
|
||||||
python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> --test_batches <number/of/test/batches> [--amp]
|
python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]
|
||||||
```
|
```
|
||||||
|
|
||||||
For example, to benchmark inference using mixed-precision for 3D U-Net, with batch size of 4 for 80 batches, run:
|
For example, to benchmark inference using mixed-precision for 3D U-Net, with batch size of 4, run:
|
||||||
|
|
||||||
```
|
```
|
||||||
python scripts/benchmark.py --mode predict --dim 3 --amp --batch_size 4 --test_batches 80
|
python scripts/benchmark.py --mode predict --dim 3 --amp --batch_size 4
|
||||||
```
|
```
|
||||||
|
|
||||||
Each of these scripts will by default run 10 warm-up iterations and benchmark the performance during the next 70 iterations. To modify these values provide: `--warmup <warmup> --test_batches <number/of/test/batches>`.
|
Each of these scripts will by default run warm-up for 1 data pass and start inference benchmarking during the second pass.
|
||||||
|
|
||||||
At the end of the script, a line reporting the inference throughput and latency will be printed.
|
At the end of the script, a line reporting the inference throughput and latency will be printed.
|
||||||
|
|
||||||
|
@ -558,25 +507,26 @@ The following sections provide details on how to achieve the same performance an
|
||||||
|
|
||||||
##### Training accuracy: NVIDIA DGX A100 (8x A100 80G)
|
##### Training accuracy: NVIDIA DGX A100 (8x A100 80G)
|
||||||
|
|
||||||
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 20.12 NGC container on NVIDIA DGX A100 with (8x A100 80G) GPUs.
|
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs.
|
||||||
|
|
||||||
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - TF32| Time to train speedup (TF32 to mixed precision)
|
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - TF32| Time to train speedup (TF32 to mixed precision)
|
||||||
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
|
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
|
||||||
| 2 | 1 | 16 |0.7021 |0.7051 |89min | 104min| 1.17 |
|
| 2 | 1 | 64 | 73.002 | 73.390 | 98 min | 150 min | 1.536 |
|
||||||
| 2 | 8 | 16 |0.7316 |0.7316 |13 min | 17 min| 1.31 |
|
| 2 | 8 | 64 | 72.916 | 73.054 | 17 min | 23 min | 1.295 |
|
||||||
| 3 | 1 | 2 |0.7436 |0.7433 |241 min|342 min| 1.42 |
|
| 3 | 1 | 2 | 74.408 | 74.402 | 118 min | 221 min | 1.869 |
|
||||||
| 3 | 8 | 2 |0.7443 |0.7443 |36 min | 44 min| 1.22 |
|
| 3 | 8 | 2 | 74.350 | 74.292 | 27 min | 46 min | 1.775 |
|
||||||
|
|
||||||
|
|
||||||
##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
|
##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
|
||||||
|
|
||||||
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
|
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
|
||||||
|
|
||||||
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - FP32 | Time to train speedup (FP32 to mixed precision)
|
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - FP32 | Time to train speedup (FP32 to mixed precision)
|
||||||
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
|
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
|
||||||
| 2 | 1 | 16 |0.7034 |0.7033 |144 min|180 min| 1.25 |
|
| 2 | 1 | 64 | 73.316 | 73.200 | 175 min | 342 min | 1.952 |
|
||||||
| 2 | 8 | 16 |0.7319 |0.7315 |37 min |44 min | 1.19 |
|
| 2 | 8 | 64 | 72.886 | 72.954 | 43 min | 52 min | 1.230 |
|
||||||
| 3 | 1 | 2 |0.7439 |0.7436 |317 min|738 min| 2.32 |
|
| 3 | 1 | 2 | 74.378 | 74.324 | 228 min | 667 min | 2.935 |
|
||||||
| 3 | 8 | 2 |0.7440 |0.7441 |58 min |121 min| 2.09 |
|
| 3 | 8 | 2 | 74.29 | 74.378 | 62 min | 141 min | 2.301 |
|
||||||
|
|
||||||
#### Training performance results
|
#### Training performance results
|
||||||
|
|
||||||
|
@ -586,40 +536,37 @@ Our results were obtained by running the `python scripts/benchmark.py --mode tra
|
||||||
|
|
||||||
| Dimension | GPUs | Batch size / GPU | Throughput - mixed precision [img/s] | Throughput - TF32 [img/s] | Throughput speedup (TF32 - mixed precision) | Weak scaling - mixed precision | Weak scaling - TF32 |
|
| Dimension | GPUs | Batch size / GPU | Throughput - mixed precision [img/s] | Throughput - TF32 [img/s] | Throughput speedup (TF32 - mixed precision) | Weak scaling - mixed precision | Weak scaling - TF32 |
|
||||||
|:-:|:-:|:--:|:------:|:------:|:-----:|:-----:|:-----:|
|
|:-:|:-:|:--:|:------:|:------:|:-----:|:-----:|:-----:|
|
||||||
| 2 | 1 | 32 | 674.34 | 489.3 | 1.38 | N/A | N/A |
|
| 2 | 1 | 64 | 1064.46 | 678.86 | 1.568 | N/A | N/A |
|
||||||
| 2 | 1 | 64 | 856.34 | 565.62 | 1.51 | N/A | N/A |
|
| 2 | 1 | 128 | 1129.09 | 710.09 | 1.59 | N/A | N/A |
|
||||||
| 2 | 1 | 128| 926.64 | 600.34 | 1.54 | N/A | N/A |
|
| 2 | 8 | 64 | 6477.99 | 4780.3 | 1.355 | 6.086 | 7.042 |
|
||||||
| 2 | 8 | 32 | 3957.33 | 3275.88 | 1.21| 5.868 | 6.695 |
|
| 2 | 8 | 128 | 8163.67 | 5342.49 | 1.528 | 7.23 | 7.524 |
|
||||||
| 2 | 8 | 64 | 5667.14 | 4037.82 | 1.40 | 6.618 | 7.139 |
|
| 3 | 1 | 1 | 13.39 | 8.46 | 1.583 | N/A | N/A |
|
||||||
| 2 | 8 | 128| 6310.97 | 4568.13 | 1.38 | 6.811 | 7.609 |
|
| 3 | 1 | 2 | 15.97 | 9.52 | 1.678 | N/A | N/A |
|
||||||
| 3 | 1 | 1 | 4.24 | 3.57 | 1.19 | N/A | N/A |
|
| 3 | 1 | 4 | 17.84 | 5.16 | 3.457 | N/A | N/A |
|
||||||
| 3 | 1 | 2 | 6.74 | 5.21 | 1.29 | N/A | N/A |
|
| 3 | 8 | 1 | 92.93 | 61.68 | 1.507 | 6.94 | 7.291 |
|
||||||
| 3 | 1 | 4 | 9.52 | 4.16 | 2.29 | N/A | N/A |
|
| 3 | 8 | 2 | 113.51 | 72.23 | 1.572 | 7.108 | 7.587 |
|
||||||
| 3 | 8 | 1 | 32.48 | 27.79 | 1.17 | 7.66 | 7.78 |
|
| 3 | 8 | 4 | 129.91 | 38.26 | 3.395 | 7.282 | 7.415 |
|
||||||
| 3 | 8 | 2 | 51.50 | 40.67 | 1.27 | 7.64 | 7.81 |
|
|
||||||
| 3 | 8 | 4 | 74.29 | 31.50 | 2.36 | 7.80 | 7.57 |
|
|
||||||
|
|
||||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
|
||||||
|
|
||||||
##### Training performance: NVIDIA DGX-1 (8x V100 16G)
|
##### Training performance: NVIDIA DGX-1 (8x V100 16G)
|
||||||
|
|
||||||
Our results were obtained by running the `python scripts/benchmark.py --mode train --gpus {1,8} --dim {2,3} --batch_size <bsize> [--amp]` training script in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs. Performance numbers (in volumes per second) were averaged over an entire training epoch.
|
Our results were obtained by running the `python scripts/benchmark.py --mode train --gpus {1,8} --dim {2,3} --batch_size <bsize> [--amp]` training script in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs. Performance numbers (in volumes per second) were averaged over an entire training epoch.
|
||||||
|
|
||||||
| Dimension | GPUs | Batch size / GPU | Throughput - mixed precision [img/s] | Throughput - FP32 [img/s] | Throughput speedup (FP32 - mixed precision) | Weak scaling - mixed precision | Weak scaling - FP32 |
|
| Dimension | GPUs | Batch size / GPU | Throughput - mixed precision [img/s] | Throughput - FP32 [img/s] | Throughput speedup (FP32 - mixed precision) | Weak scaling - mixed precision | Weak scaling - FP32 |
|
||||||
|:-:|:-:|:---:|:---------:|:-----------:|:--------:|:---------:|:-------------:|
|
|:-:|:-:|:---:|:---------:|:-----------:|:--------:|:---------:|:-------------:|
|
||||||
| 2 | 1 | 32 | 416.68 | 275.99 | 1.51 | N/A | N/A |
|
| 2 | 1 | 64 | 575.11 | 277.93 | 2.069 | N/A | N/A |
|
||||||
| 2 | 1 | 64 | 524.13 | 281.84 | 1.86 | N/A | N/A |
|
| 2 | 1 | 128 | 612.32 | 268.28 | 2.282 | N/A | N/A |
|
||||||
| 2 | 1 | 128| 557.48 | 272.68 | 2.04 | N/A | N/A |
|
| 2 | 8 | 64 | 4178.94 | 2149.46 | 1.944 | 7.266 | 7.734 |
|
||||||
| 2 | 8 | 32 | 2731.22 | 2005.49 | 1.36 | 6.56 | 7.27 |
|
| 2 | 8 | 128 | 4629.01 | 2087.25 | 2.218 | 7.56 | 7.78 |
|
||||||
| 2 | 8 | 64 | 3604.83 | 2088.58 | 1.73 | 6.88 | 7.41 |
|
| 3 | 1 | 1 | 7.68 | 2.11 | 3.64 | N/A | N/A |
|
||||||
| 2 | 8 | 128| 4202.35 | 2094.63 | 2.01 | 7.54 | 7.68 |
|
| 3 | 1 | 2 | 8.27 | 2.49 | 3.321 | N/A | N/A |
|
||||||
| 3 | 1 | 1 | 3.97 | 1.77 | 2.24 | N/A | N/A |
|
| 3 | 1 | 4 | 8.5 | OOM | N/A | N/A | N/A |
|
||||||
| 3 | 1 | 2 | 5.49 | 2.32 | 2.37 | N/A | N/A |
|
| 3 | 8 | 1 | 56.4 | 16.42 | 3.435 | 7.344 | 7.782 |
|
||||||
| 3 | 1 | 4 | 6.78 | OOM | N/A | N/A | N/A |
|
| 3 | 8 | 2 | 62.46 | 19.46 | 3.21 | 7.553 | 7.815 |
|
||||||
| 3 | 8 | 1 | 29.98 | 13.78 | 2.18 | 7.55 | 7.79 |
|
| 3 | 8 | 4 | 64.46 | OOM | N/A | 7.584 | N/A |
|
||||||
| 3 | 8 | 2 | 41.31 | 18.11 | 2.28 | 7.53 | 7.81 |
|
|
||||||
| 3 | 8 | 4 | 50.26 | OOM | N/A | 7.41 | N/A |
|
|
||||||
|
|
||||||
|
|
||||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
@ -628,31 +575,30 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
||||||
|
|
||||||
##### Inference performance: NVIDIA DGX A100 (1x A100 80G)
|
##### Inference performance: NVIDIA DGX A100 (1x A100 80G)
|
||||||
|
|
||||||
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 20.12 NGC container on NVIDIA DGX A100 (1x A100 80G) GPU.
|
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 21.02 NGC container on NVIDIA DGX A100 (1x A100 80G) GPU.
|
||||||
|
|
||||||
|
|
||||||
FP16
|
FP16
|
||||||
|
|
||||||
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
||||||
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
||||||
| 2 | 32 | 4x192x160 | 3281.91 | 9.75 | 9.88 | 10.14 | 10.17 |
|
| 2 | 64 | 4x192x160 | 3198.8 | 20.01 | 24.1 | 30.5 | 33.75 |
|
||||||
| 2 | 64 | 4x192x160 | 3625.3 | 17.65 | 18.13 | 18.16 | 18.24 |
|
| 2 | 128 | 4x192x160 | 3587.89 | 35.68 | 36.0 | 36.08 | 36.16 |
|
||||||
| 2 |128 | 4x192x160 | 3867.24 | 33.10 | 33.29 | 33.29 | 33.35 |
|
| 3 | 1 | 4x128x128x128 | 47.16 | 21.21 | 21.56 | 21.7 | 22.5 |
|
||||||
| 3 | 1 | 4x128x128x128 | 10.93| 91.52 | 91.30 | 92,68 | 111.87|
|
| 3 | 2 | 4x128x128x128 | 47.59 | 42.02 | 53.9 | 56.97 | 77.3 |
|
||||||
| 3 | 2 | 4x128x128x128 | 18.85| 106.08| 105.12| 106.05| 127.95|
|
| 3 | 4 | 4x128x128x128 | 53.98 | 74.1 | 91.18 | 106.13 | 143.18 |
|
||||||
| 3 | 4 | 4x128x128x128 | 27.4 | 145.98| 164.05| 165.58| 183.43|
|
|
||||||
|
|
||||||
|
|
||||||
TF32
|
TF32
|
||||||
|
|
||||||
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
||||||
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
||||||
| 2 | 32 | 4x192x160 | 2002.66 | 15.98 | 16.14 | 16.24 | 16.37|
|
| 2 | 64 | 4x192x160 | 2353.27 | 27.2 | 27.43 | 27.53 | 27.7 |
|
||||||
| 2 | 64 | 4x192x160 | 2180.54 | 29.35 | 29.50 | 29.51 | 29.59|
|
| 2 | 128 | 4x192x160 | 2492.78 | 51.35 | 51.54 | 51.59 | 51.73 |
|
||||||
| 2 |128 | 4x192x160 | 2289.12 | 55.92 | 56.08 | 56.13 | 56.36|
|
| 3 | 1 | 4x128x128x128 | 34.33 | 29.13 | 29.41 | 29.52 | 29.67 |
|
||||||
| 3 | 1 | 4x128x128x128 | 10.05| 99.55 | 99.17 | 99.82 |120.39|
|
| 3 | 2 | 4x128x128x128 | 37.29 | 53.63 | 52.41 | 60.12 | 84.92 |
|
||||||
| 3 | 2 | 4x128x128x128 | 16.29|122.78 |123.06 |124.02 |143.47|
|
| 3 | 4 | 4x128x128x128 | 22.98 | 174.09 | 173.02 | 196.04 | 231.03 |
|
||||||
| 3 | 4 | 4x128x128x128 | 15.99|250.16 |273.67 |274.85 |297.06|
|
|
||||||
|
|
||||||
Throughput is reported in images per second. Latency is reported in milliseconds per batch.
|
Throughput is reported in images per second. Latency is reported in milliseconds per batch.
|
||||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
@ -660,29 +606,28 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
||||||
|
|
||||||
##### Inference performance: NVIDIA DGX-1 (1x V100 16G)
|
##### Inference performance: NVIDIA DGX-1 (1x V100 16G)
|
||||||
|
|
||||||
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 20.12 NGC container on NVIDIA DGX-1 with (1x V100 16G) GPU.
|
Our results were obtained by running the `python scripts/benchmark.py --mode predict --dim {2,3} --batch_size <bsize> [--amp]` inferencing benchmarking script in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (1x V100 16G) GPU.
|
||||||
|
|
||||||
FP16
|
FP16
|
||||||
|
|
||||||
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
||||||
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
||||||
| 2 | 32 | 4x192x160 | 1697.16 | 18.86 | 18.89 | 18.95 | 18.99 |
|
| 2 | 64 | 4x192x160 | 1866.52 | 34.29 | 34.7 | 48.87 | 52.44 |
|
||||||
| 2 | 64 | 4x192x160 | 2008.81 | 31.86 | 31.95 | 32.01 | 32.08 |
|
| 2 | 128 | 4x192x160 | 2032.74 | 62.97 | 63.21 | 63.25 | 63.32 |
|
||||||
| 2 |128 | 4x192x160 | 2221.44 | 57.62 | 57.83 | 57.88 | 57.96 |
|
| 3 | 1 | 4x128x128x128 | 27.52 | 36.33 | 37.03 | 37.25 | 37.71 |
|
||||||
| 3 | 1 | 4x128x128x128 | 11.01 | 90.76 | 89.96 | 90.53 | 116.67 |
|
| 3 | 2 | 4x128x128x128 | 29.04 | 68.87 | 68.09 | 76.48 | 112.4 |
|
||||||
| 3 | 2 | 4x128x128x128 | 16.60 | 120.49 | 119.69 | 120.72 | 146.42 |
|
| 3 | 4 | 4x128x128x128 | 30.23 | 132.33 | 131.59 | 165.57 | 191.64 |
|
||||||
| 3 | 4 | 4x128x128x128 | 21.18 | 188.85 | 211.92 | 214.17 | 238.19 |
|
|
||||||
|
|
||||||
FP32
|
FP32
|
||||||
|
|
||||||
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|
||||||
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
|
||||||
| 2 | 32 | 4x192x160 | 1106.22 | 28.93 | 29.06 | 29.10 | 29.15 |
|
| 2 | 64 | 4x192x160 | 1051.46 | 60.87 | 61.21 | 61.48 | 62.87 |
|
||||||
| 2 | 64 | 4x192x160 | 1157.24 | 55.30 | 55.39 | 55.44 | 55.50 |
|
| 2 | 128 | 4x192x160 | 1051.68 | 121.71 | 122.29 | 122.44 | 122.6 |
|
||||||
| 2 |128 | 4x192x160 | 1171.24 | 109.29 | 109.83 | 109.98 | 110.58 |
|
| 3 | 1 | 4x128x128x128 | 9.87 | 101.34 | 102.33 | 102.52 | 102.86 |
|
||||||
| 3 | 1 | 4x128x128x128 | 6.8 | 147.10 | 147.51 | 148.15 | 170.46 |
|
| 3 | 2 | 4x128x128x128 | 9.91 | 201.91 | 202.36 | 202.77 | 240.45 |
|
||||||
| 3 | 2 | 4x128x128x128 | 8.53| 234.46 | 237.00 | 238.43 | 258.92 |
|
| 3 | 4 | 4x128x128x128 | 10.0 | 399.91 | 400.94 | 430.72 | 466.62 |
|
||||||
| 3 | 4 | 4x128x128x128 | 9.6 | 416.83 | 439.97 | 442.12 | 454.69 |
|
|
||||||
|
|
||||||
Throughput is reported in images per second. Latency is reported in milliseconds per batch.
|
Throughput is reported in images per second. Latency is reported in milliseconds per batch.
|
||||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
@ -694,6 +639,8 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
||||||
January 2021
|
January 2021
|
||||||
- Initial release
|
- Initial release
|
||||||
- Add notebook with custom dataset loading
|
- Add notebook with custom dataset loading
|
||||||
|
February 2021
|
||||||
|
- Change data format from tfrecord to npy and data loading for 2D
|
||||||
|
|
||||||
### Known issues
|
### Known issues
|
||||||
|
|
||||||
|
|
|
@ -19,45 +19,53 @@ import numpy as np
|
||||||
import nvidia.dali.fn as fn
|
import nvidia.dali.fn as fn
|
||||||
import nvidia.dali.math as math
|
import nvidia.dali.math as math
|
||||||
import nvidia.dali.ops as ops
|
import nvidia.dali.ops as ops
|
||||||
import nvidia.dali.tfrecord as tfrec
|
|
||||||
import nvidia.dali.types as types
|
import nvidia.dali.types as types
|
||||||
from nvidia.dali.pipeline import Pipeline
|
from nvidia.dali.pipeline import Pipeline
|
||||||
from nvidia.dali.plugin.pytorch import DALIGenericIterator
|
from nvidia.dali.plugin.pytorch import DALIGenericIterator
|
||||||
|
|
||||||
|
|
||||||
class TFRecordTrain(Pipeline):
|
def get_numpy_reader(files, shard_id, num_shards, seed, shuffle):
|
||||||
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
return ops.NumpyReader(
|
||||||
super(TFRecordTrain, self).__init__(batch_size, num_threads, device_id)
|
seed=seed,
|
||||||
self.dim = kwargs["dim"]
|
files=files,
|
||||||
self.seed = kwargs["seed"]
|
device="cpu",
|
||||||
self.oversampling = kwargs["oversampling"]
|
|
||||||
self.input = ops.TFRecordReader(
|
|
||||||
path=kwargs["tfrecords"],
|
|
||||||
index_path=kwargs["tfrecords_idx"],
|
|
||||||
features={
|
|
||||||
"X_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
|
|
||||||
"Y_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
|
|
||||||
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
|
|
||||||
"Y": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
},
|
|
||||||
num_shards=kwargs["gpus"],
|
|
||||||
shard_id=device_id,
|
|
||||||
random_shuffle=True,
|
|
||||||
pad_last_batch=True,
|
|
||||||
read_ahead=True,
|
read_ahead=True,
|
||||||
seed=self.seed,
|
shard_id=shard_id,
|
||||||
|
pad_last_batch=True,
|
||||||
|
num_shards=num_shards,
|
||||||
|
dont_use_mmap=True,
|
||||||
|
shuffle_after_epoch=shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainPipeline(Pipeline):
|
||||||
|
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
||||||
|
super(TrainPipeline, self).__init__(batch_size, num_threads, device_id)
|
||||||
|
self.dim = kwargs["dim"]
|
||||||
|
self.oversampling = kwargs["oversampling"]
|
||||||
|
self.input_x = get_numpy_reader(
|
||||||
|
num_shards=kwargs["gpus"],
|
||||||
|
files=kwargs["imgs"],
|
||||||
|
seed=kwargs["seed"],
|
||||||
|
shard_id=device_id,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
self.input_y = get_numpy_reader(
|
||||||
|
num_shards=kwargs["gpus"],
|
||||||
|
files=kwargs["lbls"],
|
||||||
|
seed=kwargs["seed"],
|
||||||
|
shard_id=device_id,
|
||||||
|
shuffle=True,
|
||||||
)
|
)
|
||||||
self.patch_size = kwargs["patch_size"]
|
self.patch_size = kwargs["patch_size"]
|
||||||
|
if self.dim == 2:
|
||||||
|
self.patch_size = [kwargs["batch_size_2d"]] + self.patch_size
|
||||||
self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
|
self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
|
||||||
self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
|
self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
|
||||||
self.layout = "CDHW" if self.dim == 3 else "CHW"
|
|
||||||
self.axis_name = "DHW" if self.dim == 3 else "HW"
|
|
||||||
|
|
||||||
def load_data(self, features):
|
def load_data(self):
|
||||||
img = fn.reshape(features["X"], shape=features["X_shape"], layout=self.layout)
|
img, lbl = self.input_x(name="ReaderX"), self.input_y(name="ReaderY")
|
||||||
lbl = fn.reshape(features["Y"], shape=features["Y_shape"], layout=self.layout)
|
img, lbl = fn.reshape(img, layout="CDHW"), fn.reshape(lbl, layout="CDHW")
|
||||||
lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
|
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
def random_augmentation(self, probability, augmented, original):
|
def random_augmentation(self, probability, augmented, original):
|
||||||
|
@ -66,17 +74,17 @@ class TFRecordTrain(Pipeline):
|
||||||
return condition * augmented + neg_condition * original
|
return condition * augmented + neg_condition * original
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def slice_fn(img, start_idx, length):
|
def slice_fn(img):
|
||||||
return fn.slice(img, start_idx, length, axes=[0])
|
return fn.slice(img, 1, 3, axes=[0])
|
||||||
|
|
||||||
def crop_fn(self, img, lbl):
|
def crop_fn(self, img, lbl):
|
||||||
center = fn.segmentation.random_mask_pixel(lbl, foreground=fn.coin_flip(probability=self.oversampling))
|
center = fn.segmentation.random_mask_pixel(lbl, foreground=fn.coin_flip(probability=self.oversampling))
|
||||||
crop_anchor = self.slice_fn(center, 1, self.dim) - self.crop_shape // 2
|
crop_anchor = self.slice_fn(center) - self.crop_shape // 2
|
||||||
adjusted_anchor = math.max(0, crop_anchor)
|
adjusted_anchor = math.max(0, crop_anchor)
|
||||||
max_anchor = self.slice_fn(fn.shapes(lbl), 1, self.dim) - self.crop_shape
|
max_anchor = self.slice_fn(fn.shapes(lbl)) - self.crop_shape
|
||||||
crop_anchor = math.min(adjusted_anchor, max_anchor)
|
crop_anchor = math.min(adjusted_anchor, max_anchor)
|
||||||
img = fn.slice(img.gpu(), crop_anchor, self.crop_shape, axis_names=self.axis_name, out_of_bounds_policy="pad")
|
img = fn.slice(img.gpu(), crop_anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad")
|
||||||
lbl = fn.slice(lbl.gpu(), crop_anchor, self.crop_shape, axis_names=self.axis_name, out_of_bounds_policy="pad")
|
lbl = fn.slice(lbl.gpu(), crop_anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad")
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
def zoom_fn(self, img, lbl):
|
def zoom_fn(self, img, lbl):
|
||||||
|
@ -87,7 +95,7 @@ class TFRecordTrain(Pipeline):
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
def noise_fn(self, img):
|
def noise_fn(self, img):
|
||||||
img_noised = img + fn.normal_distribution(img, stddev=fn.uniform(range=(0.0, 0.33)))
|
img_noised = img + fn.random.normal(img, stddev=fn.uniform(range=(0.0, 0.33)))
|
||||||
return self.random_augmentation(0.15, img_noised, img)
|
return self.random_augmentation(0.15, img_noised, img)
|
||||||
|
|
||||||
def blur_fn(self, img):
|
def blur_fn(self, img):
|
||||||
|
@ -110,111 +118,113 @@ class TFRecordTrain(Pipeline):
|
||||||
kwargs.update({"depthwise": fn.coin_flip(probability=0.33)})
|
kwargs.update({"depthwise": fn.coin_flip(probability=0.33)})
|
||||||
return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
|
return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
|
||||||
|
|
||||||
|
def transpose_fn(self, img, lbl):
|
||||||
|
img, lbl = fn.transpose(img, perm=(1, 0, 2, 3)), fn.transpose(lbl, perm=(1, 0, 2, 3))
|
||||||
|
return img, lbl
|
||||||
|
|
||||||
def define_graph(self):
|
def define_graph(self):
|
||||||
features = self.input(name="Reader")
|
img, lbl = self.load_data()
|
||||||
img, lbl = self.load_data(features)
|
|
||||||
img, lbl = self.crop_fn(img, lbl)
|
img, lbl = self.crop_fn(img, lbl)
|
||||||
img, lbl = self.zoom_fn(img, lbl)
|
img, lbl = self.zoom_fn(img, lbl)
|
||||||
|
img, lbl = self.flips_fn(img, lbl)
|
||||||
img = self.noise_fn(img)
|
img = self.noise_fn(img)
|
||||||
img = self.blur_fn(img)
|
img = self.blur_fn(img)
|
||||||
img = self.brightness_fn(img)
|
img = self.brightness_fn(img)
|
||||||
img = self.contrast_fn(img)
|
img = self.contrast_fn(img)
|
||||||
img, lbl = self.flips_fn(img, lbl)
|
if self.dim == 2:
|
||||||
|
img, lbl = self.transpose_fn(img, lbl)
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
|
|
||||||
class TFRecordEval(Pipeline):
|
class EvalPipeline(Pipeline):
|
||||||
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
||||||
super(TFRecordEval, self).__init__(batch_size, num_threads, device_id)
|
super(EvalPipeline, self).__init__(batch_size, num_threads, device_id)
|
||||||
self.input = ops.TFRecordReader(
|
self.input_x = get_numpy_reader(
|
||||||
path=kwargs["tfrecords"],
|
files=kwargs["imgs"],
|
||||||
index_path=kwargs["tfrecords_idx"],
|
|
||||||
features={
|
|
||||||
"X_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
|
|
||||||
"Y_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
|
|
||||||
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
|
|
||||||
"Y": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
},
|
|
||||||
shard_id=device_id,
|
shard_id=device_id,
|
||||||
num_shards=kwargs["gpus"],
|
num_shards=kwargs["gpus"],
|
||||||
read_ahead=True,
|
seed=kwargs["seed"],
|
||||||
random_shuffle=False,
|
shuffle=False,
|
||||||
pad_last_batch=True,
|
)
|
||||||
|
self.input_y = get_numpy_reader(
|
||||||
|
files=kwargs["lbls"],
|
||||||
|
shard_id=device_id,
|
||||||
|
num_shards=kwargs["gpus"],
|
||||||
|
seed=kwargs["seed"],
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_data(self, features):
|
def define_graph(self):
|
||||||
img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout="CDHW")
|
img, lbl = self.input_x(name="ReaderX").gpu(), self.input_y(name="ReaderY").gpu()
|
||||||
lbl = fn.reshape(features["Y"].gpu(), shape=features["Y_shape"], layout="CDHW")
|
img, lbl = fn.reshape(img, layout="CDHW"), fn.reshape(lbl, layout="CDHW")
|
||||||
lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
|
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
def define_graph(self):
|
|
||||||
features = self.input(name="Reader")
|
|
||||||
img, lbl = self.load_data(features)
|
|
||||||
return img, lbl, features["fname"]
|
|
||||||
|
|
||||||
|
class TestPipeline(Pipeline):
|
||||||
class TFRecordTest(Pipeline):
|
|
||||||
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
||||||
super(TFRecordTest, self).__init__(batch_size, num_threads, device_id)
|
super(TestPipeline, self).__init__(batch_size, num_threads, device_id)
|
||||||
self.input = ops.TFRecordReader(
|
self.input_x = get_numpy_reader(
|
||||||
path=kwargs["tfrecords"],
|
files=kwargs["imgs"],
|
||||||
index_path=kwargs["tfrecords_idx"],
|
|
||||||
features={
|
|
||||||
"X_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
|
|
||||||
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
|
|
||||||
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
},
|
|
||||||
shard_id=device_id,
|
shard_id=device_id,
|
||||||
num_shards=kwargs["gpus"],
|
num_shards=kwargs["gpus"],
|
||||||
read_ahead=True,
|
seed=kwargs["seed"],
|
||||||
random_shuffle=False,
|
shuffle=False,
|
||||||
pad_last_batch=True,
|
)
|
||||||
|
self.input_meta = get_numpy_reader(
|
||||||
|
files=kwargs["meta"],
|
||||||
|
shard_id=device_id,
|
||||||
|
num_shards=kwargs["gpus"],
|
||||||
|
seed=kwargs["seed"],
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def define_graph(self):
|
def define_graph(self):
|
||||||
features = self.input(name="Reader")
|
img, meta = self.input_x(name="ReaderX").gpu(), self.input_meta(name="ReaderY").gpu()
|
||||||
img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout="CDHW")
|
img = fn.reshape(img, layout="CDHW")
|
||||||
return img, features["fname"]
|
return img, meta
|
||||||
|
|
||||||
|
|
||||||
class TFRecordBenchmark(Pipeline):
|
class BenchmarkPipeline(Pipeline):
|
||||||
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
||||||
super(TFRecordBenchmark, self).__init__(batch_size, num_threads, device_id)
|
super(BenchmarkPipeline, self).__init__(batch_size, num_threads, device_id)
|
||||||
|
self.input_x = get_numpy_reader(
|
||||||
|
files=kwargs["imgs"],
|
||||||
|
shard_id=device_id,
|
||||||
|
seed=kwargs["seed"],
|
||||||
|
num_shards=kwargs["gpus"],
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
self.input_y = get_numpy_reader(
|
||||||
|
files=kwargs["lbls"],
|
||||||
|
shard_id=device_id,
|
||||||
|
num_shards=kwargs["gpus"],
|
||||||
|
seed=kwargs["seed"],
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
self.dim = kwargs["dim"]
|
self.dim = kwargs["dim"]
|
||||||
self.input = ops.TFRecordReader(
|
|
||||||
path=kwargs["tfrecords"],
|
|
||||||
index_path=kwargs["tfrecords_idx"],
|
|
||||||
features={
|
|
||||||
"X_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
|
|
||||||
"Y_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
|
|
||||||
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
|
|
||||||
"Y": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
|
|
||||||
},
|
|
||||||
shard_id=device_id,
|
|
||||||
num_shards=kwargs["gpus"],
|
|
||||||
read_ahead=True,
|
|
||||||
)
|
|
||||||
self.patch_size = kwargs["patch_size"]
|
self.patch_size = kwargs["patch_size"]
|
||||||
self.layout = "CDHW" if self.dim == 3 else "CHW"
|
if self.dim == 2:
|
||||||
|
self.patch_size = [kwargs["batch_size_2d"]] + self.patch_size
|
||||||
|
|
||||||
def load_data(self, features):
|
def load_data(self):
|
||||||
img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout=self.layout)
|
img, lbl = self.input_x(name="ReaderX").gpu(), self.input_y(name="ReaderY").gpu()
|
||||||
lbl = fn.reshape(features["Y"].gpu(), shape=features["Y_shape"], layout=self.layout)
|
img, lbl = fn.reshape(img, layout="CDHW"), fn.reshape(lbl, layout="CDHW")
|
||||||
lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
|
return img, lbl
|
||||||
|
|
||||||
|
def transpose_fn(self, img, lbl):
|
||||||
|
img, lbl = fn.transpose(img, perm=(1, 0, 2, 3)), fn.transpose(lbl, perm=(1, 0, 2, 3))
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
def crop_fn(self, img, lbl):
|
def crop_fn(self, img, lbl):
|
||||||
img = fn.crop(img, crop=self.patch_size)
|
img = fn.crop(img, crop=self.patch_size, out_of_bounds_policy="pad")
|
||||||
lbl = fn.crop(lbl, crop=self.patch_size)
|
lbl = fn.crop(lbl, crop=self.patch_size, out_of_bounds_policy="pad")
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
def define_graph(self):
|
def define_graph(self):
|
||||||
features = self.input(name="Reader")
|
img, lbl = self.load_data()
|
||||||
img, lbl = self.load_data(features)
|
|
||||||
img, lbl = self.crop_fn(img, lbl)
|
img, lbl = self.crop_fn(img, lbl)
|
||||||
|
if self.dim == 2:
|
||||||
|
img, lbl = self.transpose_fn(img, lbl)
|
||||||
return img, lbl
|
return img, lbl
|
||||||
|
|
||||||
|
|
||||||
|
@ -228,39 +238,55 @@ class LightningWrapper(DALIGenericIterator):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def fetch_dali_loader(tfrecords, idx_files, batch_size, mode, **kwargs):
|
def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
|
||||||
assert len(tfrecords) > 0, "Got empty tfrecord list"
|
assert len(imgs) > 0, "Got empty list of images"
|
||||||
assert len(idx_files) == len(tfrecords), f"Got {len(idx_files)} index files but {len(tfrecords)} tfrecords"
|
if lbls is not None:
|
||||||
|
assert len(imgs) == len(lbls), f"Got {len(imgs)} images but {len(lbls)} lables"
|
||||||
|
|
||||||
if kwargs["benchmark"]:
|
if kwargs["benchmark"]: # Just to make sure the number of examples is large enough for benchmark run.
|
||||||
tfrecords = list(itertools.chain(*(20 * [tfrecords])))
|
nbs = kwargs["test_batches"] if mode == "test" else kwargs["train_batches"]
|
||||||
idx_files = list(itertools.chain(*(20 * [idx_files])))
|
if kwargs["dim"] == 3:
|
||||||
|
nbs *= batch_size
|
||||||
|
imgs = list(itertools.chain(*(100 * [imgs])))[: nbs * kwargs["gpus"]]
|
||||||
|
lbls = list(itertools.chain(*(100 * [lbls])))[: nbs * kwargs["gpus"]]
|
||||||
|
if mode == "eval":
|
||||||
|
reminder = len(imgs) % kwargs["gpus"]
|
||||||
|
if reminder != 0:
|
||||||
|
imgs = imgs[:-reminder]
|
||||||
|
lbls = lbls[:-reminder]
|
||||||
|
|
||||||
pipe_kwargs = {
|
pipe_kwargs = {
|
||||||
"tfrecords": tfrecords,
|
"imgs": imgs,
|
||||||
"tfrecords_idx": idx_files,
|
"lbls": lbls,
|
||||||
|
"dim": kwargs["dim"],
|
||||||
"gpus": kwargs["gpus"],
|
"gpus": kwargs["gpus"],
|
||||||
"seed": kwargs["seed"],
|
"seed": kwargs["seed"],
|
||||||
|
"meta": kwargs["meta"],
|
||||||
"patch_size": kwargs["patch_size"],
|
"patch_size": kwargs["patch_size"],
|
||||||
"dim": kwargs["dim"],
|
|
||||||
"oversampling": kwargs["oversampling"],
|
"oversampling": kwargs["oversampling"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if kwargs["benchmark"] and mode == "eval":
|
if kwargs["benchmark"]:
|
||||||
pipeline = TFRecordBenchmark
|
pipeline = BenchmarkPipeline
|
||||||
output_map = ["image", "label"]
|
output_map = ["image", "label"]
|
||||||
dynamic_shape = False
|
dynamic_shape = False
|
||||||
elif mode == "training":
|
if kwargs["dim"] == 2:
|
||||||
pipeline = TFRecordTrain
|
pipe_kwargs.update({"batch_size_2d": batch_size})
|
||||||
|
batch_size = 1
|
||||||
|
elif mode == "train":
|
||||||
|
pipeline = TrainPipeline
|
||||||
output_map = ["image", "label"]
|
output_map = ["image", "label"]
|
||||||
dynamic_shape = False
|
dynamic_shape = False
|
||||||
|
if kwargs["dim"] == 2:
|
||||||
|
pipe_kwargs.update({"batch_size_2d": batch_size // kwargs["nvol"]})
|
||||||
|
batch_size = kwargs["nvol"]
|
||||||
elif mode == "eval":
|
elif mode == "eval":
|
||||||
pipeline = TFRecordEval
|
pipeline = EvalPipeline
|
||||||
output_map = ["image", "label", "fname"]
|
output_map = ["image", "label"]
|
||||||
dynamic_shape = True
|
dynamic_shape = True
|
||||||
else:
|
else:
|
||||||
pipeline = TFRecordTest
|
pipeline = TestPipeline
|
||||||
output_map = ["image", "fname"]
|
output_map = ["image", "meta"]
|
||||||
dynamic_shape = True
|
dynamic_shape = True
|
||||||
|
|
||||||
device_id = int(os.getenv("LOCAL_RANK", "0"))
|
device_id = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
@ -268,7 +294,7 @@ def fetch_dali_loader(tfrecords, idx_files, batch_size, mode, **kwargs):
|
||||||
return LightningWrapper(
|
return LightningWrapper(
|
||||||
pipe,
|
pipe,
|
||||||
auto_reset=True,
|
auto_reset=True,
|
||||||
reader_name="Reader",
|
reader_name="ReaderX",
|
||||||
output_map=output_map,
|
output_map=output_map,
|
||||||
dynamic_shape=dynamic_shape,
|
dynamic_shape=dynamic_shape,
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,16 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
from subprocess import call
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from joblib import Parallel, delayed
|
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
from sklearn.model_selection import KFold
|
from sklearn.model_selection import KFold
|
||||||
from tqdm import tqdm
|
from utils.utils import get_config_file, get_path, get_split, get_test_fnames, is_main_process, load_data
|
||||||
from utils.utils import get_config_file, get_task_code, is_main_process, make_empty_dir
|
|
||||||
|
|
||||||
from data_loading.dali_loader import fetch_dali_loader
|
from data_loading.dali_loader import fetch_dali_loader
|
||||||
|
|
||||||
|
@ -30,16 +23,13 @@ class DataModule(LightningDataModule):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.tfrecords_train = []
|
self.train_imgs = []
|
||||||
self.tfrecords_val = []
|
self.train_lbls = []
|
||||||
self.tfrecords_test = []
|
self.val_imgs = []
|
||||||
self.train_idx = []
|
self.val_lbls = []
|
||||||
self.val_idx = []
|
self.test_imgs = []
|
||||||
self.test_idx = []
|
|
||||||
self.kfold = KFold(n_splits=self.args.nfolds, shuffle=True, random_state=12345)
|
self.kfold = KFold(n_splits=self.args.nfolds, shuffle=True, random_state=12345)
|
||||||
self.data_path = os.path.join(self.args.data, get_task_code(self.args))
|
self.data_path = get_path(args)
|
||||||
if self.args.exec_mode == "predict" and not args.benchmark:
|
|
||||||
self.data_path = os.path.join(self.data_path, "test")
|
|
||||||
configs = get_config_file(self.args)
|
configs = get_config_file(self.args)
|
||||||
self.kwargs = {
|
self.kwargs = {
|
||||||
"dim": self.args.dim,
|
"dim": self.args.dim,
|
||||||
|
@ -48,90 +38,37 @@ class DataModule(LightningDataModule):
|
||||||
"gpus": self.args.gpus,
|
"gpus": self.args.gpus,
|
||||||
"num_workers": self.args.num_workers,
|
"num_workers": self.args.num_workers,
|
||||||
"oversampling": self.args.oversampling,
|
"oversampling": self.args.oversampling,
|
||||||
"create_idx": self.args.create_idx,
|
|
||||||
"benchmark": self.args.benchmark,
|
"benchmark": self.args.benchmark,
|
||||||
|
"nvol": self.args.nvol,
|
||||||
|
"train_batches": self.args.train_batches,
|
||||||
|
"test_batches": self.args.test_batches,
|
||||||
|
"meta": load_data(self.data_path, "*_meta.npy"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_data(self):
|
|
||||||
if self.args.create_idx:
|
|
||||||
tfrecords_train, tfrecords_val, tfrecords_test = self.load_tfrecords()
|
|
||||||
make_empty_dir("train_idx")
|
|
||||||
make_empty_dir("val_idx")
|
|
||||||
make_empty_dir("test_idx")
|
|
||||||
self.create_idx("train_idx", tfrecords_train)
|
|
||||||
self.create_idx("val_idx", tfrecords_val)
|
|
||||||
self.create_idx("test_idx", tfrecords_test)
|
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
self.tfrecords_train, self.tfrecords_val, self.tfrecords_test = self.load_tfrecords()
|
imgs = load_data(self.data_path, "*_x.npy")
|
||||||
self.train_idx, self.val_idx, self.test_idx = self.load_idx_files()
|
lbls = load_data(self.data_path, "*_y.npy")
|
||||||
|
|
||||||
|
self.test_imgs, self.kwargs["meta"] = get_test_fnames(self.args, self.data_path, self.kwargs["meta"])
|
||||||
|
if self.args.exec_mode != "predict" or self.args.benchmark:
|
||||||
|
train_idx, val_idx = list(self.kfold.split(imgs))[self.args.fold]
|
||||||
|
self.train_imgs = get_split(imgs, train_idx)
|
||||||
|
self.train_lbls = get_split(lbls, train_idx)
|
||||||
|
self.val_imgs = get_split(imgs, val_idx)
|
||||||
|
self.val_lbls = get_split(lbls, val_idx)
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
ntrain, nval, ntest = len(self.tfrecords_train), len(self.tfrecords_val), len(self.tfrecords_test)
|
ntrain, nval = len(self.train_imgs), len(self.val_imgs)
|
||||||
print(f"Number of examples: Train {ntrain} - Val {nval} - Test {ntest}")
|
print(f"Number of examples: Train {ntrain} - Val {nval}")
|
||||||
|
elif is_main_process():
|
||||||
|
print(f"Number of test examples: {len(self.test_imgs)}")
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return fetch_dali_loader(self.tfrecords_train, self.train_idx, self.args.batch_size, "training", **self.kwargs)
|
return fetch_dali_loader(self.train_imgs, self.train_lbls, self.args.batch_size, "train", **self.kwargs)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return fetch_dali_loader(self.tfrecords_val, self.val_idx, 1, "eval", **self.kwargs)
|
return fetch_dali_loader(self.val_imgs, self.val_lbls, 1, "eval", **self.kwargs)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
if self.kwargs["benchmark"]:
|
if self.kwargs["benchmark"]:
|
||||||
return fetch_dali_loader(
|
return fetch_dali_loader(self.train_imgs, self.train_lbls, self.args.val_batch_size, "test", **self.kwargs)
|
||||||
self.tfrecords_train, self.train_idx, self.args.val_batch_size, "eval", **self.kwargs
|
return fetch_dali_loader(self.test_imgs, None, 1, "test", **self.kwargs)
|
||||||
)
|
|
||||||
return fetch_dali_loader(self.tfrecords_test, self.test_idx, 1, "test", **self.kwargs)
|
|
||||||
|
|
||||||
def load_tfrecords(self):
|
|
||||||
if self.args.dim == 2:
|
|
||||||
train_tfrecords = self.load_data(self.data_path, "*.train_tfrecord")
|
|
||||||
val_tfrecords = self.load_data(self.data_path, "*.val_tfrecord")
|
|
||||||
else:
|
|
||||||
train_tfrecords = self.load_data(self.data_path, "*.tfrecord")
|
|
||||||
val_tfrecords = self.load_data(self.data_path, "*.tfrecord")
|
|
||||||
|
|
||||||
train_idx, val_idx = list(self.kfold.split(train_tfrecords))[self.args.fold]
|
|
||||||
train_tfrecords = self.get_split(train_tfrecords, train_idx)
|
|
||||||
val_tfrecords = self.get_split(val_tfrecords, val_idx)
|
|
||||||
|
|
||||||
return train_tfrecords, val_tfrecords, self.load_data(os.path.join(self.data_path, "test"), "*.tfrecord")
|
|
||||||
|
|
||||||
def load_idx_files(self):
|
|
||||||
if self.args.create_idx:
|
|
||||||
test_idx = sorted(glob.glob(os.path.join("test_idx", "*.idx")))
|
|
||||||
else:
|
|
||||||
test_idx = self.get_idx_list("test/idx_files", self.tfrecords_test)
|
|
||||||
|
|
||||||
if self.args.create_idx:
|
|
||||||
train_idx = sorted(glob.glob(os.path.join("train_idx", "*.idx")))
|
|
||||||
val_idx = sorted(glob.glob(os.path.join("val_idx", "*.idx")))
|
|
||||||
elif self.args.dim == 3:
|
|
||||||
train_idx = self.get_idx_list("idx_files", self.tfrecords_train)
|
|
||||||
val_idx = self.get_idx_list("idx_files", self.tfrecords_val)
|
|
||||||
else:
|
|
||||||
train_idx = self.get_idx_list("train_idx_files", self.tfrecords_train)
|
|
||||||
val_idx = self.get_idx_list("val_idx_files", self.tfrecords_val)
|
|
||||||
return train_idx, val_idx, test_idx
|
|
||||||
|
|
||||||
def create_idx(self, idx_dir, tfrecords):
|
|
||||||
idx_files = [os.path.join(idx_dir, os.path.basename(tfrec).split(".")[0] + ".idx") for tfrec in tfrecords]
|
|
||||||
Parallel(n_jobs=-1)(
|
|
||||||
delayed(self.tfrecord2idx)(tfrec, tfidx)
|
|
||||||
for tfrec, tfidx in tqdm(zip(tfrecords, idx_files), total=len(tfrecords))
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_idx_list(self, dir_name, tfrecords):
|
|
||||||
root_dir = os.path.join(self.data_path, dir_name)
|
|
||||||
return sorted([os.path.join(root_dir, os.path.basename(tfr).split(".")[0] + ".idx") for tfr in tfrecords])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_split(data, idx):
|
|
||||||
return list(np.array(data)[idx])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_data(path, files_pattern):
|
|
||||||
return sorted(glob.glob(os.path.join(path, files_pattern)))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tfrecord2idx(tfrecord, tfidx):
|
|
||||||
call(["tfrecord2idx", tfrecord, tfidx])
|
|
||||||
|
|
|
@ -1,127 +0,0 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from glob import glob
|
|
||||||
from subprocess import call
|
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from joblib import Parallel, delayed
|
|
||||||
from tqdm import tqdm
|
|
||||||
from utils.utils import get_task_code, make_empty_dir
|
|
||||||
|
|
||||||
|
|
||||||
class Converter:
|
|
||||||
def __init__(self, args):
|
|
||||||
self.args = args
|
|
||||||
self.mode = self.args.exec_mode
|
|
||||||
task_code = get_task_code(self.args)
|
|
||||||
self.data = os.path.join(self.args.data, task_code)
|
|
||||||
self.results = os.path.join(self.args.results, task_code)
|
|
||||||
if self.mode == "test":
|
|
||||||
self.data = os.path.join(self.data, "test")
|
|
||||||
self.results = os.path.join(self.results, "test")
|
|
||||||
self.vpf = self.args.vpf
|
|
||||||
|
|
||||||
self.imgs = self.load_files("*x.npy")
|
|
||||||
self.lbls = self.load_files("*y.npy")
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
print("Saving tfrecords...")
|
|
||||||
suffix = "tfrecord" if self.args.dim == 3 else "val_tfrecord"
|
|
||||||
self.save_tfrecords(self.imgs, self.lbls, dim=3, suffix=suffix)
|
|
||||||
if self.args.dim == 2:
|
|
||||||
self.save_tfrecords(self.imgs, self.lbls, dim=2, suffix="train_tfrecord")
|
|
||||||
train_tfrecords, train_idx_dir = self.get_tfrecords_data("*.train_tfrecord", "train_idx_files")
|
|
||||||
val_tfrecords, val_idx_dir = self.get_tfrecords_data("*.val_tfrecord", "val_idx_files")
|
|
||||||
print("Saving idx files...")
|
|
||||||
self.create_idx_files(train_tfrecords, train_idx_dir)
|
|
||||||
self.create_idx_files(val_tfrecords, val_idx_dir)
|
|
||||||
else:
|
|
||||||
tfrecords, idx_dir = self.get_tfrecords_data("*.tfrecord", "idx_files")
|
|
||||||
print("Saving idx files...")
|
|
||||||
self.create_idx_files(tfrecords, idx_dir)
|
|
||||||
|
|
||||||
def save_tfrecords(self, imgs, lbls, dim, suffix):
|
|
||||||
if len(lbls) == 0:
|
|
||||||
lbls = imgs[:]
|
|
||||||
chunks = np.array_split(list(zip(imgs, lbls)), math.ceil(len(imgs) / self.args.vpf))
|
|
||||||
Parallel(n_jobs=self.args.n_jobs)(
|
|
||||||
delayed(self.convert2tfrec)(chunk, dim, suffix) for chunk in tqdm(chunks, total=len(chunks))
|
|
||||||
)
|
|
||||||
|
|
||||||
def convert2tfrec(self, files, dim, suffix):
|
|
||||||
examples = []
|
|
||||||
for img_path, lbl_path in files:
|
|
||||||
img, lbl = np.load(img_path), np.load(lbl_path)
|
|
||||||
if dim == 2:
|
|
||||||
for depth in range(img.shape[1]):
|
|
||||||
examples.append(self.create_example(img[:, depth], lbl[:, depth], os.path.basename(img_path)))
|
|
||||||
else:
|
|
||||||
examples.append(self.create_example(img, lbl, os.path.basename(img_path)))
|
|
||||||
|
|
||||||
fname = os.path.basename(files[0][0]).replace("_x.npy", "")
|
|
||||||
tfrecord_name = os.path.join(self.results, f"{fname}.{suffix}")
|
|
||||||
with tf.io.TFRecordWriter(tfrecord_name) as writer:
|
|
||||||
for example in examples:
|
|
||||||
writer.write(example.SerializeToString())
|
|
||||||
|
|
||||||
def create_idx_files(self, tfrecords, save_dir):
|
|
||||||
make_empty_dir(save_dir)
|
|
||||||
tfrecords_idx = []
|
|
||||||
for tfrec in tfrecords:
|
|
||||||
fname = os.path.basename(tfrec).split(".")[0]
|
|
||||||
tfrecords_idx.append(os.path.join(save_dir, f"{fname}.idx"))
|
|
||||||
|
|
||||||
Parallel(n_jobs=self.args.n_jobs)(
|
|
||||||
delayed(self.create_idx)(tr, ti) for tr, ti in tqdm(zip(tfrecords, tfrecords_idx), total=len(tfrecords))
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_example(self, img, lbl, fname):
|
|
||||||
def _float_feature(value):
|
|
||||||
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
|
||||||
|
|
||||||
def _int64_feature(value):
|
|
||||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
|
||||||
|
|
||||||
def _bytes_feature(value):
|
|
||||||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
|
||||||
|
|
||||||
feature = {
|
|
||||||
"X": _float_feature(img.flatten()),
|
|
||||||
"X_shape": _int64_feature(img.shape),
|
|
||||||
"fname": _bytes_feature(str.encode(fname)),
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.mode == "training":
|
|
||||||
feature.update({"Y": _bytes_feature(lbl.flatten().tobytes()), "Y_shape": _int64_feature(lbl.shape)})
|
|
||||||
|
|
||||||
return tf.train.Example(features=tf.train.Features(feature=feature))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_idx(tfrecord, tfidx):
|
|
||||||
call(["tfrecord2idx", tfrecord, tfidx])
|
|
||||||
|
|
||||||
def load_files(self, suffix):
|
|
||||||
return sorted(glob(os.path.join(self.data, suffix)))
|
|
||||||
|
|
||||||
def get_tfrecords_data(self, tfrec_pattern, idx_dir):
|
|
||||||
tfrecords = self.load_files(os.path.join(self.results, tfrec_pattern))
|
|
||||||
tfrecords_dir = os.path.join(self.results, idx_dir)
|
|
||||||
return tfrecords, tfrecords_dir
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
@ -23,11 +22,11 @@ import monai.transforms as transforms
|
||||||
import nibabel
|
import nibabel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
from skimage.morphology import dilation, erosion, square
|
||||||
from skimage.transform import resize
|
from skimage.transform import resize
|
||||||
from utils.utils import get_task_code, make_empty_dir
|
from utils.utils import get_task_code, make_empty_dir
|
||||||
|
|
||||||
from data_preprocessing.configs import (ct_max, ct_mean, ct_min, ct_std,
|
from data_preprocessing.configs import ct_max, ct_mean, ct_min, ct_std, patch_size, spacings, task
|
||||||
patch_size, spacings, task)
|
|
||||||
|
|
||||||
|
|
||||||
class Preprocessor:
|
class Preprocessor:
|
||||||
|
@ -45,11 +44,17 @@ class Preprocessor:
|
||||||
self.data_path = os.path.join(args.data, task[args.task])
|
self.data_path = os.path.join(args.data, task[args.task])
|
||||||
self.results = os.path.join(args.results, self.task_code)
|
self.results = os.path.join(args.results, self.task_code)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
self.results = os.path.join(self.results, "test")
|
self.results = os.path.join(self.results, self.args.exec_mode)
|
||||||
self.metadata = json.load(open(os.path.join(self.data_path, "dataset.json"), "r"))
|
|
||||||
self.modality = self.metadata["modality"]["0"]
|
|
||||||
self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
|
self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
|
||||||
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=True, channel_wise=True)
|
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=False, channel_wise=True)
|
||||||
|
metadata_path = os.path.join(self.data_path, "dataset.json")
|
||||||
|
if self.args.exec_mode == "val":
|
||||||
|
dataset_json = json.load(open(metadata_path, "r"))
|
||||||
|
dataset_json["val"] = dataset_json["training"]
|
||||||
|
with open(metadata_path, "w") as outfile:
|
||||||
|
json.dump(dataset_json, outfile)
|
||||||
|
self.metadata = json.load(open(metadata_path, "r"))
|
||||||
|
self.modality = self.metadata["modality"]["0"]
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
make_empty_dir(self.results)
|
make_empty_dir(self.results)
|
||||||
|
@ -87,19 +92,31 @@ class Preprocessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_pair(self, pair):
|
def preprocess_pair(self, pair):
|
||||||
fname = os.path.basename(pair["image"] if self.training else pair)
|
fname = os.path.basename(pair["image"] if isinstance(pair, dict) else pair)
|
||||||
image, label, image_spacings = self.load_pair(pair)
|
image, label, image_spacings = self.load_pair(pair)
|
||||||
if self.training:
|
if self.training:
|
||||||
data = self.crop_foreg({"image": image, "label": label})
|
data = self.crop_foreg({"image": image, "label": label})
|
||||||
image, label = data["image"], data["label"]
|
image, label = data["image"], data["label"]
|
||||||
|
test_metadata = None
|
||||||
|
else:
|
||||||
|
bbox = transforms.utils.generate_spatial_bounding_box(image)
|
||||||
|
test_metadata = np.vstack([bbox, image.shape[1:]])
|
||||||
|
image = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(image)
|
||||||
|
if label is not None:
|
||||||
|
label = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(label)
|
||||||
if self.args.dim == 3:
|
if self.args.dim == 3:
|
||||||
image, label = self.resample(image, label, image_spacings)
|
image, label = self.resample(image, label, image_spacings)
|
||||||
if self.modality == "CT":
|
if self.modality == "CT":
|
||||||
image = np.clip(image, self.ct_min, self.ct_max)
|
image = np.clip(image, self.ct_min, self.ct_max)
|
||||||
|
image = self.normalize(image)
|
||||||
if self.training:
|
if self.training:
|
||||||
image, label = self.standardize(image, label)
|
image, label = self.standardize(image, label)
|
||||||
image = self.normalize(image)
|
if self.args.dilation:
|
||||||
self.save(image, label, fname)
|
new_lbl = np.zeros(label.shape, dtype=np.uint8)
|
||||||
|
for depth in range(label.shape[1]):
|
||||||
|
new_lbl[0, depth] = erosion(dilation(label[0, depth], square(3)), square(3))
|
||||||
|
label = new_lbl
|
||||||
|
self.save(image, label, fname, test_metadata)
|
||||||
|
|
||||||
def resample(self, image, label, image_spacings):
|
def resample(self, image, label, image_spacings):
|
||||||
if self.target_spacing != image_spacings:
|
if self.target_spacing != image_spacings:
|
||||||
|
@ -108,9 +125,9 @@ class Preprocessor:
|
||||||
|
|
||||||
def standardize(self, image, label):
|
def standardize(self, image, label):
|
||||||
pad_shape = self.calculate_pad_shape(image)
|
pad_shape = self.calculate_pad_shape(image)
|
||||||
img_shape = image.shape[1:]
|
image_shape = image.shape[1:]
|
||||||
if pad_shape != img_shape:
|
if pad_shape != image_shape:
|
||||||
paddings = [(pad_sh - img_sh) / 2 for (pad_sh, img_sh) in zip(pad_shape, img_shape)]
|
paddings = [(pad_sh - image_sh) / 2 for (pad_sh, image_sh) in zip(pad_shape, image_shape)]
|
||||||
image = self.pad(image, paddings)
|
image = self.pad(image, paddings)
|
||||||
label = self.pad(label, paddings)
|
label = self.pad(label, paddings)
|
||||||
if self.args.dim == 2: # Center cropping 2D images.
|
if self.args.dim == 2: # Center cropping 2D images.
|
||||||
|
@ -126,21 +143,26 @@ class Preprocessor:
|
||||||
return (image - self.ct_mean) / self.ct_std
|
return (image - self.ct_mean) / self.ct_std
|
||||||
return self.normalize_intensity(image)
|
return self.normalize_intensity(image)
|
||||||
|
|
||||||
def save(self, image, label, fname):
|
def save(self, image, label, fname, test_metadata):
|
||||||
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
|
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
|
||||||
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
|
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
|
||||||
self.save_3d(image, label, fname)
|
self.save_npy(image, fname, "_x.npy")
|
||||||
|
if label is not None:
|
||||||
|
self.save_npy(label, fname, "_y.npy")
|
||||||
|
if test_metadata is not None:
|
||||||
|
self.save_npy(test_metadata, fname, "_meta.npy")
|
||||||
|
|
||||||
def load_pair(self, pair):
|
def load_pair(self, pair):
|
||||||
image = self.load_nifty(pair["image"] if self.training else pair)
|
image = self.load_nifty(pair["image"] if isinstance(pair, dict) else pair)
|
||||||
image_spacing = self.load_spacing(image)
|
image_spacing = self.load_spacing(image)
|
||||||
image = image.get_fdata().astype(np.float32)
|
image = image.get_fdata().astype(np.float32)
|
||||||
image = self.standardize_layout(image)
|
image = self.standardize_layout(image)
|
||||||
|
|
||||||
label = None
|
|
||||||
if self.training:
|
if self.training:
|
||||||
label = self.load_nifty(pair["label"]).get_fdata().astype(np.uint8)
|
label = self.load_nifty(pair["label"]).get_fdata().astype(np.uint8)
|
||||||
label = self.standardize_layout(label)
|
label = self.standardize_layout(label)
|
||||||
|
else:
|
||||||
|
label = None
|
||||||
|
|
||||||
return image, label, image_spacing
|
return image, label, image_spacing
|
||||||
|
|
||||||
|
@ -148,23 +170,23 @@ class Preprocessor:
|
||||||
shape = self.calculate_new_shape(spacing, image.shape[1:])
|
shape = self.calculate_new_shape(spacing, image.shape[1:])
|
||||||
if self.check_anisotrophy(spacing):
|
if self.check_anisotrophy(spacing):
|
||||||
image = self.resample_anisotrophic_image(image, shape)
|
image = self.resample_anisotrophic_image(image, shape)
|
||||||
if self.training:
|
if label is not None:
|
||||||
label = self.resample_anisotrophic_label(label, shape)
|
label = self.resample_anisotrophic_label(label, shape)
|
||||||
else:
|
else:
|
||||||
image = self.resample_regular_image(image, shape)
|
image = self.resample_regular_image(image, shape)
|
||||||
if self.training:
|
if label is not None:
|
||||||
label = self.resample_regular_label(label, shape)
|
label = self.resample_regular_label(label, shape)
|
||||||
image = image.astype(np.float32)
|
image = image.astype(np.float32)
|
||||||
if self.training:
|
if label is not None:
|
||||||
label = label.astype(np.uint8)
|
label = label.astype(np.uint8)
|
||||||
return image, label
|
return image, label
|
||||||
|
|
||||||
def calculate_pad_shape(self, image):
|
def calculate_pad_shape(self, image):
|
||||||
min_shape = self.patch_size[:]
|
min_shape = self.patch_size[:]
|
||||||
img_shape = image.shape[1:]
|
image_shape = image.shape[1:]
|
||||||
if len(min_shape) == 2: # In 2D case we don't want to pad depth axis.
|
if len(min_shape) == 2: # In 2D case we don't want to pad depth axis.
|
||||||
min_shape.insert(0, img_shape[0])
|
min_shape.insert(0, image_shape[0])
|
||||||
pad_shape = [max(mshape, ishape) for mshape, ishape in zip(min_shape, img_shape)]
|
pad_shape = [max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape)]
|
||||||
return pad_shape
|
return pad_shape
|
||||||
|
|
||||||
def get_intensities(self, pair):
|
def get_intensities(self, pair):
|
||||||
|
@ -205,13 +227,8 @@ class Preprocessor:
|
||||||
new_shape = (spacing_ratio * np.array(shape)).astype(int).tolist()
|
new_shape = (spacing_ratio * np.array(shape)).astype(int).tolist()
|
||||||
return new_shape
|
return new_shape
|
||||||
|
|
||||||
def save_3d(self, image, label, fname):
|
def save_npy(self, image, fname, suffix):
|
||||||
self.save_npy(image, fname, "_x.npy")
|
np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), image, allow_pickle=False)
|
||||||
if self.training:
|
|
||||||
self.save_npy(label, fname, "_y.npy")
|
|
||||||
|
|
||||||
def save_npy(self, img, fname, suffix):
|
|
||||||
np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), img, allow_pickle=False)
|
|
||||||
|
|
||||||
def run_parallel(self, func, exec_mode):
|
def run_parallel(self, func, exec_mode):
|
||||||
return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode])
|
return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode])
|
||||||
|
|
58
PyTorch/Segmentation/nnUNet/evaluate.py
Normal file
58
PyTorch/Segmentation/nnUNet/evaluate.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||||
|
|
||||||
|
import nibabel
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--preds", type=str, required=True, help="Path to predictions")
|
||||||
|
parser.add_argument("--lbls", type=str, required=True, help="Path to labels")
|
||||||
|
|
||||||
|
|
||||||
|
def get_stats(pred, targ, class_idx):
|
||||||
|
tp_ = np.logical_and(pred == class_idx, targ == class_idx).sum()
|
||||||
|
fn_ = np.logical_and(pred != class_idx, targ == class_idx).sum()
|
||||||
|
fp_ = np.logical_and(pred == class_idx, targ != class_idx).sum()
|
||||||
|
return tp_, fn_, fp_
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
y_pred = sorted(glob.glob(os.path.join(args.preds, "*.npy")))
|
||||||
|
y_true = [os.path.join(args.lbls, os.path.basename(pred).replace("npy", "nii.gz")) for pred in y_pred]
|
||||||
|
assert len(y_pred) > 0
|
||||||
|
n_class = np.load(y_pred[0]).shape[0] - 1
|
||||||
|
|
||||||
|
dice = [[] for _ in range(n_class)]
|
||||||
|
for pr, lb in tqdm(zip(y_pred, y_true), total=len(y_pred)):
|
||||||
|
prd = np.transpose(np.argmax(np.load(pr), axis=0), (2, 1, 0))
|
||||||
|
lbl = nibabel.load(lb).get_fdata().astype(np.uint8)
|
||||||
|
|
||||||
|
for i in range(1, n_class + 1):
|
||||||
|
counts = np.count_nonzero(lbl == i) + np.count_nonzero(prd == i)
|
||||||
|
if counts == 0: # no foreground class
|
||||||
|
dice[i - 1].append(1)
|
||||||
|
else:
|
||||||
|
tp, fn, fp = get_stats(prd, lbl, i)
|
||||||
|
denum = 2 * tp + fp + fn
|
||||||
|
dice[i - 1].append(2 * tp / denum if denum != 0 else 0)
|
||||||
|
|
||||||
|
dice_score = np.mean(np.array(dice), axis=-1)
|
||||||
|
dice_cls = " ".join([f"L{i+1} {round(dice_score[i], 4)}" for i, dice in enumerate(dice_score)])
|
||||||
|
print(f"mean dice: {round(np.mean(dice_score), 4)} - {dice_cls}")
|
Binary file not shown.
Before Width: | Height: | Size: 115 KiB After Width: | Height: | Size: 115 KiB |
|
@ -16,7 +16,6 @@ import os
|
||||||
|
|
||||||
import pyprof
|
import pyprof
|
||||||
import torch
|
import torch
|
||||||
from dllogger import JSONStreamBackend, Logger, StdOutBackend, Verbosity
|
|
||||||
from pytorch_lightning import Trainer, seed_everything
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
|
@ -24,31 +23,11 @@ from data_loading.data_module import DataModule
|
||||||
from models.nn_unet import NNUnet
|
from models.nn_unet import NNUnet
|
||||||
from utils.gpu_affinity import set_affinity
|
from utils.gpu_affinity import set_affinity
|
||||||
from utils.logger import LoggingCallback
|
from utils.logger import LoggingCallback
|
||||||
from utils.utils import get_main_args, is_main_process, make_empty_dir, set_cuda_devices, verify_ckpt_path
|
from utils.utils import get_main_args, is_main_process, log, make_empty_dir, set_cuda_devices, verify_ckpt_path
|
||||||
|
|
||||||
|
|
||||||
def log(logname, dice, epoch=None, dice_tta=None):
|
|
||||||
dllogger = Logger(
|
|
||||||
backends=[
|
|
||||||
JSONStreamBackend(Verbosity.VERBOSE, os.path.join(args.results, logname)),
|
|
||||||
StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
metrics = {}
|
|
||||||
if epoch is not None:
|
|
||||||
metrics.update({"Epoch": epoch})
|
|
||||||
metrics.update({"Mean dice": round(dice.mean().item(), 2)})
|
|
||||||
if dice_tta is not None:
|
|
||||||
metrics.update({"Mean TTA dice": round(dice_tta.mean().item(), 2)})
|
|
||||||
metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)})
|
|
||||||
if dice_tta is not None:
|
|
||||||
metrics.update({f"TTA_L{j+1}": round(m.item(), 2) for j, m in enumerate(dice_tta)})
|
|
||||||
dllogger.log(step=(), data=metrics)
|
|
||||||
dllogger.flush()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_main_args()
|
args = get_main_args()
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
pyprof.init(enable_function_stack=True)
|
pyprof.init(enable_function_stack=True)
|
||||||
print("Profiling enabled")
|
print("Profiling enabled")
|
||||||
|
@ -57,8 +36,6 @@ if __name__ == "__main__":
|
||||||
affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)
|
affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)
|
||||||
|
|
||||||
set_cuda_devices(args)
|
set_cuda_devices(args)
|
||||||
if is_main_process():
|
|
||||||
print(f"{args.exec_mode.upper()} TASK {args.task} FOLD {args.fold} SEED {args.seed}")
|
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
data_module = DataModule(args)
|
data_module = DataModule(args)
|
||||||
data_module.prepare_data()
|
data_module.prepare_data()
|
||||||
|
@ -121,26 +98,24 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
trainer.fit(model, train_dataloader=data_module.train_dataloader())
|
trainer.fit(model, train_dataloader=data_module.train_dataloader())
|
||||||
else:
|
else:
|
||||||
|
# warmup
|
||||||
|
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
||||||
|
# benchmark run
|
||||||
|
trainer.current_epoch = 1
|
||||||
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
||||||
elif args.exec_mode == "train":
|
elif args.exec_mode == "train":
|
||||||
trainer.fit(model, data_module)
|
trainer.fit(model, data_module)
|
||||||
if model_ckpt is not None:
|
|
||||||
model.args.exec_mode = "evaluate"
|
|
||||||
model.args.tta = True
|
|
||||||
trainer.interrupted = False
|
|
||||||
trainer.test(test_dataloaders=data_module.val_dataloader())
|
|
||||||
if is_main_process():
|
|
||||||
log_name = args.logname if args.logname is not None else "train_log.json"
|
|
||||||
log(log_name, model.best_sum_dice, model.best_sum_epoch, model.eval_dice)
|
|
||||||
elif args.exec_mode == "evaluate":
|
elif args.exec_mode == "evaluate":
|
||||||
model.args = args
|
model.args = args
|
||||||
trainer.test(model, test_dataloaders=data_module.val_dataloader())
|
trainer.test(model, test_dataloaders=data_module.val_dataloader())
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
log(args.logname if args.logname is not None else "eval_log.json", model.eval_dice)
|
logname = args.logname if args.logname is not None else "eval_log.json"
|
||||||
|
log(logname, model.eval_dice, results=args.results)
|
||||||
elif args.exec_mode == "predict":
|
elif args.exec_mode == "predict":
|
||||||
model.args = args
|
model.args = args
|
||||||
if args.save_preds:
|
if args.save_preds:
|
||||||
dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}"
|
prec = "amp" if args.amp else "fp32"
|
||||||
|
dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}_{prec}"
|
||||||
if args.tta:
|
if args.tta:
|
||||||
dir_name += "_tta"
|
dir_name += "_tta"
|
||||||
save_dir = os.path.join(args.results, dir_name)
|
save_dir = os.path.join(args.results, dir_name)
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from dropblock import DropBlock3D, LinearScheduler
|
||||||
|
|
||||||
normalizations = {
|
normalizations = {
|
||||||
"instancenorm3d": nn.InstanceNorm3d,
|
"instancenorm3d": nn.InstanceNorm3d,
|
||||||
|
@ -67,22 +68,40 @@ def get_output_padding(kernel_size, stride, padding):
|
||||||
return out_padding if len(out_padding) > 1 else out_padding[0]
|
return out_padding if len(out_padding) > 1 else out_padding[0]
|
||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Module):
|
def get_drop_block():
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
|
return LinearScheduler(
|
||||||
super(ConvLayer, self).__init__()
|
DropBlock3D(block_size=5, drop_prob=0.0),
|
||||||
self.conv = get_conv(in_channels, out_channels, kernel_size, stride, dim)
|
start_value=0.0,
|
||||||
self.norm = get_norm(norm, out_channels)
|
stop_value=0.1,
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=negative_slope, inplace=True)
|
nr_steps=10000,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input_data):
|
|
||||||
return self.lrelu(self.norm(self.conv(input_data)))
|
class ConvLayer(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||||
|
super(ConvLayer, self).__init__()
|
||||||
|
self.conv = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
|
||||||
|
self.norm = get_norm(kwargs["norm"], out_channels)
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
|
||||||
|
self.use_drop_block = kwargs["drop_block"]
|
||||||
|
if self.use_drop_block:
|
||||||
|
self.drop_block = get_drop_block()
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
out = self.conv(data)
|
||||||
|
if self.use_drop_block:
|
||||||
|
self.drop_block.step()
|
||||||
|
out = self.drop_block(out)
|
||||||
|
out = self.norm(out)
|
||||||
|
out = self.lrelu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
|
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||||
super(ConvBlock, self).__init__()
|
super(ConvBlock, self).__init__()
|
||||||
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim)
|
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
|
||||||
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
|
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, **kwargs)
|
||||||
|
|
||||||
def forward(self, input_data):
|
def forward(self, input_data):
|
||||||
out = self.conv1(input_data)
|
out = self.conv1(input_data)
|
||||||
|
@ -90,14 +109,72 @@ class ConvBlock(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||||
|
super(ResidBlock, self).__init__()
|
||||||
|
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
|
||||||
|
self.conv2 = get_conv(out_channels, out_channels, kernel_size, 1, kwargs["dim"])
|
||||||
|
self.norm = get_norm(kwargs["norm"], out_channels)
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
|
||||||
|
self.use_drop_block = kwargs["drop_block"]
|
||||||
|
if self.use_drop_block:
|
||||||
|
self.drop_block = get_drop_block()
|
||||||
|
self.skip_drop_block = get_drop_block()
|
||||||
|
self.downsample = None
|
||||||
|
if max(stride) > 1 or in_channels != out_channels:
|
||||||
|
self.downsample = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
|
||||||
|
self.norm_res = get_norm(kwargs["norm"], out_channels)
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
residual = input_data
|
||||||
|
out = self.conv1(input_data)
|
||||||
|
out = self.conv2(out)
|
||||||
|
if self.use_drop_block:
|
||||||
|
out = self.drop_block(out)
|
||||||
|
out = self.norm(out)
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(residual)
|
||||||
|
if self.use_drop_block:
|
||||||
|
residual = self.skip_drop_block(residual)
|
||||||
|
residual = self.norm_res(residual)
|
||||||
|
out = self.lrelu(out + residual)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, norm, dim):
|
||||||
|
super(AttentionLayer, self).__init__()
|
||||||
|
self.conv = get_conv(in_channels, out_channels, kernel_size=3, stride=1, dim=dim)
|
||||||
|
self.norm = get_norm(norm, out_channels)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
out = self.conv(inputs)
|
||||||
|
out = self.norm(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class UpsampleBlock(nn.Module):
|
class UpsampleBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, norm, negative_slope, dim):
|
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||||
super(UpsampleBlock, self).__init__()
|
super(UpsampleBlock, self).__init__()
|
||||||
self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, dim)
|
self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, kwargs["dim"])
|
||||||
self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, norm, negative_slope, dim)
|
self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, **kwargs)
|
||||||
|
self.attention = kwargs["attention"]
|
||||||
|
if self.attention:
|
||||||
|
att_out, norm, dim = out_channels // 2, kwargs["norm"], kwargs["dim"]
|
||||||
|
self.conv_o = AttentionLayer(out_channels, att_out, norm, dim)
|
||||||
|
self.conv_s = AttentionLayer(out_channels, att_out, norm, dim)
|
||||||
|
self.psi = AttentionLayer(att_out, 1, norm, dim)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
def forward(self, input_data, skip_data):
|
def forward(self, input_data, skip_data):
|
||||||
out = self.transp_conv(input_data)
|
out = self.transp_conv(input_data)
|
||||||
|
if self.attention:
|
||||||
|
out_a = self.conv_o(out)
|
||||||
|
skip_a = self.conv_s(skip_data)
|
||||||
|
psi_a = self.psi(self.relu(out_a + skip_a))
|
||||||
|
attention = self.sigmoid(psi_a)
|
||||||
|
skip_data = skip_data * attention
|
||||||
out = torch.cat((out, skip_data), dim=1)
|
out = torch.cat((out, skip_data), dim=1)
|
||||||
out = self.conv_block(out)
|
out = self.conv_block(out)
|
||||||
return out
|
return out
|
||||||
|
@ -107,6 +184,7 @@ class OutputBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, dim):
|
def __init__(self, in_channels, out_channels, dim):
|
||||||
super(OutputBlock, self).__init__()
|
super(OutputBlock, self).__init__()
|
||||||
self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
|
self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
|
||||||
|
nn.init.constant_(self.conv.bias, 0)
|
||||||
|
|
||||||
def forward(self, input_data):
|
def forward(self, input_data):
|
||||||
return self.conv(input_data)
|
return self.conv(input_data)
|
||||||
|
|
59
PyTorch/Segmentation/nnUNet/models/loss.py
Normal file
59
PyTorch/Segmentation/nnUNet/models/loss.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from monai.losses import FocalLoss
|
||||||
|
|
||||||
|
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
def __init__(self, include_background=False, smooth=1e-5, eps=1e-7):
|
||||||
|
super(DiceLoss, self).__init__()
|
||||||
|
self.include_background = include_background
|
||||||
|
self.smooth = smooth
|
||||||
|
self.dims = (0, 2)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, y_pred, y_true):
|
||||||
|
num_classes, batch_size = y_pred.size(1), y_true.size(0)
|
||||||
|
y_pred = y_pred.log_softmax(dim=1).exp()
|
||||||
|
y_true, y_pred = y_true.view(batch_size, -1), y_pred.view(batch_size, num_classes, -1)
|
||||||
|
y_true = F.one_hot(y_true.to(torch.int64), num_classes).permute(0, 2, 1)
|
||||||
|
if not self.include_background:
|
||||||
|
y_true, y_pred = y_true[:, 1:], y_pred[:, 1:]
|
||||||
|
intersection = torch.sum(y_true * y_pred, dim=self.dims)
|
||||||
|
cardinality = torch.sum(y_true + y_pred, dim=self.dims)
|
||||||
|
dice_loss = 1 - (2.0 * intersection + self.smooth) / (cardinality + self.smooth).clamp_min(self.eps)
|
||||||
|
mask = (y_true.sum(self.dims) > 0).to(dice_loss.dtype)
|
||||||
|
dice_loss *= mask.to(dice_loss.dtype)
|
||||||
|
dice_loss = dice_loss.sum() / mask.sum()
|
||||||
|
return dice_loss
|
||||||
|
|
||||||
|
|
||||||
|
class Loss(nn.Module):
|
||||||
|
def __init__(self, focal):
|
||||||
|
super(Loss, self).__init__()
|
||||||
|
self.dice = DiceLoss()
|
||||||
|
self.cross_entropy = nn.CrossEntropyLoss()
|
||||||
|
self.focal = FocalLoss(gamma=2.0)
|
||||||
|
self.use_focal = focal
|
||||||
|
|
||||||
|
def forward(self, y_pred, y_true):
|
||||||
|
loss = self.dice(y_pred, y_true)
|
||||||
|
if self.use_focal:
|
||||||
|
loss += self.focal(y_pred, y_true)
|
||||||
|
else:
|
||||||
|
loss += self.cross_entropy(y_pred, y_true[:, 0].long())
|
||||||
|
return loss
|
|
@ -12,9 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import monai
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from pytorch_lightning.metrics.functional import stat_scores
|
from pytorch_lightning.metrics.functional import stat_scores
|
||||||
from pytorch_lightning.metrics.metric import Metric
|
from pytorch_lightning.metrics.metric import Metric
|
||||||
|
|
||||||
|
@ -22,43 +20,28 @@ from pytorch_lightning.metrics.metric import Metric
|
||||||
class Dice(Metric):
|
class Dice(Metric):
|
||||||
def __init__(self, nclass):
|
def __init__(self, nclass):
|
||||||
super().__init__(dist_sync_on_step=True)
|
super().__init__(dist_sync_on_step=True)
|
||||||
self.add_state("dice", default=torch.zeros((nclass,)), dist_reduce_fx="mean")
|
self.add_state("n_updates", default=torch.zeros(1), dist_reduce_fx="sum")
|
||||||
|
self.add_state("dice", default=torch.zeros((nclass,)), dist_reduce_fx="sum")
|
||||||
|
|
||||||
def update(self, pred, target):
|
def update(self, pred, target):
|
||||||
self.dice = self.compute_stats(pred, target)
|
self.n_updates += 1
|
||||||
|
self.dice += self.compute_stats(pred, target)
|
||||||
|
|
||||||
def compute(self):
|
def compute(self):
|
||||||
return self.dice
|
return 100 * self.dice / self.n_updates
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_stats(pred, target):
|
def compute_stats(pred, target):
|
||||||
num_classes = pred.shape[1]
|
num_classes = pred.shape[1]
|
||||||
_bg = 1
|
scores = torch.zeros(num_classes - 1, device=pred.device, dtype=torch.float32)
|
||||||
scores = torch.zeros(num_classes - _bg, device=pred.device, dtype=torch.float32)
|
for i in range(1, num_classes):
|
||||||
precision = torch.zeros(num_classes - _bg, device=pred.device, dtype=torch.float32)
|
if (target != i).all():
|
||||||
recall = torch.zeros(num_classes - _bg, device=pred.device, dtype=torch.float32)
|
|
||||||
for i in range(_bg, num_classes):
|
|
||||||
if not (target == i).any():
|
|
||||||
# no foreground class
|
# no foreground class
|
||||||
_, _pred = torch.max(pred, 1)
|
_, _pred = torch.max(pred, 1)
|
||||||
scores[i - _bg] += 1 if not (_pred == i).any() else 0
|
scores[i - 1] += 1 if (_pred != i).all() else 0
|
||||||
recall[i - _bg] += 1 if not (_pred == i).any() else 0
|
|
||||||
precision[i - _bg] += 1 if not (_pred == i).any() else 0
|
|
||||||
continue
|
continue
|
||||||
_tp, _fp, _tn, _fn, _ = stat_scores(pred=pred, target=target, class_index=i)
|
_tp, _fp, _tn, _fn, _ = stat_scores(pred=pred, target=target, class_index=i)
|
||||||
denom = (2 * _tp + _fp + _fn).to(torch.float)
|
denom = (2 * _tp + _fp + _fn).to(torch.float)
|
||||||
score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
|
score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
|
||||||
scores[i - _bg] += score_cls
|
scores[i - 1] += score_cls
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class Loss(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Loss, self).__init__()
|
|
||||||
self.dice = monai.losses.DiceLoss(to_onehot_y=True, softmax=True, batch=True)
|
|
||||||
self.cross_entropy = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def forward(self, y_pred, y_true):
|
|
||||||
dice = self.dice(y_pred, y_true)
|
|
||||||
cross_entropy = self.cross_entropy(y_pred, y_true[:, 0].long())
|
|
||||||
return dice + cross_entropy
|
|
||||||
|
|
|
@ -18,12 +18,23 @@ import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_optimizer as optim
|
from apex.optimizers import FusedAdam, FusedSGD
|
||||||
from dllogger import JSONStreamBackend, Logger, StdOutBackend, Verbosity
|
|
||||||
from monai.inferers import sliding_window_inference
|
from monai.inferers import sliding_window_inference
|
||||||
from utils.utils import flip, get_config_file, is_main_process
|
from skimage.transform import resize
|
||||||
|
from torch_optimizer import RAdam
|
||||||
|
from utils.utils import (
|
||||||
|
flip,
|
||||||
|
get_dllogger,
|
||||||
|
get_path,
|
||||||
|
get_test_fnames,
|
||||||
|
get_tta_flips,
|
||||||
|
get_unet_params,
|
||||||
|
is_main_process,
|
||||||
|
layout_2d,
|
||||||
|
)
|
||||||
|
|
||||||
from models.metrics import Dice, Loss
|
from models.loss import Loss
|
||||||
|
from models.metrics import Dice
|
||||||
from models.unet import UNet
|
from models.unet import UNet
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,48 +44,41 @@ class NNUnet(pl.LightningModule):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
self.build_nnunet()
|
self.build_nnunet()
|
||||||
self.loss = Loss()
|
self.loss = Loss(self.args.focal)
|
||||||
self.dice = Dice(self.n_class)
|
self.dice = Dice(self.n_class)
|
||||||
self.best_sum = 0
|
self.best_sum = 0
|
||||||
self.eval_dice = 0
|
|
||||||
self.best_sum_epoch = 0
|
self.best_sum_epoch = 0
|
||||||
self.best_dice = self.n_class * [0]
|
self.best_dice = self.n_class * [0]
|
||||||
self.best_epoch = self.n_class * [0]
|
self.best_epoch = self.n_class * [0]
|
||||||
self.best_sum_dice = self.n_class * [0]
|
self.best_sum_dice = self.n_class * [0]
|
||||||
self.learning_rate = args.learning_rate
|
self.learning_rate = args.learning_rate
|
||||||
|
self.tta_flips = get_tta_flips(args.dim)
|
||||||
|
self.test_idx = 0
|
||||||
|
self.test_imgs = []
|
||||||
if self.args.exec_mode in ["train", "evaluate"]:
|
if self.args.exec_mode in ["train", "evaluate"]:
|
||||||
self.dllogger = Logger(
|
self.dllogger = get_dllogger(args.results)
|
||||||
backends=[
|
|
||||||
JSONStreamBackend(Verbosity.VERBOSE, os.path.join(args.results, "logs.json")),
|
|
||||||
StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tta_flips = (
|
|
||||||
[[2], [3], [2, 3]] if self.args.dim == 2 else [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
if self.args.benchmark:
|
if self.args.benchmark:
|
||||||
|
if self.args.dim == 2 and self.args.data2d_dim == 3:
|
||||||
|
img = layout_2d(img, None)
|
||||||
return self.model(img)
|
return self.model(img)
|
||||||
return self.tta_inference(img) if self.args.tta else self.do_inference(img)
|
return self.tta_inference(img) if self.args.tta else self.do_inference(img)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
img, lbl = batch["image"], batch["label"]
|
img, lbl = self.get_train_data(batch)
|
||||||
if self.args.dim == 2 and len(lbl.shape) == 3:
|
|
||||||
lbl = lbl.unsqueeze(1)
|
|
||||||
pred = self.model(img)
|
pred = self.model(img)
|
||||||
loss = self.compute_loss(pred, lbl)
|
loss = self.compute_loss(pred, lbl)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
|
if self.current_epoch < self.args.skip_first_n_eval:
|
||||||
|
return None
|
||||||
img, lbl = batch["image"], batch["label"]
|
img, lbl = batch["image"], batch["label"]
|
||||||
if self.args.dim == 2 and len(lbl.shape) == 3:
|
|
||||||
lbl = lbl.unsqueeze(1)
|
|
||||||
pred = self.forward(img)
|
pred = self.forward(img)
|
||||||
loss = self.loss(pred, lbl)
|
loss = self.loss(pred, lbl)
|
||||||
dice = self.dice(pred, lbl[:, 0])
|
self.dice.update(pred, lbl[:, 0])
|
||||||
return {"val_loss": loss, "val_dice": dice}
|
return {"val_loss": loss}
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
if self.args.exec_mode == "evaluate":
|
if self.args.exec_mode == "evaluate":
|
||||||
|
@ -82,53 +86,46 @@ class NNUnet(pl.LightningModule):
|
||||||
img = batch["image"]
|
img = batch["image"]
|
||||||
pred = self.forward(img)
|
pred = self.forward(img)
|
||||||
if self.args.save_preds:
|
if self.args.save_preds:
|
||||||
self.save_mask(pred, batch["fname"])
|
meta = batch["meta"][0].cpu().detach().numpy()
|
||||||
|
original_shape = meta[2]
|
||||||
|
min_d, max_d = meta[0, 0], meta[1, 0]
|
||||||
|
min_h, max_h = meta[0, 1], meta[1, 1]
|
||||||
|
min_w, max_w = meta[0, 2], meta[1, 2]
|
||||||
|
|
||||||
def build_unet(self, in_channels, n_class, kernels, strides):
|
final_pred = torch.zeros((1, pred.shape[1], *original_shape), device=img.device)
|
||||||
return UNet(
|
final_pred[:, :, min_d:max_d, min_h:max_h, min_w:max_w] = pred
|
||||||
|
final_pred = nn.functional.softmax(final_pred, dim=1)
|
||||||
|
final_pred = final_pred.squeeze(0).cpu().detach().numpy()
|
||||||
|
|
||||||
|
if not all(original_shape == final_pred.shape[1:]):
|
||||||
|
class_ = final_pred.shape[0]
|
||||||
|
resized_pred = np.zeros((class_, *original_shape))
|
||||||
|
for i in range(class_):
|
||||||
|
resized_pred[i] = resize(
|
||||||
|
final_pred[i], original_shape, order=3, mode="edge", cval=0, clip=True, anti_aliasing=False
|
||||||
|
)
|
||||||
|
final_pred = resized_pred
|
||||||
|
|
||||||
|
self.save_mask(final_pred)
|
||||||
|
|
||||||
|
def build_nnunet(self):
|
||||||
|
in_channels, n_class, kernels, strides, self.patch_size = get_unet_params(self.args)
|
||||||
|
self.n_class = n_class - 1
|
||||||
|
self.model = UNet(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
n_class=n_class,
|
n_class=n_class,
|
||||||
kernels=kernels,
|
kernels=kernels,
|
||||||
strides=strides,
|
strides=strides,
|
||||||
|
dimension=self.args.dim,
|
||||||
|
residual=self.args.residual,
|
||||||
|
attention=self.args.attention,
|
||||||
|
drop_block=self.args.drop_block,
|
||||||
normalization_layer=self.args.norm,
|
normalization_layer=self.args.norm,
|
||||||
negative_slope=self.args.negative_slope,
|
negative_slope=self.args.negative_slope,
|
||||||
deep_supervision=self.args.deep_supervision,
|
deep_supervision=self.args.deep_supervision,
|
||||||
dimension=self.args.dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_unet_params(self):
|
|
||||||
config = get_config_file(self.args)
|
|
||||||
in_channels = config["in_channels"]
|
|
||||||
patch_size = config["patch_size"]
|
|
||||||
spacings = config["spacings"]
|
|
||||||
n_class = config["n_class"]
|
|
||||||
|
|
||||||
strides, kernels, sizes = [], [], patch_size[:]
|
|
||||||
while True:
|
|
||||||
spacing_ratio = [spacing / min(spacings) for spacing in spacings]
|
|
||||||
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
|
|
||||||
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
|
|
||||||
if all(s == 1 for s in stride):
|
|
||||||
break
|
|
||||||
sizes = [i / j for i, j in zip(sizes, stride)]
|
|
||||||
spacings = [i * j for i, j in zip(spacings, stride)]
|
|
||||||
kernels.append(kernel)
|
|
||||||
strides.append(stride)
|
|
||||||
if len(strides) == 5:
|
|
||||||
break
|
|
||||||
strides.insert(0, len(spacings) * [1])
|
|
||||||
kernels.append(len(spacings) * [3])
|
|
||||||
|
|
||||||
return in_channels, n_class, kernels, strides, patch_size
|
|
||||||
|
|
||||||
def build_nnunet(self):
|
|
||||||
in_channels, n_class, kernels, strides, self.patch_size = self.get_unet_params()
|
|
||||||
self.model = self.build_unet(in_channels, n_class, kernels, strides)
|
|
||||||
self.n_class = n_class - 1
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(f"Filters: {self.model.filters}")
|
print(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}")
|
||||||
print(f"Kernels: {kernels}")
|
|
||||||
print(f"Strides: {strides}")
|
|
||||||
|
|
||||||
def compute_loss(self, preds, label):
|
def compute_loss(self, preds, label):
|
||||||
if self.args.deep_supervision:
|
if self.args.deep_supervision:
|
||||||
|
@ -141,15 +138,14 @@ class NNUnet(pl.LightningModule):
|
||||||
return self.loss(preds, label)
|
return self.loss(preds, label)
|
||||||
|
|
||||||
def do_inference(self, image):
|
def do_inference(self, image):
|
||||||
if self.args.dim == 2:
|
if self.args.dim == 3:
|
||||||
|
return self.sliding_window_inference(image)
|
||||||
if self.args.data2d_dim == 2:
|
if self.args.data2d_dim == 2:
|
||||||
return self.model(image)
|
return self.model(image)
|
||||||
if self.args.exec_mode == "predict" and not self.args.benchmark:
|
if self.args.exec_mode == "predict":
|
||||||
return self.inference2d_test(image)
|
return self.inference2d_test(image)
|
||||||
return self.inference2d(image)
|
return self.inference2d(image)
|
||||||
|
|
||||||
return self.sliding_window_inference(image)
|
|
||||||
|
|
||||||
def tta_inference(self, img):
|
def tta_inference(self, img):
|
||||||
pred = self.do_inference(img)
|
pred = self.do_inference(img)
|
||||||
for flip_idx in self.tta_flips:
|
for flip_idx in self.tta_flips:
|
||||||
|
@ -159,12 +155,9 @@ class NNUnet(pl.LightningModule):
|
||||||
|
|
||||||
def inference2d(self, image):
|
def inference2d(self, image):
|
||||||
batch_modulo = image.shape[2] % self.args.val_batch_size
|
batch_modulo = image.shape[2] % self.args.val_batch_size
|
||||||
if self.args.benchmark:
|
if batch_modulo != 0:
|
||||||
image = image[:, :, batch_modulo:]
|
|
||||||
elif batch_modulo != 0:
|
|
||||||
batch_pad = self.args.val_batch_size - batch_modulo
|
batch_pad = self.args.val_batch_size - batch_modulo
|
||||||
image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)
|
image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)
|
||||||
|
|
||||||
image = torch.transpose(image.squeeze(0), 0, 1)
|
image = torch.transpose(image.squeeze(0), 0, 1)
|
||||||
preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
|
preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
|
||||||
preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device)
|
preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device)
|
||||||
|
@ -172,10 +165,8 @@ class NNUnet(pl.LightningModule):
|
||||||
end = start + self.args.val_batch_size
|
end = start + self.args.val_batch_size
|
||||||
pred = self.model(image[start:end])
|
pred = self.model(image[start:end])
|
||||||
preds[start:end] = pred.data
|
preds[start:end] = pred.data
|
||||||
|
if batch_modulo != 0:
|
||||||
if batch_modulo != 0 and not self.args.benchmark:
|
|
||||||
preds = preds[batch_pad:]
|
preds = preds[batch_pad:]
|
||||||
|
|
||||||
return torch.transpose(preds, 0, 1).unsqueeze(0)
|
return torch.transpose(preds, 0, 1).unsqueeze(0)
|
||||||
|
|
||||||
def inference2d_test(self, image):
|
def inference2d_test(self, image):
|
||||||
|
@ -192,7 +183,7 @@ class NNUnet(pl.LightningModule):
|
||||||
sw_batch_size=self.args.val_batch_size,
|
sw_batch_size=self.args.val_batch_size,
|
||||||
predictor=self.model,
|
predictor=self.model,
|
||||||
overlap=self.args.overlap,
|
overlap=self.args.overlap,
|
||||||
mode=self.args.val_mode,
|
mode=self.args.blend,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -200,8 +191,12 @@ class NNUnet(pl.LightningModule):
|
||||||
return torch.stack([out[name] for out in outputs]).mean(dim=0)
|
return torch.stack([out[name] for out in outputs]).mean(dim=0)
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
|
if self.current_epoch < self.args.skip_first_n_eval:
|
||||||
|
self.log("dice_sum", 0.001 * self.current_epoch)
|
||||||
|
self.dice.reset()
|
||||||
|
return None
|
||||||
loss = self.metric_mean("val_loss", outputs)
|
loss = self.metric_mean("val_loss", outputs)
|
||||||
dice = 100 * self.metric_mean("val_dice", outputs)
|
dice = self.dice.compute()
|
||||||
dice_sum = torch.sum(dice)
|
dice_sum = torch.sum(dice)
|
||||||
if dice_sum >= self.best_sum:
|
if dice_sum >= self.best_sum:
|
||||||
self.best_sum = dice_sum
|
self.best_sum = dice_sum
|
||||||
|
@ -215,6 +210,7 @@ class NNUnet(pl.LightningModule):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
metrics.update({"mean dice": round(torch.mean(dice).item(), 2)})
|
metrics.update({"mean dice": round(torch.mean(dice).item(), 2)})
|
||||||
metrics.update({"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)})
|
metrics.update({"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)})
|
||||||
|
if self.n_class > 1:
|
||||||
metrics.update({f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)})
|
metrics.update({f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)})
|
||||||
metrics.update({f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice)})
|
metrics.update({f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice)})
|
||||||
metrics.update({"val_loss": round(loss.item(), 4)})
|
metrics.update({"val_loss": round(loss.item(), 4)})
|
||||||
|
@ -226,14 +222,13 @@ class NNUnet(pl.LightningModule):
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
if self.args.exec_mode == "evaluate":
|
if self.args.exec_mode == "evaluate":
|
||||||
self.eval_dice = 100 * self.metric_mean("val_dice", outputs)
|
self.eval_dice = self.dice.compute()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = {
|
optimizer = {
|
||||||
"sgd": torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum),
|
"sgd": FusedSGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum),
|
||||||
"adam": torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
"adam": FusedAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
||||||
"adamw": torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
"radam": RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
||||||
"radam": optim.RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
|
||||||
}[self.args.optimizer.lower()]
|
}[self.args.optimizer.lower()]
|
||||||
|
|
||||||
scheduler = {
|
scheduler = {
|
||||||
|
@ -250,8 +245,16 @@ class NNUnet(pl.LightningModule):
|
||||||
opt_dict.update({"lr_scheduler": scheduler})
|
opt_dict.update({"lr_scheduler": scheduler})
|
||||||
return opt_dict
|
return opt_dict
|
||||||
|
|
||||||
def save_mask(self, pred, fname):
|
def save_mask(self, pred):
|
||||||
fname = str(fname[0].cpu().detach().numpy(), "utf-8").replace("_x", "_pred")
|
if self.test_idx == 0:
|
||||||
pred = nn.functional.softmax(torch.tensor(pred), dim=1)
|
data_path = get_path(self.args)
|
||||||
pred = pred.squeeze().cpu().detach().numpy()
|
self.test_imgs, _ = get_test_fnames(self.args, data_path)
|
||||||
|
fname = os.path.basename(self.test_imgs[self.test_idx]).replace("_x", "")
|
||||||
np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False)
|
np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False)
|
||||||
|
self.test_idx += 1
|
||||||
|
|
||||||
|
def get_train_data(self, batch):
|
||||||
|
img, lbl = batch["image"], batch["label"]
|
||||||
|
if self.args.dim == 2 and self.args.data2d_dim == 3:
|
||||||
|
img, lbl = layout_2d(img, lbl)
|
||||||
|
return img, lbl
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from models.layers import ConvBlock, OutputBlock, UpsampleBlock
|
from models.layers import ConvBlock, OutputBlock, ResidBlock, UpsampleBlock
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
|
@ -27,36 +27,44 @@ class UNet(nn.Module):
|
||||||
normalization_layer,
|
normalization_layer,
|
||||||
negative_slope,
|
negative_slope,
|
||||||
deep_supervision,
|
deep_supervision,
|
||||||
|
attention,
|
||||||
|
drop_block,
|
||||||
|
residual,
|
||||||
dimension,
|
dimension,
|
||||||
):
|
):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
self.dim = dimension
|
self.dim = dimension
|
||||||
self.n_class = n_class
|
self.n_class = n_class
|
||||||
|
self.attention = attention
|
||||||
|
self.residual = residual
|
||||||
self.negative_slope = negative_slope
|
self.negative_slope = negative_slope
|
||||||
self.deep_supervision = deep_supervision
|
self.deep_supervision = deep_supervision
|
||||||
self.norm = normalization_layer + f"norm{dimension}d"
|
self.norm = normalization_layer + f"norm{dimension}d"
|
||||||
self.filters = [min(2 ** (5 + i), 320) for i in range(len(strides))]
|
self.filters = [min(2 ** (5 + i), 320 if dimension == 3 else 512) for i in range(len(strides))]
|
||||||
|
|
||||||
|
down_block = ResidBlock if self.residual else ConvBlock
|
||||||
self.input_block = self.get_conv_block(
|
self.input_block = self.get_conv_block(
|
||||||
conv_block=ConvBlock,
|
conv_block=down_block,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=self.filters[0],
|
out_channels=self.filters[0],
|
||||||
kernel_size=kernels[0],
|
kernel_size=kernels[0],
|
||||||
stride=strides[0],
|
stride=strides[0],
|
||||||
)
|
)
|
||||||
self.downsamples = self.get_module_list(
|
self.downsamples = self.get_module_list(
|
||||||
conv_block=ConvBlock,
|
conv_block=down_block,
|
||||||
in_channels=self.filters[:-1],
|
in_channels=self.filters[:-1],
|
||||||
out_channels=self.filters[1:],
|
out_channels=self.filters[1:],
|
||||||
kernels=kernels[1:-1],
|
kernels=kernels[1:-1],
|
||||||
strides=strides[1:-1],
|
strides=strides[1:-1],
|
||||||
|
drop_block=drop_block,
|
||||||
)
|
)
|
||||||
self.bottleneck = self.get_conv_block(
|
self.bottleneck = self.get_conv_block(
|
||||||
conv_block=ConvBlock,
|
conv_block=down_block,
|
||||||
in_channels=self.filters[-2],
|
in_channels=self.filters[-2],
|
||||||
out_channels=self.filters[-1],
|
out_channels=self.filters[-1],
|
||||||
kernel_size=kernels[-1],
|
kernel_size=kernels[-1],
|
||||||
stride=strides[-1],
|
stride=strides[-1],
|
||||||
|
drop_block=drop_block,
|
||||||
)
|
)
|
||||||
self.upsamples = self.get_module_list(
|
self.upsamples = self.get_module_list(
|
||||||
conv_block=UpsampleBlock,
|
conv_block=UpsampleBlock,
|
||||||
|
@ -87,15 +95,17 @@ class UNet(nn.Module):
|
||||||
out.append(self.deep_supervision_heads[i](decoder_out))
|
out.append(self.deep_supervision_heads[i](decoder_out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride):
|
def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride, drop_block=False):
|
||||||
return conv_block(
|
return conv_block(
|
||||||
in_channels=in_channels,
|
dim=self.dim,
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
stride=stride,
|
||||||
norm=self.norm,
|
norm=self.norm,
|
||||||
|
drop_block=drop_block,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
attention=self.attention,
|
||||||
|
out_channels=out_channels,
|
||||||
negative_slope=self.negative_slope,
|
negative_slope=self.negative_slope,
|
||||||
dim=self.dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_output_block(self, decoder_level):
|
def get_output_block(self, decoder_level):
|
||||||
|
@ -104,10 +114,11 @@ class UNet(nn.Module):
|
||||||
def get_deep_supervision_heads(self):
|
def get_deep_supervision_heads(self):
|
||||||
return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)])
|
return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)])
|
||||||
|
|
||||||
def get_module_list(self, in_channels, out_channels, kernels, strides, conv_block):
|
def get_module_list(self, in_channels, out_channels, kernels, strides, conv_block, drop_block=False):
|
||||||
layers = []
|
layers = []
|
||||||
for in_channel, out_channel, kernel, stride in zip(in_channels, out_channels, kernels, strides):
|
for i, (in_channel, out_channel, kernel, stride) in enumerate(zip(in_channels, out_channels, kernels, strides)):
|
||||||
conv_layer = self.get_conv_block(conv_block, in_channel, out_channel, kernel, stride)
|
use_drop_block = drop_block and len(in_channels) - i <= 2
|
||||||
|
conv_layer = self.get_conv_block(conv_block, in_channel, out_channel, kernel, stride, use_drop_block)
|
||||||
layers.append(conv_layer)
|
layers.append(conv_layer)
|
||||||
return nn.ModuleList(layers)
|
return nn.ModuleList(layers)
|
||||||
|
|
||||||
|
@ -115,5 +126,3 @@ class UNet(nn.Module):
|
||||||
name = module.__class__.__name__.lower()
|
name = module.__class__.__name__.lower()
|
||||||
if name in ["conv2d", "conv3d"]:
|
if name in ["conv2d", "conv3d"]:
|
||||||
nn.init.kaiming_normal_(module.weight, a=self.negative_slope)
|
nn.init.kaiming_normal_(module.weight, a=self.negative_slope)
|
||||||
elif name in ["convtranspose2d", "convtranspose3d"]:
|
|
||||||
nn.init.kaiming_normal_(module.weight, a=1.0)
|
|
||||||
|
|
|
@ -141,7 +141,7 @@
|
||||||
" self.fp[0] += false_pos\n",
|
" self.fp[0] += false_pos\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def compute(self):\n",
|
" def compute(self):\n",
|
||||||
" return 2 * self.tp / (2 * self.tp + self.fp + self.fn)\n",
|
" return 200 * self.tp / (2 * self.tp + self.fp + self.fn)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @staticmethod\n",
|
" @staticmethod\n",
|
||||||
" def get_stats(pred, targ, class_idx):\n",
|
" def get_stats(pred, targ, class_idx):\n",
|
||||||
|
@ -277,7 +277,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"### Data loading <a name=\"dataloader\"></a>\n",
|
"### Data loading <a name=\"dataloader\"></a>\n",
|
||||||
"\n",
|
"\n",
|
||||||
"In our nnUNet repository we are converting data to [tfrecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) format and use [NVIDIA DALI](https://docs.nvidia.com/deeplearning/dali/master-user-guide/docs/index.html) for data loading. However, you can modify this part and create your own data loading pipeline.\n",
|
"In our nnUNet repository we are converting data to npy format and use [NVIDIA DALI](https://docs.nvidia.com/deeplearning/dali/master-user-guide/docs/index.html) for data loading. However, you can modify this part and create your own data loading pipeline.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"In this example we use PyTorch DataLoader with *zoom*, *crop*, *flips*, *gaussian noise*, *gamma*, *brightness* and *contrast* for data augmentation from [albumentations](https://albumentations.ai) library.\n",
|
"In this example we use PyTorch DataLoader with *zoom*, *crop*, *flips*, *gaussian noise*, *gamma*, *brightness* and *contrast* for data augmentation from [albumentations](https://albumentations.ai) library.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -324,6 +324,7 @@
|
||||||
" img = self.brctr(image=img)[\"image\"]\n",
|
" img = self.brctr(image=img)[\"image\"]\n",
|
||||||
" img = self.gamma(image=img)[\"image\"]\n",
|
" img = self.gamma(image=img)[\"image\"]\n",
|
||||||
" img = self.normalize(image=img)[\"image\"]\n",
|
" img = self.normalize(image=img)[\"image\"]\n",
|
||||||
|
" lbl = np.expand_dims(lbl, 0)\n",
|
||||||
" return {\"image\": np.transpose(img, (2, 0, 1)), \"label\": lbl}\n",
|
" return {\"image\": np.transpose(img, (2, 0, 1)), \"label\": lbl}\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def load_pair(self, idx):\n",
|
" def load_pair(self, idx):\n",
|
||||||
|
@ -344,6 +345,7 @@
|
||||||
" def __getitem__(self, idx):\n",
|
" def __getitem__(self, idx):\n",
|
||||||
" img, lbl = self.load_pair(idx)\n",
|
" img, lbl = self.load_pair(idx)\n",
|
||||||
" img = self.normalize(image=img)[\"image\"]\n",
|
" img = self.normalize(image=img)[\"image\"]\n",
|
||||||
|
" lbl = np.expand_dims(lbl, 0)\n",
|
||||||
" return {\"image\": np.transpose(img, (2, 0, 1)), \"label\": lbl}\n",
|
" return {\"image\": np.transpose(img, (2, 0, 1)), \"label\": lbl}\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def load_pair(self, idx):\n",
|
" def load_pair(self, idx):\n",
|
||||||
|
|
|
@ -15,9 +15,7 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||||
from subprocess import call
|
|
||||||
|
|
||||||
from data_preprocessing.convert2tfrec import Converter
|
|
||||||
from data_preprocessing.preprocessor import Preprocessor
|
from data_preprocessing.preprocessor import Preprocessor
|
||||||
from utils.utils import get_task_code
|
from utils.utils import get_task_code
|
||||||
|
|
||||||
|
@ -28,24 +26,22 @@ parser.add_argument(
|
||||||
"--exec_mode",
|
"--exec_mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="training",
|
default="training",
|
||||||
choices=["training", "test"],
|
choices=["training", "val", "test"],
|
||||||
help="Mode for data preprocessing",
|
help="Mode for data preprocessing",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--dilation", action="store_true", help="Perform morphological label dilation")
|
||||||
parser.add_argument("--task", type=str, help="Number of task to be run. MSD uses numbers 01-10")
|
parser.add_argument("--task", type=str, help="Number of task to be run. MSD uses numbers 01-10")
|
||||||
parser.add_argument("--dim", type=int, default=3, choices=[2, 3], help="Data dimension to prepare")
|
parser.add_argument("--dim", type=int, default=3, choices=[2, 3], help="Data dimension to prepare")
|
||||||
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs for data preprocessing")
|
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs for data preprocessing")
|
||||||
parser.add_argument("--vpf", type=int, default=1, help="Volumes per tfrecord")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
Preprocessor(args).run()
|
Preprocessor(args).run()
|
||||||
Converter(args).run()
|
|
||||||
task_code = get_task_code(args)
|
task_code = get_task_code(args)
|
||||||
path = os.path.join(args.data, task_code)
|
path = os.path.join(args.data, task_code)
|
||||||
if args.exec_mode == "test":
|
if args.exec_mode == "test":
|
||||||
path = os.path.join(path, "test")
|
path = os.path.join(path, "test")
|
||||||
call(f'find {path} -name "*.npy" -print0 | xargs -0 rm', shell=True)
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Preprocessing time: {(end - start):.2f}")
|
print(f"Preprocessing time: {(end - start):.2f}")
|
||||||
|
|
|
@ -3,7 +3,7 @@ nibabel==3.1.1
|
||||||
joblib==0.16.0
|
joblib==0.16.0
|
||||||
scikit-learn==0.23.2
|
scikit-learn==0.23.2
|
||||||
pynvml==8.0.4
|
pynvml==8.0.4
|
||||||
tensorflow==2.3.1
|
|
||||||
pillow==6.2.0
|
pillow==6.2.0
|
||||||
fsspec==0.8.0
|
fsspec==0.8.0
|
||||||
pytorch_ranger==0.1.1
|
pytorch_ranger==0.1.1
|
||||||
|
dropblock
|
|
@ -19,22 +19,23 @@ from subprocess import call
|
||||||
|
|
||||||
parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument("--mode", type=str, required=True, choices=["train", "predict"], help="Benchmarking mode")
|
parser.add_argument("--mode", type=str, required=True, choices=["train", "predict"], help="Benchmarking mode")
|
||||||
|
parser.add_argument("--task", type=str, default="01", help="Task code")
|
||||||
parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use")
|
parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use")
|
||||||
parser.add_argument("--dim", type=int, required=True, help="Dimension of UNet")
|
parser.add_argument("--dim", type=int, required=True, help="Dimension of UNet")
|
||||||
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
||||||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||||
parser.add_argument("--train_batches", type=int, default=80, help="Number of batches for training")
|
parser.add_argument("--train_batches", type=int, default=150, help="Number of batches for training")
|
||||||
parser.add_argument("--test_batches", type=int, default=80, help="Number of batches for inference")
|
parser.add_argument("--test_batches", type=int, default=150, help="Number of batches for inference")
|
||||||
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations before collecting statistics")
|
parser.add_argument("--warmup", type=int, default=50, help="Warmup iterations before collecting statistics")
|
||||||
parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
|
parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
|
||||||
parser.add_argument("--logname", type=str, default="perf.json", help="Name of dlloger output")
|
parser.add_argument("--logname", type=str, default="perf.json", help="Name of dlloger output")
|
||||||
parser.add_argument("--create_idx", action="store_true", help="Create index files for tfrecord")
|
|
||||||
parser.add_argument("--profile", action="store_true", help="Enable dlprof profiling")
|
parser.add_argument("--profile", action="store_true", help="Enable dlprof profiling")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
||||||
cmd = "python main.py --task 01 --benchmark --max_epochs 1 --min_epochs 1 "
|
cmd = ""
|
||||||
|
cmd += f"python main.py --task {args.task} --benchmark --max_epochs 2 --min_epochs 1 --optimizer adam "
|
||||||
cmd += f"--results {args.results} "
|
cmd += f"--results {args.results} "
|
||||||
cmd += f"--logname {args.logname} "
|
cmd += f"--logname {args.logname} "
|
||||||
cmd += f"--exec_mode {args.mode} "
|
cmd += f"--exec_mode {args.mode} "
|
||||||
|
@ -44,7 +45,6 @@ if __name__ == "__main__":
|
||||||
cmd += f"--test_batches {args.test_batches} "
|
cmd += f"--test_batches {args.test_batches} "
|
||||||
cmd += f"--warmup {args.warmup} "
|
cmd += f"--warmup {args.warmup} "
|
||||||
cmd += "--amp " if args.amp else ""
|
cmd += "--amp " if args.amp else ""
|
||||||
cmd += "--create_idx " if args.create_idx else ""
|
|
||||||
cmd += "--profile " if args.profile else ""
|
cmd += "--profile " if args.profile else ""
|
||||||
if args.mode == "train":
|
if args.mode == "train":
|
||||||
cmd += f"--batch_size {args.batch_size} "
|
cmd += f"--batch_size {args.batch_size} "
|
||||||
|
|
|
@ -18,10 +18,12 @@ from os.path import dirname
|
||||||
from subprocess import call
|
from subprocess import call
|
||||||
|
|
||||||
parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--data", type=str, required=True, help="Path to data")
|
||||||
|
parser.add_argument("--task", type=str, default="01", help="Path to data")
|
||||||
parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], help="Fold number")
|
parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], help="Fold number")
|
||||||
parser.add_argument("--dim", type=int, required=True, help="Dimension of UNet")
|
parser.add_argument("--dim", type=int, required=True, help="Dimension of UNet")
|
||||||
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
|
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
|
||||||
parser.add_argument("--val_batch_size", type=int, default=4, help="Batch size")
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
||||||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||||
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
|
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
|
||||||
parser.add_argument("--save_preds", action="store_true", help="Save predicted masks")
|
parser.add_argument("--save_preds", action="store_true", help="Save predicted masks")
|
||||||
|
@ -30,11 +32,12 @@ parser.add_argument("--save_preds", action="store_true", help="Save predicted ma
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
||||||
cmd = f"python {path_to_main} --exec_mode evaluate --task 01 --gpus 1 "
|
cmd = f"python {path_to_main} --exec_mode predict --task {args.task} --gpus 1 "
|
||||||
|
cmd += f"--data {args.data} "
|
||||||
cmd += f"--dim {args.dim} "
|
cmd += f"--dim {args.dim} "
|
||||||
cmd += f"--fold {args.fold} "
|
cmd += f"--fold {args.fold} "
|
||||||
cmd += f"--ckpt_path {args.ckpt_path} "
|
cmd += f"--ckpt_path {args.ckpt_path} "
|
||||||
cmd += f"--val_batch_size {args.val_batch_size} "
|
cmd += f"--val_batch_size {args.batch_size} "
|
||||||
cmd += "--amp " if args.amp else ""
|
cmd += "--amp " if args.amp else ""
|
||||||
cmd += "--tta " if args.tta else ""
|
cmd += "--tta " if args.tta else ""
|
||||||
cmd += "--save_preds " if args.save_preds else ""
|
cmd += "--save_preds " if args.save_preds else ""
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
python download.py --task 01
|
|
||||||
python preprocess.py --task 01 --dim 3
|
|
||||||
python preprocess.py --task 01 --dim 2
|
|
|
@ -18,19 +18,22 @@ from os.path import dirname
|
||||||
from subprocess import call
|
from subprocess import call
|
||||||
|
|
||||||
parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("--task", type=str, default="01", help="Path to data")
|
||||||
parser.add_argument("--gpus", type=int, required=True, help="Number of GPUs")
|
parser.add_argument("--gpus", type=int, required=True, help="Number of GPUs")
|
||||||
parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], help="Fold number")
|
parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], help="Fold number")
|
||||||
parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
|
parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
|
||||||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||||
|
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
||||||
cmd = f"python {path_to_main} --exec_mode train --task 01 --deep_supervision --save_ckpt "
|
cmd = f"python {path_to_main} --exec_mode train --task {args.data} --deep_supervision --save_ckpt "
|
||||||
cmd += f"--dim {args.dim} "
|
cmd += f"--dim {args.dim} "
|
||||||
cmd += f"--batch_size {2 if args.dim == 3 else 16} "
|
cmd += f"--batch_size {2 if args.dim == 3 else 64} "
|
||||||
cmd += f"--val_batch_size {4 if args.dim == 3 else 64} "
|
cmd += f"--val_batch_size {4 if args.dim == 3 else 64} "
|
||||||
cmd += f"--fold {args.fold} "
|
cmd += f"--fold {args.fold} "
|
||||||
cmd += f"--gpus {args.gpus} "
|
cmd += f"--gpus {args.gpus} "
|
||||||
cmd += "--amp " if args.amp else ""
|
cmd += "--amp " if args.amp else ""
|
||||||
|
cmd += "--tta " if args.tta else ""
|
||||||
call(cmd, shell=True)
|
call(cmd, shell=True)
|
||||||
|
|
|
@ -39,13 +39,15 @@ class LoggingCallback(Callback):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
if self.profile and self.step == self.warmup_steps:
|
if self.profile and self.step == self.warmup_steps:
|
||||||
profiler.start()
|
profiler.start()
|
||||||
if self.step >= self.warmup_steps:
|
if self.step > self.warmup_steps:
|
||||||
self.timestamps.append(time.time())
|
self.timestamps.append(time.time())
|
||||||
|
|
||||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||||
|
if trainer.current_epoch == 1:
|
||||||
self.do_step()
|
self.do_step()
|
||||||
|
|
||||||
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||||
|
if trainer.current_epoch == 1:
|
||||||
self.do_step()
|
self.do_step()
|
||||||
|
|
||||||
def process_performance_stats(self, deltas):
|
def process_performance_stats(self, deltas):
|
||||||
|
@ -77,4 +79,5 @@ class LoggingCallback(Callback):
|
||||||
self.log()
|
self.log()
|
||||||
|
|
||||||
def on_test_end(self, trainer, pl_module):
|
def on_test_end(self, trainer, pl_module):
|
||||||
|
if trainer.current_epoch == 1:
|
||||||
self.log()
|
self.log()
|
||||||
|
|
|
@ -12,12 +12,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import glob
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||||
from subprocess import call
|
from subprocess import call
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from dllogger import JSONStreamBackend, Logger, StdOutBackend, Verbosity
|
||||||
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
|
|
||||||
def is_main_process():
|
def is_main_process():
|
||||||
|
@ -42,8 +46,26 @@ def get_task_code(args):
|
||||||
|
|
||||||
def get_config_file(args):
|
def get_config_file(args):
|
||||||
task_code = get_task_code(args)
|
task_code = get_task_code(args)
|
||||||
config_file = os.path.join(args.data, task_code, "config.pkl")
|
if args.data != "/data":
|
||||||
return pickle.load(open(config_file, "rb"))
|
path = os.path.join(args.data, "config.pkl")
|
||||||
|
else:
|
||||||
|
path = os.path.join(args.data, task_code, "config.pkl")
|
||||||
|
return pickle.load(open(path, "rb"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_dllogger(results):
|
||||||
|
return Logger(
|
||||||
|
backends=[
|
||||||
|
JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, "logs.json")),
|
||||||
|
StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tta_flips(dim):
|
||||||
|
if dim == 2:
|
||||||
|
return [[2], [3], [2, 3]]
|
||||||
|
return [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]]
|
||||||
|
|
||||||
|
|
||||||
def make_empty_dir(path):
|
def make_empty_dir(path):
|
||||||
|
@ -73,74 +95,151 @@ def float_0_1(value):
|
||||||
return ivalue
|
return ivalue
|
||||||
|
|
||||||
|
|
||||||
|
def get_unet_params(args):
|
||||||
|
config = get_config_file(args)
|
||||||
|
patch_size, spacings = config["patch_size"], config["spacings"]
|
||||||
|
strides, kernels, sizes = [], [], patch_size[:]
|
||||||
|
while True:
|
||||||
|
spacing_ratio = [spacing / min(spacings) for spacing in spacings]
|
||||||
|
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
|
||||||
|
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
|
||||||
|
if all(s == 1 for s in stride):
|
||||||
|
break
|
||||||
|
sizes = [i / j for i, j in zip(sizes, stride)]
|
||||||
|
spacings = [i * j for i, j in zip(spacings, stride)]
|
||||||
|
kernels.append(kernel)
|
||||||
|
strides.append(stride)
|
||||||
|
if len(strides) == 5:
|
||||||
|
break
|
||||||
|
strides.insert(0, len(spacings) * [1])
|
||||||
|
kernels.append(len(spacings) * [3])
|
||||||
|
return config["in_channels"], config["n_class"], kernels, strides, patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def log(logname, dice, results="/results"):
|
||||||
|
dllogger = Logger(
|
||||||
|
backends=[
|
||||||
|
JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, logname)),
|
||||||
|
StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
metrics = {}
|
||||||
|
metrics.update({"Mean dice": round(dice.mean().item(), 2)})
|
||||||
|
metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)})
|
||||||
|
dllogger.log(step=(), data=metrics)
|
||||||
|
dllogger.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def layout_2d(img, lbl):
|
||||||
|
batch_size, depth, channels, height, weight = img.shape
|
||||||
|
img = torch.reshape(img, (batch_size * depth, channels, height, weight))
|
||||||
|
if lbl is not None:
|
||||||
|
lbl = torch.reshape(lbl, (batch_size * depth, 1, height, weight))
|
||||||
|
return img, lbl
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def get_split(data, idx):
|
||||||
|
return list(np.array(data)[idx])
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(path, files_pattern):
|
||||||
|
return sorted(glob.glob(os.path.join(path, files_pattern)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_path(args):
|
||||||
|
if args.data != "/data":
|
||||||
|
return args.data
|
||||||
|
data_path = os.path.join(args.data, get_task_code(args))
|
||||||
|
if args.exec_mode == "predict" and not args.benchmark:
|
||||||
|
data_path = os.path.join(data_path, "test")
|
||||||
|
return data_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_fnames(args, data_path, meta=None):
|
||||||
|
kfold = KFold(n_splits=args.nfolds, shuffle=True, random_state=12345)
|
||||||
|
test_imgs = load_data(data_path, "*_x.npy")
|
||||||
|
|
||||||
|
if args.exec_mode == "predict" and "val" in data_path:
|
||||||
|
_, val_idx = list(kfold.split(test_imgs))[args.fold]
|
||||||
|
test_imgs = sorted(get_split(test_imgs, val_idx))
|
||||||
|
if meta is not None:
|
||||||
|
meta = sorted(get_split(meta, val_idx))
|
||||||
|
|
||||||
|
return test_imgs, meta
|
||||||
|
|
||||||
|
|
||||||
def get_main_args(strings=None):
|
def get_main_args(strings=None):
|
||||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument(
|
arg = parser.add_argument
|
||||||
|
arg(
|
||||||
"--exec_mode",
|
"--exec_mode",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["train", "evaluate", "predict"],
|
choices=["train", "evaluate", "predict"],
|
||||||
default="train",
|
default="train",
|
||||||
help="Execution mode to run the model",
|
help="Execution mode to run the model",
|
||||||
)
|
)
|
||||||
parser.add_argument("--data", type=str, default="/data", help="Path to data directory")
|
arg("--data", type=str, default="/data", help="Path to data directory")
|
||||||
parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
|
arg("--results", type=str, default="/results", help="Path to results directory")
|
||||||
parser.add_argument("--logname", type=str, default=None, help="Name of dlloger output")
|
arg("--logname", type=str, default=None, help="Name of dlloger output")
|
||||||
parser.add_argument("--task", type=str, help="Task number. MSD uses numbers 01-10")
|
arg("--task", type=str, help="Task number. MSD uses numbers 01-10")
|
||||||
parser.add_argument("--gpus", type=non_negative_int, default=1, help="Number of gpus")
|
arg("--gpus", type=non_negative_int, default=1, help="Number of gpus")
|
||||||
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate")
|
arg("--learning_rate", type=float, default=0.001, help="Learning rate")
|
||||||
parser.add_argument("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
|
arg("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
|
||||||
parser.add_argument("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
|
arg("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
|
||||||
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
|
arg("--tta", action="store_true", help="Enable test time augmentation")
|
||||||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
arg("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||||
parser.add_argument("--benchmark", action="store_true", help="Run model benchmarking")
|
arg("--benchmark", action="store_true", help="Run model benchmarking")
|
||||||
parser.add_argument("--deep_supervision", action="store_true", help="Enable deep supervision")
|
arg("--deep_supervision", action="store_true", help="Enable deep supervision")
|
||||||
parser.add_argument("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm")
|
arg("--drop_block", action="store_true", help="Enable drop block")
|
||||||
parser.add_argument("--save_ckpt", action="store_true", help="Enable saving checkpoint")
|
arg("--attention", action="store_true", help="Enable attention in decoder")
|
||||||
parser.add_argument("--nfolds", type=positive_int, default=5, help="Number of cross-validation folds")
|
arg("--residual", action="store_true", help="Enable residual block in encoder")
|
||||||
parser.add_argument("--seed", type=non_negative_int, default=1, help="Random seed")
|
arg("--focal", action="store_true", help="Use focal loss instead of cross entropy")
|
||||||
parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint")
|
arg("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm")
|
||||||
parser.add_argument("--fold", type=non_negative_int, default=0, help="Fold number")
|
arg("--save_ckpt", action="store_true", help="Enable saving checkpoint")
|
||||||
parser.add_argument("--patience", type=positive_int, default=100, help="Early stopping patience")
|
arg("--nfolds", type=positive_int, default=5, help="Number of cross-validation folds")
|
||||||
parser.add_argument("--lr_patience", type=positive_int, default=70, help="Patience for ReduceLROnPlateau scheduler")
|
arg("--seed", type=non_negative_int, default=1, help="Random seed")
|
||||||
parser.add_argument("--batch_size", type=positive_int, default=2, help="Batch size")
|
arg("--skip_first_n_eval", type=non_negative_int, default=0, help="Skip the evaluation for the first n epochs.")
|
||||||
parser.add_argument("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
|
arg("--ckpt_path", type=str, default=None, help="Path to checkpoint")
|
||||||
parser.add_argument("--steps", nargs="+", type=positive_int, required=False, help="Steps for multistep scheduler")
|
arg("--fold", type=non_negative_int, default=0, help="Fold number")
|
||||||
parser.add_argument("--create_idx", action="store_true", help="Create index files for tfrecord")
|
arg("--patience", type=positive_int, default=100, help="Early stopping patience")
|
||||||
parser.add_argument("--profile", action="store_true", help="Run dlprof profiling")
|
arg("--lr_patience", type=positive_int, default=70, help="Patience for ReduceLROnPlateau scheduler")
|
||||||
parser.add_argument("--momentum", type=float, default=0.99, help="Momentum factor")
|
arg("--batch_size", type=positive_int, default=2, help="Batch size")
|
||||||
parser.add_argument("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
|
arg("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
|
||||||
parser.add_argument("--save_preds", action="store_true", help="Enable prediction saving")
|
arg("--steps", nargs="+", type=positive_int, required=False, help="Steps for multistep scheduler")
|
||||||
parser.add_argument("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
|
arg("--profile", action="store_true", help="Run dlprof profiling")
|
||||||
parser.add_argument("--resume_training", action="store_true", help="Resume training from the last checkpoint")
|
arg("--momentum", type=float, default=0.99, help="Momentum factor")
|
||||||
parser.add_argument("--factor", type=float, default=0.3, help="Scheduler factor")
|
arg("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
|
||||||
parser.add_argument(
|
arg("--save_preds", action="store_true", help="Enable prediction saving")
|
||||||
"--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading"
|
arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
|
||||||
|
arg("--resume_training", action="store_true", help="Resume training from the last checkpoint")
|
||||||
|
arg("--factor", type=float, default=0.3, help="Scheduler factor")
|
||||||
|
arg("--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading")
|
||||||
|
arg("--min_epochs", type=non_negative_int, default=30, help="Force training for at least these many epochs")
|
||||||
|
arg("--max_epochs", type=non_negative_int, default=10000, help="Stop training after this number of epochs")
|
||||||
|
arg("--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics")
|
||||||
|
arg("--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer")
|
||||||
|
arg("--nvol", type=positive_int, default=1, help="Number of volumes which come into single batch size for 2D model")
|
||||||
|
arg(
|
||||||
|
"--data2d_dim",
|
||||||
|
choices=[2, 3],
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Input data dimension for 2d model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--min_epochs", type=non_negative_int, default=30, help="Force training for at least these many epochs"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_epochs", type=non_negative_int, default=10000, help="Stop training after this number of epochs"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--oversampling",
|
"--oversampling",
|
||||||
type=float_0_1,
|
type=float_0_1,
|
||||||
default=0.33,
|
default=0.33,
|
||||||
help="Probability of crop to have some region with positive label",
|
help="Probability of crop to have some region with positive label",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--overlap",
|
"--overlap",
|
||||||
type=float_0_1,
|
type=float_0_1,
|
||||||
default=0.25,
|
default=0.5,
|
||||||
help="Amount of overlap between scans during sliding window inference",
|
help="Amount of overlap between scans during sliding window inference",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--affinity",
|
"--affinity",
|
||||||
type=str,
|
type=str,
|
||||||
default="socket_unique_interleaved",
|
default="socket_unique_interleaved",
|
||||||
|
@ -154,49 +253,41 @@ def get_main_args(strings=None):
|
||||||
],
|
],
|
||||||
help="type of CPU affinity",
|
help="type of CPU affinity",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--data2d_dim",
|
|
||||||
choices=[2, 3],
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Input data dimension for 2d model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--scheduler",
|
"--scheduler",
|
||||||
type=str,
|
type=str,
|
||||||
default="none",
|
default="none",
|
||||||
choices=["none", "multistep", "cosine", "plateau"],
|
choices=["none", "multistep", "cosine", "plateau"],
|
||||||
help="Learning rate scheduler",
|
help="Learning rate scheduler",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--optimizer",
|
"--optimizer",
|
||||||
type=str,
|
type=str,
|
||||||
default="radam",
|
default="radam",
|
||||||
choices=["sgd", "adam", "adamw", "radam", "fused_adam"],
|
choices=["sgd", "radam", "adam"],
|
||||||
help="Optimizer",
|
help="Optimizer",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--val_mode",
|
"--blend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["gaussian", "constant"],
|
choices=["gaussian", "constant"],
|
||||||
default="gaussian",
|
default="gaussian",
|
||||||
help="How to blend output of overlapping windows",
|
help="How to blend output of overlapping windows",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--train_batches",
|
"--train_batches",
|
||||||
type=non_negative_int,
|
type=non_negative_int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Limit number of batches for training (used for benchmarking mode only)",
|
help="Limit number of batches for training (used for benchmarking mode only)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
arg(
|
||||||
"--test_batches",
|
"--test_batches",
|
||||||
type=non_negative_int,
|
type=non_negative_int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Limit number of batches for inference (used for benchmarking mode only)",
|
help="Limit number of batches for inference (used for benchmarking mode only)",
|
||||||
)
|
)
|
||||||
if strings is not None:
|
if strings is not None:
|
||||||
parser.add_argument(
|
arg(
|
||||||
"strings",
|
"strings",
|
||||||
metavar="STRING",
|
metavar="STRING",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
|
|
Loading…
Reference in a new issue