Merge pull request #882 from michal2409/mfutrega/nnunet-fix

correct train.py script and add backward compability for dropblock
This commit is contained in:
nv-kkudrynski 2021-03-24 11:00:22 +01:00 committed by GitHub
commit 7863f0e486
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 36 deletions

View file

@ -507,7 +507,7 @@ The following sections provide details on how to achieve the same performance an
##### Training accuracy: NVIDIA DGX A100 (8x A100 80G)
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs.
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs.
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - TF32| Time to train speedup (TF32 to mixed precision)
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
@ -519,7 +519,7 @@ Our results were obtained by running the `python scripts/train.py --gpus {1,8} -
##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
| Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - FP32 | Time to train speedup (FP32 to mixed precision)
|:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
@ -580,7 +580,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre
FP16
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
| 2 | 64 | 4x192x160 | 3198.8 | 20.01 | 24.1 | 30.5 | 33.75 |
| 2 | 128 | 4x192x160 | 3587.89 | 35.68 | 36.0 | 36.08 | 36.16 |
@ -591,7 +591,7 @@ FP16
TF32
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
| 2 | 64 | 4x192x160 | 2353.27 | 27.2 | 27.43 | 27.53 | 27.7 |
| 2 | 128 | 4x192x160 | 2492.78 | 51.35 | 51.54 | 51.59 | 51.73 |
@ -610,7 +610,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre
FP16
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
| 2 | 64 | 4x192x160 | 1866.52 | 34.29 | 34.7 | 48.87 | 52.44 |
| 2 | 128 | 4x192x160 | 2032.74 | 62.97 | 63.21 | 63.25 | 63.32 |
@ -620,7 +620,7 @@ FP16
FP32
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
| 2 | 64 | 4x192x160 | 1051.46 | 60.87 | 61.21 | 61.48 | 62.87 |
| 2 | 128 | 4x192x160 | 1051.68 | 121.71 | 122.29 | 122.44 | 122.6 |

View file

@ -12,42 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import FocalLoss
class DiceLoss(nn.Module):
def __init__(self, include_background=False, smooth=1e-5, eps=1e-7):
super(DiceLoss, self).__init__()
self.include_background = include_background
self.smooth = smooth
self.dims = (0, 2)
self.eps = eps
def forward(self, y_pred, y_true):
num_classes, batch_size = y_pred.size(1), y_true.size(0)
y_pred = y_pred.log_softmax(dim=1).exp()
y_true, y_pred = y_true.view(batch_size, -1), y_pred.view(batch_size, num_classes, -1)
y_true = F.one_hot(y_true.to(torch.int64), num_classes).permute(0, 2, 1)
if not self.include_background:
y_true, y_pred = y_true[:, 1:], y_pred[:, 1:]
intersection = torch.sum(y_true * y_pred, dim=self.dims)
cardinality = torch.sum(y_true + y_pred, dim=self.dims)
dice_loss = 1 - (2.0 * intersection + self.smooth) / (cardinality + self.smooth).clamp_min(self.eps)
mask = (y_true.sum(self.dims) > 0).to(dice_loss.dtype)
dice_loss *= mask.to(dice_loss.dtype)
dice_loss = dice_loss.sum() / mask.sum()
return dice_loss
from monai.losses import DiceLoss, FocalLoss
class Loss(nn.Module):
def __init__(self, focal):
super(Loss, self).__init__()
self.dice = DiceLoss()
self.cross_entropy = nn.CrossEntropyLoss()
self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
self.focal = FocalLoss(gamma=2.0)
self.cross_entropy = nn.CrossEntropyLoss()
self.use_focal = focal
def forward(self, y_pred, y_true):

View file

@ -42,6 +42,8 @@ class NNUnet(pl.LightningModule):
def __init__(self, args):
super(NNUnet, self).__init__()
self.args = args
if not hasattr(self.args, "drop_block"): # For backward compability
self.args.drop_block = False
self.save_hyperparameters()
self.build_nnunet()
self.loss = Loss(self.args.focal)

View file

@ -24,11 +24,15 @@ parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4],
parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
parser.add_argument("--logname", type=str, default="log", help="Name of dlloger output")
if __name__ == "__main__":
args = parser.parse_args()
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
cmd = f"python {path_to_main} --exec_mode train --task {args.data} --deep_supervision --save_ckpt "
cmd = f"python {path_to_main} --exec_mode train --task {args.task} --deep_supervision --save_ckpt "
cmd += f"--results {args.results} "
cmd += f"--logname {args.logname} "
cmd += f"--dim {args.dim} "
cmd += f"--batch_size {2 if args.dim == 3 else 64} "
cmd += f"--val_batch_size {4 if args.dim == 3 else 64} "