299 lines
11 KiB
Python
Executable file
299 lines
11 KiB
Python
Executable file
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import glob
|
|
import os
|
|
import pickle
|
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
|
from subprocess import call
|
|
|
|
import numpy as np
|
|
import torch
|
|
from dllogger import JSONStreamBackend, Logger, StdOutBackend, Verbosity
|
|
from sklearn.model_selection import KFold
|
|
|
|
|
|
def is_main_process():
|
|
return int(os.getenv("LOCAL_RANK", "0")) == 0
|
|
|
|
|
|
def set_cuda_devices(args):
|
|
assert args.gpus <= torch.cuda.device_count(), f"Requested {args.gpus} gpus, available {torch.cuda.device_count()}."
|
|
device_list = ",".join([str(i) for i in range(args.gpus)])
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", device_list)
|
|
|
|
|
|
def verify_ckpt_path(args):
|
|
resume_path = os.path.join(args.results, "checkpoints", "last.ckpt")
|
|
ckpt_path = resume_path if args.resume_training and os.path.exists(resume_path) else args.ckpt_path
|
|
return ckpt_path
|
|
|
|
|
|
def get_task_code(args):
|
|
return f"{args.task}_{args.dim}d"
|
|
|
|
|
|
def get_config_file(args):
|
|
task_code = get_task_code(args)
|
|
if args.data != "/data":
|
|
path = os.path.join(args.data, "config.pkl")
|
|
else:
|
|
path = os.path.join(args.data, task_code, "config.pkl")
|
|
return pickle.load(open(path, "rb"))
|
|
|
|
|
|
def get_dllogger(results):
|
|
return Logger(
|
|
backends=[
|
|
JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, "logs.json")),
|
|
StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "),
|
|
]
|
|
)
|
|
|
|
|
|
def get_tta_flips(dim):
|
|
if dim == 2:
|
|
return [[2], [3], [2, 3]]
|
|
return [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]]
|
|
|
|
|
|
def make_empty_dir(path):
|
|
call(["rm", "-rf", path])
|
|
os.makedirs(path)
|
|
|
|
|
|
def flip(data, axis):
|
|
return torch.flip(data, dims=axis)
|
|
|
|
|
|
def positive_int(value):
|
|
ivalue = int(value)
|
|
assert ivalue > 0, f"Argparse error. Expected positive integer but got {value}"
|
|
return ivalue
|
|
|
|
|
|
def non_negative_int(value):
|
|
ivalue = int(value)
|
|
assert ivalue >= 0, f"Argparse error. Expected positive integer but got {value}"
|
|
return ivalue
|
|
|
|
|
|
def float_0_1(value):
|
|
ivalue = float(value)
|
|
assert 0 <= ivalue <= 1, f"Argparse error. Expected float to be in range (0, 1), but got {value}"
|
|
return ivalue
|
|
|
|
|
|
def get_unet_params(args):
|
|
config = get_config_file(args)
|
|
patch_size, spacings = config["patch_size"], config["spacings"]
|
|
strides, kernels, sizes = [], [], patch_size[:]
|
|
while True:
|
|
spacing_ratio = [spacing / min(spacings) for spacing in spacings]
|
|
stride = [2 if ratio <= 2 and size >= 2 * args.min_fmap else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
|
|
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
|
|
if all(s == 1 for s in stride):
|
|
break
|
|
sizes = [i / j for i, j in zip(sizes, stride)]
|
|
spacings = [i * j for i, j in zip(spacings, stride)]
|
|
kernels.append(kernel)
|
|
strides.append(stride)
|
|
if len(strides) == 6:
|
|
break
|
|
strides.insert(0, len(spacings) * [1])
|
|
kernels.append(len(spacings) * [3])
|
|
return config["in_channels"], config["n_class"], kernels, strides, patch_size
|
|
|
|
|
|
def log(logname, dice, results="/results"):
|
|
dllogger = Logger(
|
|
backends=[
|
|
JSONStreamBackend(Verbosity.VERBOSE, os.path.join(results, logname)),
|
|
StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""),
|
|
]
|
|
)
|
|
metrics = {}
|
|
metrics.update({"Mean dice": round(dice.mean().item(), 2)})
|
|
metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)})
|
|
dllogger.log(step=(), data=metrics)
|
|
dllogger.flush()
|
|
|
|
|
|
def layout_2d(img, lbl):
|
|
batch_size, depth, channels, height, weight = img.shape
|
|
img = torch.reshape(img, (batch_size * depth, channels, height, weight))
|
|
if lbl is not None:
|
|
lbl = torch.reshape(lbl, (batch_size * depth, 1, height, weight))
|
|
return img, lbl
|
|
return img
|
|
|
|
|
|
def get_split(data, idx):
|
|
return list(np.array(data)[idx])
|
|
|
|
|
|
def load_data(path, files_pattern):
|
|
return sorted(glob.glob(os.path.join(path, files_pattern)))
|
|
|
|
|
|
def get_path(args):
|
|
if args.data != "/data":
|
|
return args.data
|
|
data_path = os.path.join(args.data, get_task_code(args))
|
|
if args.exec_mode == "predict" and not args.benchmark:
|
|
data_path = os.path.join(data_path, "test")
|
|
return data_path
|
|
|
|
|
|
def get_test_fnames(args, data_path, meta=None):
|
|
kfold = KFold(n_splits=args.nfolds, shuffle=True, random_state=12345)
|
|
test_imgs = load_data(data_path, "*_x.npy")
|
|
|
|
if args.exec_mode == "predict" and "val" in data_path:
|
|
_, val_idx = list(kfold.split(test_imgs))[args.fold]
|
|
test_imgs = sorted(get_split(test_imgs, val_idx))
|
|
if meta is not None:
|
|
meta = sorted(get_split(meta, val_idx))
|
|
|
|
return test_imgs, meta
|
|
|
|
|
|
def get_main_args(strings=None):
|
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
|
arg = parser.add_argument
|
|
arg(
|
|
"--exec_mode",
|
|
type=str,
|
|
choices=["train", "evaluate", "predict"],
|
|
default="train",
|
|
help="Execution mode to run the model",
|
|
)
|
|
arg("--data", type=str, default="/data", help="Path to data directory")
|
|
arg("--results", type=str, default="/results", help="Path to results directory")
|
|
arg("--logname", type=str, default=None, help="Name of dlloger output")
|
|
arg("--task", type=str, help="Task number. MSD uses numbers 01-10")
|
|
arg("--gpus", type=non_negative_int, default=1, help="Number of gpus")
|
|
arg("--learning_rate", type=float, default=0.0008, help="Learning rate")
|
|
arg("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
|
|
arg("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
|
|
arg("--tta", action="store_true", help="Enable test time augmentation")
|
|
arg("--brats", action="store_true", help="Enable BraTS specific training and inference")
|
|
arg("--deep_supervision", action="store_true", help="Enable deep supervision")
|
|
arg("--more_chn", action="store_true", help="Create encoder with more channels")
|
|
arg("--amp", action="store_true", help="Enable automatic mixed precision")
|
|
arg("--benchmark", action="store_true", help="Run model benchmarking")
|
|
arg("--focal", action="store_true", help="Use focal loss instead of cross entropy")
|
|
arg("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm")
|
|
arg("--save_ckpt", action="store_true", help="Enable saving checkpoint")
|
|
arg("--nfolds", type=positive_int, default=5, help="Number of cross-validation folds")
|
|
arg("--seed", type=non_negative_int, default=1, help="Random seed")
|
|
arg("--skip_first_n_eval", type=non_negative_int, default=0, help="Skip the evaluation for the first n epochs.")
|
|
arg("--ckpt_path", type=str, default=None, help="Path to checkpoint")
|
|
arg("--fold", type=non_negative_int, default=0, help="Fold number")
|
|
arg("--patience", type=positive_int, default=100, help="Early stopping patience")
|
|
arg("--batch_size", type=positive_int, default=2, help="Batch size")
|
|
arg("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
|
|
arg("--profile", action="store_true", help="Run dlprof profiling")
|
|
arg("--momentum", type=float, default=0.99, help="Momentum factor")
|
|
arg("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
|
|
arg("--save_preds", action="store_true", help="Enable prediction saving")
|
|
arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
|
|
arg("--resume_training", action="store_true", help="Resume training from the last checkpoint")
|
|
arg("--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading")
|
|
arg("--epochs", type=non_negative_int, default=1000, help="Number of training epochs")
|
|
arg("--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics")
|
|
arg("--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer")
|
|
arg("--nvol", type=positive_int, default=1, help="Number of volumes which come into single batch size for 2D model")
|
|
arg(
|
|
"--data2d_dim",
|
|
choices=[2, 3],
|
|
type=int,
|
|
default=3,
|
|
help="Input data dimension for 2d model",
|
|
)
|
|
arg(
|
|
"--oversampling",
|
|
type=float_0_1,
|
|
default=0.33,
|
|
help="Probability of crop to have some region with positive label",
|
|
)
|
|
arg(
|
|
"--overlap",
|
|
type=float_0_1,
|
|
default=0.5,
|
|
help="Amount of overlap between scans during sliding window inference",
|
|
)
|
|
arg(
|
|
"--affinity",
|
|
type=str,
|
|
default="socket_unique_interleaved",
|
|
choices=[
|
|
"socket",
|
|
"single",
|
|
"single_unique",
|
|
"socket_unique_interleaved",
|
|
"socket_unique_continuous",
|
|
"disabled",
|
|
],
|
|
help="type of CPU affinity",
|
|
)
|
|
arg(
|
|
"--scheduler",
|
|
action="store_true",
|
|
help="Enable cosine rate scheduler with warmup",
|
|
)
|
|
arg(
|
|
"--optimizer",
|
|
type=str,
|
|
default="adam",
|
|
choices=["sgd", "adam"],
|
|
help="Optimizer",
|
|
)
|
|
arg(
|
|
"--min_fmap",
|
|
type=non_negative_int,
|
|
default=4,
|
|
help="The minimal size that feature map can be reduced in bottleneck",
|
|
)
|
|
arg(
|
|
"--blend",
|
|
type=str,
|
|
choices=["gaussian", "constant"],
|
|
default="gaussian",
|
|
help="How to blend output of overlapping windows",
|
|
)
|
|
arg(
|
|
"--train_batches",
|
|
type=non_negative_int,
|
|
default=0,
|
|
help="Limit number of batches for training (used for benchmarking mode only)",
|
|
)
|
|
arg(
|
|
"--test_batches",
|
|
type=non_negative_int,
|
|
default=0,
|
|
help="Limit number of batches for inference (used for benchmarking mode only)",
|
|
)
|
|
if strings is not None:
|
|
arg(
|
|
"strings",
|
|
metavar="STRING",
|
|
nargs="*",
|
|
help="String for searching",
|
|
)
|
|
args = parser.parse_args(strings.split())
|
|
else:
|
|
args = parser.parse_args()
|
|
return args
|