132 lines
5.2 KiB
Python
Executable file
132 lines
5.2 KiB
Python
Executable file
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
import nvidia_dlprof_pytorch_nvtx
|
|
import torch
|
|
from pytorch_lightning import Trainer, seed_everything
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, early_stopping
|
|
|
|
from data_loading.data_module import DataModule
|
|
from models.nn_unet import NNUnet
|
|
from utils.gpu_affinity import set_affinity
|
|
from utils.logger import LoggingCallback
|
|
from utils.utils import get_main_args, is_main_process, log, make_empty_dir, set_cuda_devices, verify_ckpt_path
|
|
|
|
if __name__ == "__main__":
|
|
args = get_main_args()
|
|
|
|
if args.profile:
|
|
nvidia_dlprof_pytorch_nvtx.init()
|
|
print("Profiling enabled")
|
|
|
|
if args.affinity != "disabled":
|
|
affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)
|
|
|
|
set_cuda_devices(args)
|
|
seed_everything(args.seed)
|
|
data_module = DataModule(args)
|
|
data_module.prepare_data()
|
|
data_module.setup()
|
|
ckpt_path = verify_ckpt_path(args)
|
|
|
|
callbacks = None
|
|
model_ckpt = None
|
|
if args.benchmark:
|
|
model = NNUnet(args)
|
|
batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
|
|
log_dir = os.path.join(args.results, args.logname if args.logname is not None else "perf.json")
|
|
callbacks = [
|
|
LoggingCallback(
|
|
log_dir=log_dir,
|
|
global_batch_size=batch_size * args.gpus,
|
|
mode=args.exec_mode,
|
|
warmup=args.warmup,
|
|
dim=args.dim,
|
|
profile=args.profile,
|
|
)
|
|
]
|
|
elif args.exec_mode == "train":
|
|
model = NNUnet(args)
|
|
early_stopping = EarlyStopping(monitor="dice_mean", patience=args.patience, verbose=True, mode="max")
|
|
callbacks = [early_stopping]
|
|
if args.save_ckpt:
|
|
model_ckpt = ModelCheckpoint(
|
|
filename="{epoch}-{dice_mean:.2f}", monitor="dice_mean", mode="max", save_last=True
|
|
)
|
|
callbacks.append(model_ckpt)
|
|
else: # Evaluation or inference
|
|
if ckpt_path is not None:
|
|
model = NNUnet.load_from_checkpoint(ckpt_path)
|
|
else:
|
|
model = NNUnet(args)
|
|
|
|
trainer = Trainer(
|
|
logger=False,
|
|
gpus=args.gpus,
|
|
precision=16 if args.amp else 32,
|
|
benchmark=True,
|
|
deterministic=False,
|
|
min_epochs=args.epochs,
|
|
max_epochs=args.epochs,
|
|
sync_batchnorm=args.sync_batchnorm,
|
|
gradient_clip_val=args.gradient_clip_val,
|
|
callbacks=callbacks,
|
|
num_sanity_val_steps=0,
|
|
default_root_dir=args.results,
|
|
resume_from_checkpoint=ckpt_path,
|
|
accelerator="ddp" if args.gpus > 1 else None,
|
|
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
|
|
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
|
|
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
|
|
)
|
|
|
|
if args.benchmark:
|
|
if args.exec_mode == "train":
|
|
if args.profile:
|
|
with torch.autograd.profiler.emit_nvtx():
|
|
trainer.fit(model, train_dataloader=data_module.train_dataloader())
|
|
else:
|
|
trainer.fit(model, train_dataloader=data_module.train_dataloader())
|
|
else:
|
|
# warmup
|
|
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
|
# benchmark run
|
|
trainer.current_epoch = 1
|
|
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
|
elif args.exec_mode == "train":
|
|
trainer.fit(model, data_module)
|
|
if is_main_process():
|
|
logname = args.logname if args.logname is not None else "train_log.json"
|
|
log(logname, torch.tensor(model.best_mean_dice), results=args.results)
|
|
elif args.exec_mode == "evaluate":
|
|
model.args = args
|
|
trainer.test(model, test_dataloaders=data_module.val_dataloader())
|
|
if is_main_process():
|
|
logname = args.logname if args.logname is not None else "eval_log.json"
|
|
log(logname, model.eval_dice, results=args.results)
|
|
elif args.exec_mode == "predict":
|
|
if args.save_preds:
|
|
ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
|
|
dir_name = f"predictions_{ckpt_name}"
|
|
dir_name += f"_task={model.args.task}_fold={model.args.fold}"
|
|
if args.tta:
|
|
dir_name += "_tta"
|
|
save_dir = os.path.join(args.results, dir_name)
|
|
model.save_dir = save_dir
|
|
make_empty_dir(save_dir)
|
|
model.args = args
|
|
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|