diff --git a/PyTorch/Segmentation/nnUNet/README.md b/PyTorch/Segmentation/nnUNet/README.md index d1372607..fad2f87f 100755 --- a/PyTorch/Segmentation/nnUNet/README.md +++ b/PyTorch/Segmentation/nnUNet/README.md @@ -507,7 +507,7 @@ The following sections provide details on how to achieve the same performance an ##### Training accuracy: NVIDIA DGX A100 (8x A100 80G) -Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs. +Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs. | Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - TF32| Time to train speedup (TF32 to mixed precision) |:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:| @@ -519,7 +519,7 @@ Our results were obtained by running the `python scripts/train.py --gpus {1,8} - ##### Training accuracy: NVIDIA DGX-1 (8x V100 16G) -Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs. +Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs. | Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - FP32 | Time to train speedup (FP32 to mixed precision) |:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:| @@ -580,7 +580,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre FP16 -| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | +| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:| | 2 | 64 | 4x192x160 | 3198.8 | 20.01 | 24.1 | 30.5 | 33.75 | | 2 | 128 | 4x192x160 | 3587.89 | 35.68 | 36.0 | 36.08 | 36.16 | @@ -591,7 +591,7 @@ FP16 TF32 -| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | +| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:| | 2 | 64 | 4x192x160 | 2353.27 | 27.2 | 27.43 | 27.53 | 27.7 | | 2 | 128 | 4x192x160 | 2492.78 | 51.35 | 51.54 | 51.59 | 51.73 | @@ -610,7 +610,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre FP16 -| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | +| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:| | 2 | 64 | 4x192x160 | 1866.52 | 34.29 | 34.7 | 48.87 | 52.44 | | 2 | 128 | 4x192x160 | 2032.74 | 62.97 | 63.21 | 63.25 | 63.32 | @@ -620,7 +620,7 @@ FP16 FP32 -| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | +| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] | |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:| | 2 | 64 | 4x192x160 | 1051.46 | 60.87 | 61.21 | 61.48 | 62.87 | | 2 | 128 | 4x192x160 | 1051.68 | 121.71 | 122.29 | 122.44 | 122.6 | diff --git a/PyTorch/Segmentation/nnUNet/models/loss.py b/PyTorch/Segmentation/nnUNet/models/loss.py index 1b0eca41..88ebe167 100644 --- a/PyTorch/Segmentation/nnUNet/models/loss.py +++ b/PyTorch/Segmentation/nnUNet/models/loss.py @@ -12,42 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import torch.nn as nn -import torch.nn.functional as F -from monai.losses import FocalLoss - - -class DiceLoss(nn.Module): - def __init__(self, include_background=False, smooth=1e-5, eps=1e-7): - super(DiceLoss, self).__init__() - self.include_background = include_background - self.smooth = smooth - self.dims = (0, 2) - self.eps = eps - - def forward(self, y_pred, y_true): - num_classes, batch_size = y_pred.size(1), y_true.size(0) - y_pred = y_pred.log_softmax(dim=1).exp() - y_true, y_pred = y_true.view(batch_size, -1), y_pred.view(batch_size, num_classes, -1) - y_true = F.one_hot(y_true.to(torch.int64), num_classes).permute(0, 2, 1) - if not self.include_background: - y_true, y_pred = y_true[:, 1:], y_pred[:, 1:] - intersection = torch.sum(y_true * y_pred, dim=self.dims) - cardinality = torch.sum(y_true + y_pred, dim=self.dims) - dice_loss = 1 - (2.0 * intersection + self.smooth) / (cardinality + self.smooth).clamp_min(self.eps) - mask = (y_true.sum(self.dims) > 0).to(dice_loss.dtype) - dice_loss *= mask.to(dice_loss.dtype) - dice_loss = dice_loss.sum() / mask.sum() - return dice_loss +from monai.losses import DiceLoss, FocalLoss class Loss(nn.Module): def __init__(self, focal): super(Loss, self).__init__() - self.dice = DiceLoss() - self.cross_entropy = nn.CrossEntropyLoss() + self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True) self.focal = FocalLoss(gamma=2.0) + self.cross_entropy = nn.CrossEntropyLoss() self.use_focal = focal def forward(self, y_pred, y_true): diff --git a/PyTorch/Segmentation/nnUNet/models/nn_unet.py b/PyTorch/Segmentation/nnUNet/models/nn_unet.py index 34f9ba3c..20e8850d 100644 --- a/PyTorch/Segmentation/nnUNet/models/nn_unet.py +++ b/PyTorch/Segmentation/nnUNet/models/nn_unet.py @@ -42,6 +42,8 @@ class NNUnet(pl.LightningModule): def __init__(self, args): super(NNUnet, self).__init__() self.args = args + if not hasattr(self.args, "drop_block"): # For backward compability + self.args.drop_block = False self.save_hyperparameters() self.build_nnunet() self.loss = Loss(self.args.focal) diff --git a/PyTorch/Segmentation/nnUNet/scripts/train.py b/PyTorch/Segmentation/nnUNet/scripts/train.py index 7945a426..6d8431cb 100644 --- a/PyTorch/Segmentation/nnUNet/scripts/train.py +++ b/PyTorch/Segmentation/nnUNet/scripts/train.py @@ -24,11 +24,15 @@ parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet") parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision") parser.add_argument("--tta", action="store_true", help="Enable test time augmentation") +parser.add_argument("--results", type=str, default="/results", help="Path to results directory") +parser.add_argument("--logname", type=str, default="log", help="Name of dlloger output") if __name__ == "__main__": args = parser.parse_args() path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py") - cmd = f"python {path_to_main} --exec_mode train --task {args.data} --deep_supervision --save_ckpt " + cmd = f"python {path_to_main} --exec_mode train --task {args.task} --deep_supervision --save_ckpt " + cmd += f"--results {args.results} " + cmd += f"--logname {args.logname} " cmd += f"--dim {args.dim} " cmd += f"--batch_size {2 if args.dim == 3 else 64} " cmd += f"--val_batch_size {4 if args.dim == 3 else 64} "