DeepLearningExamples/PyTorch/Segmentation/nnUNet/main.py

132 lines
5.2 KiB
Python
Executable file

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import nvidia_dlprof_pytorch_nvtx
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, early_stopping
from data_loading.data_module import DataModule
from models.nn_unet import NNUnet
from utils.gpu_affinity import set_affinity
from utils.logger import LoggingCallback
from utils.utils import get_main_args, is_main_process, log, make_empty_dir, set_cuda_devices, verify_ckpt_path
if __name__ == "__main__":
args = get_main_args()
if args.profile:
nvidia_dlprof_pytorch_nvtx.init()
print("Profiling enabled")
if args.affinity != "disabled":
affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)
set_cuda_devices(args)
seed_everything(args.seed)
data_module = DataModule(args)
data_module.prepare_data()
data_module.setup()
ckpt_path = verify_ckpt_path(args)
callbacks = None
model_ckpt = None
if args.benchmark:
model = NNUnet(args)
batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
log_dir = os.path.join(args.results, args.logname if args.logname is not None else "perf.json")
callbacks = [
LoggingCallback(
log_dir=log_dir,
global_batch_size=batch_size * args.gpus,
mode=args.exec_mode,
warmup=args.warmup,
dim=args.dim,
profile=args.profile,
)
]
elif args.exec_mode == "train":
model = NNUnet(args)
early_stopping = EarlyStopping(monitor="dice_mean", patience=args.patience, verbose=True, mode="max")
callbacks = [early_stopping]
if args.save_ckpt:
model_ckpt = ModelCheckpoint(
filename="{epoch}-{dice_mean:.2f}", monitor="dice_mean", mode="max", save_last=True
)
callbacks.append(model_ckpt)
else: # Evaluation or inference
if ckpt_path is not None:
model = NNUnet.load_from_checkpoint(ckpt_path)
else:
model = NNUnet(args)
trainer = Trainer(
logger=False,
gpus=args.gpus,
precision=16 if args.amp else 32,
benchmark=True,
deterministic=False,
min_epochs=args.epochs,
max_epochs=args.epochs,
sync_batchnorm=args.sync_batchnorm,
gradient_clip_val=args.gradient_clip_val,
callbacks=callbacks,
num_sanity_val_steps=0,
default_root_dir=args.results,
resume_from_checkpoint=ckpt_path,
accelerator="ddp" if args.gpus > 1 else None,
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
)
if args.benchmark:
if args.exec_mode == "train":
if args.profile:
with torch.autograd.profiler.emit_nvtx():
trainer.fit(model, train_dataloader=data_module.train_dataloader())
else:
trainer.fit(model, train_dataloader=data_module.train_dataloader())
else:
# warmup
trainer.test(model, test_dataloaders=data_module.test_dataloader())
# benchmark run
trainer.current_epoch = 1
trainer.test(model, test_dataloaders=data_module.test_dataloader())
elif args.exec_mode == "train":
trainer.fit(model, data_module)
if is_main_process():
logname = args.logname if args.logname is not None else "train_log.json"
log(logname, torch.tensor(model.best_mean_dice), results=args.results)
elif args.exec_mode == "evaluate":
model.args = args
trainer.test(model, test_dataloaders=data_module.val_dataloader())
if is_main_process():
logname = args.logname if args.logname is not None else "eval_log.json"
log(logname, model.eval_dice, results=args.results)
elif args.exec_mode == "predict":
if args.save_preds:
ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
dir_name = f"predictions_{ckpt_name}"
dir_name += f"_task={model.args.task}_fold={model.args.fold}"
if args.tta:
dir_name += "_tta"
save_dir = os.path.join(args.results, dir_name)
model.save_dir = save_dir
make_empty_dir(save_dir)
model.args = args
trainer.test(model, test_dataloaders=data_module.test_dataloader())