[nnUNet/PyT] Add Jupyter notebook with BraTS21 solution

This commit is contained in:
Michal Futrega 2021-10-20 10:03:05 -07:00 committed by Krzysztof Kudrynski
parent 8c98f155e0
commit 028534a5b9
20 changed files with 1064 additions and 234 deletions

View file

@ -1,30 +1,24 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.02-py3 ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.02-py3
FROM ${FROM_IMAGE_NAME} FROM ${FROM_IMAGE_NAME}
ADD . /workspace/nnunet_pyt ADD ./triton/requirements.txt /
WORKDIR /workspace/nnunet_pyt RUN pip install --disable-pip-version-check -r /requirements.txt
RUN apt-get update && apt-get install -y libb64-dev libb64-0d
ADD ./requirements.txt /
RUN pip install --upgrade pip RUN pip install --upgrade pip
RUN pip install --disable-pip-version-check -r requirements.txt RUN pip install --disable-pip-version-check -r /requirements.txt
RUN pip install --disable-pip-version-check -r triton/requirements.txt RUN pip install monai==0.7.0 --no-dependencies
RUN pip install pytorch-lightning==1.0.0 --no-dependencies RUN pip install numpy --upgrade
RUN pip install torchtext==0.6.0 --no-dependencies
RUN pip install monai==0.4.0 --no-dependencies
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.30.0
RUN pip install torch_optimizer==0.0.1a15 --no-dependencies
RUN pip install numpy==1.20.3
RUN pip install nvidia-pyindex==1.0.9 RUN pip install nvidia-pyindex==1.0.9
RUN pip install nvidia-dlprof==1.2.0 RUN pip install nvidia-dlprof==1.2.0
RUN pip install nvidia_dlprof_pytorch_nvtx==1.2.0 RUN pip install nvidia_dlprof_pytorch_nvtx==1.2.0
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==1.5.0
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
RUN unzip -qq awscliv2.zip RUN unzip -qq awscliv2.zip
RUN ./aws/install RUN ./aws/install
RUN rm -rf awscliv2.zip aws RUN rm -rf awscliv2.zip aws
# Install Perf Client required library WORKDIR /workspace/nnunet_pyt
RUN apt-get update && apt-get install -y libb64-dev libb64-0d ADD . /workspace/nnunet_pyt
# Install Triton Client Python API and copy Perf Client
#COPY --from=triton-client /workspace/install/ /workspace/install/
#RUN pip install /workspace/install/python/triton*.whl

View file

@ -69,7 +69,7 @@ The following figure shows the architecture of the 3D U-Net model and its differ
All convolution blocks in U-Net in both encoder and decoder are using two convolution layers followed by instance normalization and a leaky ReLU nonlinearity. For downsampling we are using stride convolution whereas transposed convolution for upsampling. All convolution blocks in U-Net in both encoder and decoder are using two convolution layers followed by instance normalization and a leaky ReLU nonlinearity. For downsampling we are using stride convolution whereas transposed convolution for upsampling.
All models were trained with RAdam optimizer, learning rate 0.001 and weight decay 0.0001. For loss function we use the average of [cross-entropy](https://en.wikipedia.org/wiki/Cross_entropy) and [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient). All models were trained with Adam optimizer, learning rate 0.0008 and weight decay 0.0001. For loss function we use the average of [cross-entropy](https://en.wikipedia.org/wiki/Cross_entropy) and [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient).
Early stopping is triggered if validation dice score wasn't improved during the last 100 epochs. Early stopping is triggered if validation dice score wasn't improved during the last 100 epochs.
@ -304,7 +304,7 @@ To see the full list of available options and their descriptions, use the `-h` o
The following example output is printed when running the model: The following example output is printed when running the model:
``` ```
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,radam,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES] usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
@ -370,8 +370,8 @@ optional arguments:
type of CPU affinity (default: socket_unique_interleaved) type of CPU affinity (default: socket_unique_interleaved)
--scheduler {none,multistep,cosine,plateau} --scheduler {none,multistep,cosine,plateau}
Learning rate scheduler (default: none) Learning rate scheduler (default: none)
--optimizer {sgd,radam,adam} --optimizer {sgd,adam}
Optimizer (default: radam) Optimizer (default: adam)
--blend {gaussian,constant} --blend {gaussian,constant}
How to blend output of overlapping windows (default: gaussian) How to blend output of overlapping windows (default: gaussian)
--train_batches TRAIN_BATCHES --train_batches TRAIN_BATCHES
@ -418,7 +418,7 @@ If you have dataset in other format or you want customize data preprocessing or
### Training process ### Training process
The model trains for at least `--min_epochs` and at most `--max_epochs` epochs. After each epoch evaluation, the validation set is done and validation loss is monitored for early stopping (see `--patience` flag). Default training settings are: The model trains for at least `--min_epochs` and at most `--max_epochs` epochs. After each epoch evaluation, the validation set is done and validation loss is monitored for early stopping (see `--patience` flag). Default training settings are:
* RAdam optimizer with learning rate of 0.001 and weight decay 0.0001. * Adam optimizer with learning rate of 0.0008 and weight decay 0.0001.
* Training batch size is set to 2 for 3D U-Net and 16 for 2D U-Net. * Training batch size is set to 2 for 3D U-Net and 16 for 2D U-Net.
This default parametrization is applied when running scripts from the `scripts/` directory and when running `main.py` without explicitly overriding these parameters. By default, the training is in full precision. To enable AMP, pass the `--amp` flag. AMP can be enabled for every mode of execution. This default parametrization is applied when running scripts from the `scripts/` directory and when running `main.py` without explicitly overriding these parameters. By default, the training is in full precision. To enable AMP, pass the `--amp` flag. AMP can be enabled for every mode of execution.
@ -454,8 +454,6 @@ The script will then:
## Performance ## Performance
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIAs latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
### Benchmarking ### Benchmarking
The following section shows how to run benchmarks to measure the model performance in training and inference modes. The following section shows how to run benchmarks to measure the model performance in training and inference modes.
@ -633,6 +631,9 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
### Changelog ### Changelog
October 2021
- Add Jupyter Notebook with BraTS solution
May 2021 May 2021
- Add Triton Inference Server support - Add Triton Inference Server support
- Removed deep supervision, attention and drop block - Removed deep supervision, attention and drop block

View file

@ -25,7 +25,7 @@ from nvidia.dali.plugin.pytorch import DALIGenericIterator
def get_numpy_reader(files, shard_id, num_shards, seed, shuffle): def get_numpy_reader(files, shard_id, num_shards, seed, shuffle):
return ops.NumpyReader( return ops.readers.Numpy(
seed=seed, seed=seed,
files=files, files=files,
device="cpu", device="cpu",
@ -42,6 +42,7 @@ class TrainPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, **kwargs): def __init__(self, batch_size, num_threads, device_id, **kwargs):
super(TrainPipeline, self).__init__(batch_size, num_threads, device_id) super(TrainPipeline, self).__init__(batch_size, num_threads, device_id)
self.dim = kwargs["dim"] self.dim = kwargs["dim"]
self.internal_seed = kwargs["seed"]
self.oversampling = kwargs["oversampling"] self.oversampling = kwargs["oversampling"]
self.input_x = get_numpy_reader( self.input_x = get_numpy_reader(
num_shards=kwargs["gpus"], num_shards=kwargs["gpus"],
@ -60,6 +61,7 @@ class TrainPipeline(Pipeline):
self.patch_size = kwargs["patch_size"] self.patch_size = kwargs["patch_size"]
if self.dim == 2: if self.dim == 2:
self.patch_size = [kwargs["batch_size_2d"]] + self.patch_size self.patch_size = [kwargs["batch_size_2d"]] + self.patch_size
self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64) self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT) self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
@ -69,7 +71,7 @@ class TrainPipeline(Pipeline):
return img, lbl return img, lbl
def random_augmentation(self, probability, augmented, original): def random_augmentation(self, probability, augmented, original):
condition = fn.cast(fn.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL) condition = fn.cast(fn.random.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
neg_condition = condition ^ True neg_condition = condition ^ True
return condition * augmented + neg_condition * original return condition * augmented + neg_condition * original
@ -77,45 +79,60 @@ class TrainPipeline(Pipeline):
def slice_fn(img): def slice_fn(img):
return fn.slice(img, 1, 3, axes=[0]) return fn.slice(img, 1, 3, axes=[0])
def crop_fn(self, img, lbl): def biased_crop_fn(self, img, label):
center = fn.segmentation.random_mask_pixel(lbl, foreground=fn.coin_flip(probability=self.oversampling)) roi_start, roi_end = fn.segmentation.random_object_bbox(
crop_anchor = self.slice_fn(center) - self.crop_shape // 2 label,
adjusted_anchor = math.max(0, crop_anchor) format="start_end",
max_anchor = self.slice_fn(fn.shapes(lbl)) - self.crop_shape foreground_prob=self.oversampling,
crop_anchor = math.min(adjusted_anchor, max_anchor) background=0,
img = fn.slice(img.gpu(), crop_anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad") seed=self.internal_seed,
lbl = fn.slice(lbl.gpu(), crop_anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad") device="cpu",
return img, lbl cache_objects=True,
)
anchor = fn.roi_random_crop(label, roi_start=roi_start, roi_end=roi_end, crop_shape=[1, *self.patch_size])
anchor = fn.slice(anchor, 1, 3, axes=[0]) # drop channels from anchor
img, label = fn.slice(
[img, label], anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad", device="cpu"
)
return img.gpu(), label.gpu()
def zoom_fn(self, img, lbl): def zoom_fn(self, img, lbl):
resized_shape = self.crop_shape * self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.0)), 1.0) scale = self.random_augmentation(0.15, fn.random.uniform(range=(0.7, 1.0)), 1.0)
img, lbl = fn.crop(img, crop=resized_shape), fn.crop(lbl, crop=resized_shape) d, h, w = [scale * x for x in self.patch_size]
if self.dim == 2:
d = self.patch_size[0]
img, lbl = fn.crop(img, crop_h=h, crop_w=w, crop_d=d), fn.crop(lbl, crop_h=h, crop_w=w, crop_d=d)
img = fn.resize(img, interp_type=types.DALIInterpType.INTERP_CUBIC, size=self.crop_shape_float) img = fn.resize(img, interp_type=types.DALIInterpType.INTERP_CUBIC, size=self.crop_shape_float)
lbl = fn.resize(lbl, interp_type=types.DALIInterpType.INTERP_NN, size=self.crop_shape_float) lbl = fn.resize(lbl, interp_type=types.DALIInterpType.INTERP_NN, size=self.crop_shape_float)
return img, lbl return img, lbl
def noise_fn(self, img): def noise_fn(self, img):
img_noised = img + fn.random.normal(img, stddev=fn.uniform(range=(0.0, 0.33))) img_noised = img + fn.random.normal(img, stddev=fn.random.uniform(range=(0.0, 0.33)))
return self.random_augmentation(0.15, img_noised, img) return self.random_augmentation(0.15, img_noised, img)
def blur_fn(self, img): def blur_fn(self, img):
img_blured = fn.gaussian_blur(img, sigma=fn.uniform(range=(0.5, 1.5))) img_blurred = fn.gaussian_blur(img, sigma=fn.random.uniform(range=(0.5, 1.5)))
return self.random_augmentation(0.15, img_blured, img) return self.random_augmentation(0.15, img_blurred, img)
def brightness_fn(self, img): def brightness_fn(self, img):
brightness_scale = self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.3)), 1.0) brightness_scale = self.random_augmentation(0.15, fn.random.uniform(range=(0.7, 1.3)), 1.0)
return img * brightness_scale return img * brightness_scale
def contrast_fn(self, img): def contrast_fn(self, img):
min_, max_ = fn.reductions.min(img), fn.reductions.max(img) min_, max_ = fn.reductions.min(img), fn.reductions.max(img)
scale = self.random_augmentation(0.15, fn.uniform(range=(0.65, 1.5)), 1.0) scale = self.random_augmentation(0.15, fn.random.uniform(range=(0.65, 1.5)), 1.0)
img = math.clamp(img * scale, min_, max_) img = math.clamp(img * scale, min_, max_)
return img return img
def flips_fn(self, img, lbl): def flips_fn(self, img, lbl):
kwargs = {"horizontal": fn.coin_flip(probability=0.33), "vertical": fn.coin_flip(probability=0.33)} kwargs = {
"horizontal": fn.random.coin_flip(probability=0.33),
"vertical": fn.random.coin_flip(probability=0.33),
}
if self.dim == 3: if self.dim == 3:
kwargs.update({"depthwise": fn.coin_flip(probability=0.33)}) kwargs.update({"depthwise": fn.random.coin_flip(probability=0.33)})
return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs) return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
def transpose_fn(self, img, lbl): def transpose_fn(self, img, lbl):
@ -124,7 +141,7 @@ class TrainPipeline(Pipeline):
def define_graph(self): def define_graph(self):
img, lbl = self.load_data() img, lbl = self.load_data()
img, lbl = self.crop_fn(img, lbl) img, lbl = self.biased_crop_fn(img, lbl)
img, lbl = self.zoom_fn(img, lbl) img, lbl = self.zoom_fn(img, lbl)
img, lbl = self.flips_fn(img, lbl) img, lbl = self.flips_fn(img, lbl)
img = self.noise_fn(img) img = self.noise_fn(img)
@ -141,15 +158,15 @@ class EvalPipeline(Pipeline):
super(EvalPipeline, self).__init__(batch_size, num_threads, device_id) super(EvalPipeline, self).__init__(batch_size, num_threads, device_id)
self.input_x = get_numpy_reader( self.input_x = get_numpy_reader(
files=kwargs["imgs"], files=kwargs["imgs"],
shard_id=device_id, shard_id=0,
num_shards=kwargs["gpus"], num_shards=1,
seed=kwargs["seed"], seed=kwargs["seed"],
shuffle=False, shuffle=False,
) )
self.input_y = get_numpy_reader( self.input_y = get_numpy_reader(
files=kwargs["lbls"], files=kwargs["lbls"],
shard_id=device_id, shard_id=0,
num_shards=kwargs["gpus"], num_shards=1,
seed=kwargs["seed"], seed=kwargs["seed"],
shuffle=False, shuffle=False,
) )
@ -281,6 +298,12 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
imgs = list(itertools.chain(*(100 * [imgs])))[: nbs * kwargs["gpus"]] imgs = list(itertools.chain(*(100 * [imgs])))[: nbs * kwargs["gpus"]]
lbls = list(itertools.chain(*(100 * [lbls])))[: nbs * kwargs["gpus"]] lbls = list(itertools.chain(*(100 * [lbls])))[: nbs * kwargs["gpus"]]
if mode == "eval": # To avoid padding for the multigpu evaluation.
rank = int(os.getenv("LOCAL_RANK", "0"))
imgs, lbls = np.array_split(imgs, kwargs["gpus"]), np.array_split(lbls, kwargs["gpus"])
imgs, lbls = [list(x) for x in imgs], [list(x) for x in lbls]
imgs, lbls = imgs[rank], lbls[rank]
pipe_kwargs = { pipe_kwargs = {
"imgs": imgs, "imgs": imgs,
"lbls": lbls, "lbls": lbls,

View file

@ -57,8 +57,7 @@ class DataModule(LightningDataModule):
self.val_imgs = get_split(imgs, val_idx) self.val_imgs = get_split(imgs, val_idx)
self.val_lbls = get_split(lbls, val_idx) self.val_lbls = get_split(lbls, val_idx)
if is_main_process(): if is_main_process():
ntrain, nval = len(self.train_imgs), len(self.val_imgs) print(f"Number of examples: Train {len(self.train_imgs)} - Val {len(self.val_imgs)}")
print(f"Number of examples: Train {ntrain} - Val {nval}")
elif is_main_process(): elif is_main_process():
print(f"Number of test examples: {len(self.test_imgs)}") print(f"Number of test examples: {len(self.test_imgs)}")

View file

@ -23,6 +23,8 @@ task = {
"08": "Task08_HepaticVessel", "08": "Task08_HepaticVessel",
"09": "Task09_Spleen", "09": "Task09_Spleen",
"10": "Task10_Colon", "10": "Task10_Colon",
"11": "BraTS2021_train",
"12": "BraTS2021_val",
} }
patch_size = { patch_size = {
@ -36,6 +38,8 @@ patch_size = {
"08_3d": [64, 192, 192], "08_3d": [64, 192, 192],
"09_3d": [64, 192, 160], "09_3d": [64, 192, 160],
"10_3d": [56, 192, 160], "10_3d": [56, 192, 160],
"11_3d": [128, 128, 128],
"12_3d": [128, 128, 128],
"01_2d": [192, 160], "01_2d": [192, 160],
"02_2d": [320, 256], "02_2d": [320, 256],
"03_2d": [512, 512], "03_2d": [512, 512],
@ -60,7 +64,8 @@ spacings = {
"08_3d": [1.5, 0.8, 0.8], "08_3d": [1.5, 0.8, 0.8],
"09_3d": [1.6, 0.79, 0.79], "09_3d": [1.6, 0.79, 0.79],
"10_3d": [3, 0.78, 0.78], "10_3d": [3, 0.78, 0.78],
"11_3d": [5, 0.741, 0.741], "11_3d": [1.0, 1.0, 1.0],
"12_3d": [1.0, 1.0, 1.0],
"01_2d": [1.0, 1.0], "01_2d": [1.0, 1.0],
"02_2d": [1.25, 1.25], "02_2d": [1.25, 1.25],
"03_2d": [0.7676, 0.7676], "03_2d": [0.7676, 0.7676],
@ -80,7 +85,6 @@ ct_min = {
"08": -3, "08": -3,
"09": -41, "09": -41,
"10": -30, "10": -30,
"11": -958,
} }
ct_max = { ct_max = {
@ -90,9 +94,8 @@ ct_max = {
"08": 243, "08": 243,
"09": 176, "09": 176,
"10": 165.82, "10": 165.82,
"11": 93,
} }
ct_mean = {"03": 99.4, "06": -158.58, "07": 77.9, "08": 104.37, "09": 99.29, "10": 62.18, "11": -547.7} ct_mean = {"03": 99.4, "06": -158.58, "07": 77.9, "08": 104.37, "09": 99.29, "10": 62.18}
ct_std = {"03": 39.36, "06": 324.7, "07": 75.4, "08": 52.62, "09": 39.47, "10": 32.65, "11": 281.08} ct_std = {"03": 39.36, "06": 324.7, "07": 75.4, "08": 52.62, "09": 39.47, "10": 32.65}

View file

@ -22,7 +22,6 @@ import monai.transforms as transforms
import nibabel import nibabel
import numpy as np import numpy as np
from joblib import Parallel, delayed from joblib import Parallel, delayed
from skimage.morphology import dilation, erosion, square
from skimage.transform import resize from skimage.transform import resize
from utils.utils import get_task_code, make_empty_dir from utils.utils import get_task_code, make_empty_dir
@ -39,32 +38,34 @@ class Preprocessor:
self.target_spacing = None self.target_spacing = None
self.task = args.task self.task = args.task
self.task_code = get_task_code(args) self.task_code = get_task_code(args)
self.verbose = args.verbose
self.patch_size = patch_size[self.task_code] self.patch_size = patch_size[self.task_code]
self.training = args.exec_mode == "training" self.training = args.exec_mode == "training"
self.data_path = os.path.join(args.data, task[args.task]) self.data_path = os.path.join(args.data, task[args.task])
metadata_path = os.path.join(self.data_path, "dataset.json")
self.metadata = json.load(open(metadata_path, "r"))
self.modality = self.metadata["modality"]["0"]
self.results = os.path.join(args.results, self.task_code) self.results = os.path.join(args.results, self.task_code)
if not self.training: if not self.training:
self.results = os.path.join(self.results, self.args.exec_mode) self.results = os.path.join(self.results, self.args.exec_mode)
self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image") self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=False, channel_wise=True) nonzero = True if self.modality != "CT" else False # normalize only non-zero region for MRI
metadata_path = os.path.join(self.data_path, "dataset.json") self.normalize_intensity = transforms.NormalizeIntensity(nonzero=nonzero, channel_wise=True)
if self.args.exec_mode == "val": if self.args.exec_mode == "val":
dataset_json = json.load(open(metadata_path, "r")) dataset_json = json.load(open(metadata_path, "r"))
dataset_json["val"] = dataset_json["training"] dataset_json["val"] = dataset_json["training"]
with open(metadata_path, "w") as outfile: with open(metadata_path, "w") as outfile:
json.dump(dataset_json, outfile) json.dump(dataset_json, outfile)
self.metadata = json.load(open(metadata_path, "r"))
self.modality = self.metadata["modality"]["0"]
def run(self): def run(self):
make_empty_dir(self.results) make_empty_dir(self.results)
print(f"Preprocessing {self.data_path}") print(f"Preprocessing {self.data_path}")
try: try:
self.target_spacing = spacings[self.task_code] self.target_spacing = spacings[self.task_code]
except: except:
self.collect_spacings() self.collect_spacings()
print(f"Target spacing {self.target_spacing}") if self.verbose:
print(f"Target spacing {self.target_spacing}")
if self.modality == "CT": if self.modality == "CT":
try: try:
@ -77,7 +78,8 @@ class Preprocessor:
_mean = round(self.ct_mean, 2) _mean = round(self.ct_mean, 2)
_std = round(self.ct_std, 2) _std = round(self.ct_std, 2)
print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}") if self.verbose:
print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}")
self.run_parallel(self.preprocess_pair, self.args.exec_mode) self.run_parallel(self.preprocess_pair, self.args.exec_mode)
@ -86,7 +88,7 @@ class Preprocessor:
"patch_size": self.patch_size, "patch_size": self.patch_size,
"spacings": self.target_spacing, "spacings": self.target_spacing,
"n_class": len(self.metadata["labels"]), "n_class": len(self.metadata["labels"]),
"in_channels": len(self.metadata["modality"]), "in_channels": len(self.metadata["modality"]) + int(self.args.ohe),
}, },
open(os.path.join(self.results, "config.pkl"), "wb"), open(os.path.join(self.results, "config.pkl"), "wb"),
) )
@ -99,9 +101,10 @@ class Preprocessor:
image, label = data["image"], data["label"] image, label = data["image"], data["label"]
test_metadata = None test_metadata = None
else: else:
orig_shape = image.shape[1:]
bbox = transforms.utils.generate_spatial_bounding_box(image) bbox = transforms.utils.generate_spatial_bounding_box(image)
test_metadata = np.vstack([bbox, image.shape[1:]])
image = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(image) image = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(image)
test_metadata = np.vstack([bbox, orig_shape, image.shape[1:]])
if label is not None: if label is not None:
label = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(label) label = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(label)
if self.args.dim == 3: if self.args.dim == 3:
@ -111,11 +114,16 @@ class Preprocessor:
image = self.normalize(image) image = self.normalize(image)
if self.training: if self.training:
image, label = self.standardize(image, label) image, label = self.standardize(image, label)
if self.args.dilation:
new_lbl = np.zeros(label.shape, dtype=np.uint8) if self.args.ohe:
for depth in range(label.shape[1]): mask = np.ones(image.shape[1:], dtype=np.float32)
new_lbl[0, depth] = erosion(dilation(label[0, depth], square(3)), square(3)) for i in range(image.shape[0]):
label = new_lbl zeros = np.where(image[i] <= 0)
mask[zeros] *= 0.0
image = self.normalize_intensity(image).astype(np.float32)
mask = np.expand_dims(mask, 0)
image = np.concatenate([image, mask])
self.save(image, label, fname, test_metadata) self.save(image, label, fname, test_metadata)
def resample(self, image, label, image_spacings): def resample(self, image, label, image_spacings):
@ -145,7 +153,8 @@ class Preprocessor:
def save(self, image, label, fname, test_metadata): def save(self, image, label, fname, test_metadata):
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2) mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}") if self.verbose:
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
self.save_npy(image, fname, "_x.npy") self.save_npy(image, fname, "_x.npy")
if label is not None: if label is not None:
self.save_npy(label, fname, "_y.npy") self.save_npy(label, fname, "_y.npy")

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View file

@ -14,9 +14,10 @@
import os import os
import nvidia_dlprof_pytorch_nvtx
import torch import torch
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, early_stopping
from data_loading.data_module import DataModule from data_loading.data_module import DataModule
from models.nn_unet import NNUnet from models.nn_unet import NNUnet
@ -28,8 +29,6 @@ if __name__ == "__main__":
args = get_main_args() args = get_main_args()
if args.profile: if args.profile:
import nvidia_dlprof_pytorch_nvtx
nvidia_dlprof_pytorch_nvtx.init() nvidia_dlprof_pytorch_nvtx.init()
print("Profiling enabled") print("Profiling enabled")
@ -61,9 +60,13 @@ if __name__ == "__main__":
] ]
elif args.exec_mode == "train": elif args.exec_mode == "train":
model = NNUnet(args) model = NNUnet(args)
early_stopping = EarlyStopping(monitor="dice_mean", patience=args.patience, verbose=True, mode="max")
callbacks = [early_stopping]
if args.save_ckpt: if args.save_ckpt:
model_ckpt = ModelCheckpoint(monitor="dice_sum", mode="max", save_last=True) model_ckpt = ModelCheckpoint(
callbacks = [EarlyStopping(monitor="dice_sum", patience=args.patience, verbose=True, mode="max")] filename="{epoch}-{dice_mean:.2f}", monitor="dice_mean", mode="max", save_last=True
)
callbacks.append(model_ckpt)
else: # Evaluation or inference else: # Evaluation or inference
if ckpt_path is not None: if ckpt_path is not None:
model = NNUnet.load_from_checkpoint(ckpt_path) model = NNUnet.load_from_checkpoint(ckpt_path)
@ -76,8 +79,8 @@ if __name__ == "__main__":
precision=16 if args.amp else 32, precision=16 if args.amp else 32,
benchmark=True, benchmark=True,
deterministic=False, deterministic=False,
min_epochs=args.min_epochs, min_epochs=args.epochs,
max_epochs=args.max_epochs, max_epochs=args.epochs,
sync_batchnorm=args.sync_batchnorm, sync_batchnorm=args.sync_batchnorm,
gradient_clip_val=args.gradient_clip_val, gradient_clip_val=args.gradient_clip_val,
callbacks=callbacks, callbacks=callbacks,
@ -85,7 +88,6 @@ if __name__ == "__main__":
default_root_dir=args.results, default_root_dir=args.results,
resume_from_checkpoint=ckpt_path, resume_from_checkpoint=ckpt_path,
accelerator="ddp" if args.gpus > 1 else None, accelerator="ddp" if args.gpus > 1 else None,
checkpoint_callback=model_ckpt,
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches, limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches, limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches, limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
@ -106,6 +108,9 @@ if __name__ == "__main__":
trainer.test(model, test_dataloaders=data_module.test_dataloader()) trainer.test(model, test_dataloaders=data_module.test_dataloader())
elif args.exec_mode == "train": elif args.exec_mode == "train":
trainer.fit(model, data_module) trainer.fit(model, data_module)
if is_main_process():
logname = args.logname if args.logname is not None else "train_log.json"
log(logname, torch.tensor(model.best_mean_dice), results=args.results)
elif args.exec_mode == "evaluate": elif args.exec_mode == "evaluate":
model.args = args model.args = args
trainer.test(model, test_dataloaders=data_module.val_dataloader()) trainer.test(model, test_dataloaders=data_module.val_dataloader())
@ -113,13 +118,14 @@ if __name__ == "__main__":
logname = args.logname if args.logname is not None else "eval_log.json" logname = args.logname if args.logname is not None else "eval_log.json"
log(logname, model.eval_dice, results=args.results) log(logname, model.eval_dice, results=args.results)
elif args.exec_mode == "predict": elif args.exec_mode == "predict":
model.args = args
if args.save_preds: if args.save_preds:
prec = "amp" if args.amp else "fp32" ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}_{prec}" dir_name = f"predictions_{ckpt_name}"
dir_name += f"_task={model.args.task}_fold={model.args.fold}"
if args.tta: if args.tta:
dir_name += "_tta" dir_name += "_tta"
save_dir = os.path.join(args.results, dir_name) save_dir = os.path.join(args.results, dir_name)
model.save_dir = save_dir model.save_dir = save_dir
make_empty_dir(save_dir) make_empty_dir(save_dir)
model.args = args
trainer.test(model, test_dataloaders=data_module.test_dataloader()) trainer.test(model, test_dataloaders=data_module.test_dataloader())

View file

@ -67,7 +67,6 @@ def get_output_padding(kernel_size, stride, padding):
return out_padding if len(out_padding) > 1 else out_padding[0] return out_padding if len(out_padding) > 1 else out_padding[0]
class ConvLayer(nn.Module): class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs): def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
super(ConvLayer, self).__init__() super(ConvLayer, self).__init__()
@ -94,30 +93,6 @@ class ConvBlock(nn.Module):
return out return out
class ResidBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
super(ResidBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
self.conv2 = get_conv(out_channels, out_channels, kernel_size, 1, kwargs["dim"])
self.norm = get_norm(kwargs["norm"], out_channels)
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
self.downsample = None
if max(stride) > 1 or in_channels != out_channels:
self.downsample = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
self.norm_res = get_norm(kwargs["norm"], out_channels)
def forward(self, input_data):
residual = input_data
out = self.conv1(input_data)
out = self.conv2(out)
out = self.norm(out)
if self.downsample is not None:
residual = self.downsample(residual)
residual = self.norm_res(residual)
out = self.lrelu(out + residual)
return out
class UpsampleBlock(nn.Module): class UpsampleBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs): def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
super(UpsampleBlock, self).__init__() super(UpsampleBlock, self).__init__()

View file

@ -13,21 +13,32 @@
# limitations under the License. # limitations under the License.
import torch.nn as nn import torch.nn as nn
from monai.losses import DiceLoss, FocalLoss from monai.losses import DiceCELoss, DiceFocalLoss, DiceLoss, FocalLoss
class Loss(nn.Module): class Loss(nn.Module):
def __init__(self, focal): def __init__(self, focal):
super(Loss, self).__init__() super(Loss, self).__init__()
self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True) if focal:
self.focal = FocalLoss(gamma=2.0) self.loss = DiceFocalLoss(gamma=2.0, softmax=True, to_onehot_y=True, batch=True)
self.cross_entropy = nn.CrossEntropyLoss() else:
self.use_focal = focal self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
def forward(self, y_pred, y_true): def forward(self, y_pred, y_true):
loss = self.dice(y_pred, y_true) return self.loss(y_pred, y_true)
if self.use_focal:
loss += self.focal(y_pred, y_true)
else: class LossBraTS(nn.Module):
loss += self.cross_entropy(y_pred, y_true[:, 0].long()) def __init__(self, focal):
return loss super(LossBraTS, self).__init__()
self.dice = DiceLoss(sigmoid=True, batch=True)
self.ce = FocalLoss(gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
def _loss(self, p, y):
return self.dice(p, y) + self.ce(p, y.float())
def forward(self, p, y):
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
p_wt, p_tc, p_et = p[:, 0].unsqueeze(1), p[:, 1].unsqueeze(1), p[:, 2].unsqueeze(1)
l_wt, l_tc, l_et = self._loss(p_wt, y_wt), self._loss(p_tc, y_tc), self._loss(p_et, y_et)
return l_wt + l_tc + l_et

View file

@ -13,35 +13,61 @@
# limitations under the License. # limitations under the License.
import torch import torch
from pytorch_lightning.metrics.functional import stat_scores from torchmetrics import Metric
from pytorch_lightning.metrics.metric import Metric
class Dice(Metric): class Dice(Metric):
def __init__(self, nclass): def __init__(self, n_class, brats):
super().__init__(dist_sync_on_step=True) super().__init__(dist_sync_on_step=False)
self.add_state("n_updates", default=torch.zeros(1), dist_reduce_fx="sum") self.n_class = n_class
self.add_state("dice", default=torch.zeros((nclass,)), dist_reduce_fx="sum") self.brats = brats
self.add_state("steps", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("dice", default=torch.zeros((n_class,)), dist_reduce_fx="sum")
self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")
def update(self, pred, target): def update(self, preds, target, loss):
self.n_updates += 1 self.steps += 1
self.dice += self.compute_stats(pred, target) self.dice += self.compute_stats_brats(preds, target) if self.brats else self.compute_stats(preds, target)
self.loss += loss
def compute(self): def compute(self):
return 100 * self.dice / self.n_updates return 100 * self.dice / self.steps, self.loss / self.steps
@staticmethod def compute_stats_brats(self, p, y):
def compute_stats(pred, target): scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32)
num_classes = pred.shape[1] p = (torch.sigmoid(p) > 0.5).int()
scores = torch.zeros(num_classes - 1, device=pred.device, dtype=torch.float32) y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
for i in range(1, num_classes): y = torch.stack([y_wt, y_tc, y_et], dim=1)
if (target != i).all():
for i in range(self.n_class):
p_i, y_i = p[:, i], y[:, i]
if (y_i != 1).all():
# no foreground class # no foreground class
_, _pred = torch.max(pred, 1) scores[i - 1] += 1 if (p_i != 1).all() else 0
scores[i - 1] += 1 if (_pred != i).all() else 0
continue continue
_tp, _fp, _tn, _fn, _ = stat_scores(pred=pred, target=target, class_index=i) tp, fn, fp = self.get_stats(p_i, y_i, 1)
denom = (2 * _tp + _fp + _fn).to(torch.float) denom = (2 * tp + fp + fn).to(torch.float)
score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0 score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
scores[i - 1] += score_cls scores[i - 1] += score_cls
return scores return scores
def compute_stats(self, preds, target):
scores = torch.zeros(self.n_class, device=preds.device, dtype=torch.float32)
preds = torch.argmax(preds, dim=1)
for i in range(1, self.n_class + 1):
if (target != i).all():
# no foreground class
scores[i - 1] += 1 if (preds != i).all() else 0
continue
tp, fn, fp = self.get_stats(preds, target, i)
denom = (2 * tp + fp + fn).to(torch.float)
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
scores[i - 1] += score_cls
return scores
@staticmethod
def get_stats(preds, target, class_idx):
tp = torch.logical_and(preds == class_idx, target == class_idx).sum()
fn = torch.logical_and(preds != class_idx, target == class_idx).sum()
fp = torch.logical_and(preds == class_idx, target != class_idx).sum()
return tp, fn, fp

View file

@ -20,8 +20,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from apex.optimizers import FusedAdam, FusedSGD from apex.optimizers import FusedAdam, FusedSGD
from monai.inferers import sliding_window_inference from monai.inferers import sliding_window_inference
from scipy.special import expit, softmax
from skimage.transform import resize from skimage.transform import resize
from torch_optimizer import RAdam from utils.scheduler import WarmupCosineSchedule
from utils.utils import ( from utils.utils import (
flip, flip,
get_dllogger, get_dllogger,
@ -33,7 +34,7 @@ from utils.utils import (
layout_2d, layout_2d,
) )
from models.loss import Loss from models.loss import Loss, LossBraTS
from models.metrics import Dice from models.metrics import Dice
from models.unet import UNet from models.unet import UNet
@ -41,24 +42,25 @@ from models.unet import UNet
class NNUnet(pl.LightningModule): class NNUnet(pl.LightningModule):
def __init__(self, args, bermuda=False, data_dir=None): def __init__(self, args, bermuda=False, data_dir=None):
super(NNUnet, self).__init__() super(NNUnet, self).__init__()
self.save_hyperparameters()
self.args = args self.args = args
self.bermuda = bermuda self.bermuda = bermuda
if data_dir is not None: if data_dir is not None:
self.args.data = data_dir self.args.data = data_dir
self.save_hyperparameters()
self.build_nnunet() self.build_nnunet()
self.best_sum = 0 self.best_mean = 0
self.best_sum_epoch = 0 self.best_mean_epoch = 0
self.best_dice = self.n_class * [0] self.best_dice = self.n_class * [0]
self.best_epoch = self.n_class * [0] self.best_epoch = self.n_class * [0]
self.best_sum_dice = self.n_class * [0] self.best_mean_dice = self.n_class * [0]
self.test_idx = 0 self.test_idx = 0
self.test_imgs = [] self.test_imgs = []
if not self.bermuda: if not self.bermuda:
self.learning_rate = args.learning_rate self.learning_rate = args.learning_rate
self.loss = Loss(self.args.focal) loss = LossBraTS if self.args.brats else Loss
self.loss = loss(self.args.focal)
self.tta_flips = get_tta_flips(args.dim) self.tta_flips = get_tta_flips(args.dim)
self.dice = Dice(self.n_class) self.dice = Dice(self.n_class, self.args.brats)
if self.args.exec_mode in ["train", "evaluate"]: if self.args.exec_mode in ["train", "evaluate"]:
self.dllogger = get_dllogger(args.results) self.dllogger = get_dllogger(args.results)
@ -72,10 +74,17 @@ class NNUnet(pl.LightningModule):
return self.model(img) return self.model(img)
return self.tta_inference(img) if self.args.tta else self.do_inference(img) return self.tta_inference(img) if self.args.tta else self.do_inference(img)
def compute_loss(self, preds, label):
if self.args.deep_supervision:
pred0, pred1, pred2 = preds
loss = self.loss(pred0, label) + 0.5 * self.loss(pred1, label) + 0.25 * self.loss(pred2, label)
return loss / 1.75
return self.loss(preds, label)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
img, lbl = self.get_train_data(batch) img, lbl = self.get_train_data(batch)
pred = self.model(img) pred = self.model(img)
loss = self.loss(pred, lbl) loss = self.compute_loss(pred, lbl)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
@ -84,49 +93,50 @@ class NNUnet(pl.LightningModule):
img, lbl = batch["image"], batch["label"] img, lbl = batch["image"], batch["label"]
pred = self._forward(img) pred = self._forward(img)
loss = self.loss(pred, lbl) loss = self.loss(pred, lbl)
self.dice.update(pred, lbl[:, 0]) self.dice.update(pred, lbl[:, 0], loss)
return {"val_loss": loss}
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
if self.args.exec_mode == "evaluate": if self.args.exec_mode == "evaluate":
return self.validation_step(batch, batch_idx) return self.validation_step(batch, batch_idx)
img = batch["image"] img = batch["image"]
pred = self._forward(img) pred = self._forward(img).squeeze(0).cpu().detach().numpy()
if self.args.save_preds: if self.args.save_preds:
meta = batch["meta"][0].cpu().detach().numpy() meta = batch["meta"][0].cpu().detach().numpy()
original_shape = meta[2]
min_d, max_d = meta[0, 0], meta[1, 0] min_d, max_d = meta[0, 0], meta[1, 0]
min_h, max_h = meta[0, 1], meta[1, 1] min_h, max_h = meta[0, 1], meta[1, 1]
min_w, max_w = meta[0, 2], meta[1, 2] min_w, max_w = meta[0, 2], meta[1, 2]
n_class, original_shape, cropped_shape = pred.shape[0], meta[2], meta[3]
final_pred = torch.zeros((1, pred.shape[1], *original_shape), device=img.device) if not all(cropped_shape == pred.shape[1:]):
final_pred[:, :, min_d:max_d, min_h:max_h, min_w:max_w] = pred resized_pred = np.zeros((n_class, *cropped_shape))
final_pred = nn.functional.softmax(final_pred, dim=1) for i in range(n_class):
final_pred = final_pred.squeeze(0).cpu().detach().numpy()
if not all(original_shape == final_pred.shape[1:]):
class_ = final_pred.shape[0]
resized_pred = np.zeros((class_, *original_shape))
for i in range(class_):
resized_pred[i] = resize( resized_pred[i] = resize(
final_pred[i], original_shape, order=3, mode="edge", cval=0, clip=True, anti_aliasing=False pred[i], cropped_shape, order=3, mode="edge", cval=0, clip=True, anti_aliasing=False
) )
final_pred = resized_pred pred = resized_pred
final_pred = np.zeros((n_class, *original_shape))
final_pred[:, min_d:max_d, min_h:max_h, min_w:max_w] = pred
if self.args.brats:
final_pred = expit(final_pred)
else:
final_pred = softmax(final_pred, axis=0)
self.save_mask(final_pred) self.save_mask(final_pred)
def build_nnunet(self): def build_nnunet(self):
in_channels, n_class, kernels, strides, self.patch_size = get_unet_params(self.args) in_channels, n_class, kernels, strides, self.patch_size = get_unet_params(self.args)
self.n_class = n_class - 1 self.n_class = n_class - 1
if self.args.brats:
n_class = 3
self.model = UNet( self.model = UNet(
in_channels=in_channels, in_channels=in_channels,
n_class=n_class, n_class=n_class,
kernels=kernels, kernels=kernels,
strides=strides, strides=strides,
dimension=self.args.dim, dimension=self.args.dim,
residual=self.args.residual,
normalization_layer=self.args.norm, normalization_layer=self.args.norm,
negative_slope=self.args.negative_slope, negative_slope=self.args.negative_slope,
deep_supervision=self.args.deep_supervision,
more_chn=self.args.more_chn,
) )
if is_main_process(): if is_main_process():
print(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}") print(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}")
@ -180,39 +190,30 @@ class NNUnet(pl.LightningModule):
mode=self.args.blend, mode=self.args.blend,
) )
@staticmethod
def metric_mean(name, outputs):
return torch.stack([out[name] for out in outputs]).mean(dim=0)
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
if self.current_epoch < self.args.skip_first_n_eval: dice, loss = self.dice.compute()
self.log("dice_sum", 0.001 * self.current_epoch) self.dice.reset()
self.dice.reset() dice_mean = torch.mean(dice)
return None if dice_mean >= self.best_mean:
loss = self.metric_mean("val_loss", outputs) self.best_mean = dice_mean
dice = self.dice.compute() self.best_mean_dice = dice[:]
dice_sum = torch.sum(dice) self.best_mean_epoch = self.current_epoch
if dice_sum >= self.best_sum:
self.best_sum = dice_sum
self.best_sum_dice = dice[:]
self.best_sum_epoch = self.current_epoch
for i, dice_i in enumerate(dice): for i, dice_i in enumerate(dice):
if dice_i > self.best_dice[i]: if dice_i > self.best_dice[i]:
self.best_dice[i], self.best_epoch[i] = dice_i, self.current_epoch self.best_dice[i], self.best_epoch[i] = dice_i, self.current_epoch
if is_main_process(): if is_main_process():
metrics = {} metrics = {}
metrics.update({"mean dice": round(torch.mean(dice).item(), 2)}) metrics.update({"Mean dice": round(torch.mean(dice).item(), 2)})
metrics.update({"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)}) metrics.update({"Highest": round(torch.mean(self.best_mean_dice).item(), 2)})
if self.n_class > 1: if self.n_class > 1:
metrics.update({f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)}) metrics.update({f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)})
metrics.update({f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice)})
metrics.update({"val_loss": round(loss.item(), 4)}) metrics.update({"val_loss": round(loss.item(), 4)})
self.dllogger.log(step=self.current_epoch, data=metrics) self.dllogger.log(step=self.current_epoch, data=metrics)
self.dllogger.flush() self.dllogger.flush()
self.log("val_loss", loss) self.log("val_loss", loss)
self.log("dice_sum", dice_sum) self.log("dice_mean", dice_mean)
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
if self.args.exec_mode == "evaluate": if self.args.exec_mode == "evaluate":
@ -222,22 +223,20 @@ class NNUnet(pl.LightningModule):
optimizer = { optimizer = {
"sgd": FusedSGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum), "sgd": FusedSGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum),
"adam": FusedAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay), "adam": FusedAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
"radam": RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
}[self.args.optimizer.lower()] }[self.args.optimizer.lower()]
scheduler = { if self.args.scheduler:
"none": None, scheduler = {
"multistep": torch.optim.lr_scheduler.MultiStepLR(optimizer, self.args.steps, gamma=self.args.factor), "scheduler": WarmupCosineSchedule(
"cosine": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.args.max_epochs), optimizer=optimizer,
"plateau": torch.optim.lr_scheduler.ReduceLROnPlateau( warmup_steps=250,
optimizer, factor=self.args.factor, patience=self.args.lr_patience t_total=self.args.epochs * len(self.trainer.datamodule.train_dataloader()),
), ),
}[self.args.scheduler.lower()] "interval": "step",
"frequency": 1,
opt_dict = {"optimizer": optimizer, "monitor": "val_loss"} }
if scheduler is not None: return {"optimizer": optimizer, "monitor": "val_loss", "lr_scheduler": scheduler}
opt_dict.update({"lr_scheduler": scheduler}) return {"optimizer": optimizer, "monitor": "val_loss"}
return opt_dict
def save_mask(self, pred): def save_mask(self, pred):
if self.test_idx == 0: if self.test_idx == 0:

View file

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import torch.nn as nn import torch.nn as nn
from models.layers import ConvBlock, OutputBlock, ResidBlock, UpsampleBlock from models.layers import ConvBlock, OutputBlock, UpsampleBlock
class UNet(nn.Module): class UNet(nn.Module):
@ -25,34 +27,39 @@ class UNet(nn.Module):
strides, strides,
normalization_layer, normalization_layer,
negative_slope, negative_slope,
residual,
dimension, dimension,
deep_supervision,
more_chn,
): ):
super(UNet, self).__init__() super(UNet, self).__init__()
self.more_chn = more_chn
self.dim = dimension self.dim = dimension
self.n_class = n_class self.n_class = n_class
self.residual = residual
self.negative_slope = negative_slope self.negative_slope = negative_slope
self.norm = normalization_layer + f"norm{dimension}d" self.norm = normalization_layer + f"norm{dimension}d"
self.filters = [min(2 ** (5 + i), 320 if dimension == 3 else 512) for i in range(len(strides))] self.deep_supervision = deep_supervision
self.depth = len(strides)
if self.more_chn:
self.filters = [64, 96, 128, 192, 256, 384, 512, 768, 1024][: self.depth]
else:
self.filters = [min(2 ** (5 + i), 320 if dimension == 3 else 512) for i in range(self.depth)]
down_block = ResidBlock if self.residual else ConvBlock
self.input_block = self.get_conv_block( self.input_block = self.get_conv_block(
conv_block=down_block, conv_block=ConvBlock,
in_channels=in_channels, in_channels=in_channels,
out_channels=self.filters[0], out_channels=self.filters[0],
kernel_size=kernels[0], kernel_size=kernels[0],
stride=strides[0], stride=strides[0],
) )
self.downsamples = self.get_module_list( self.downsamples = self.get_module_list(
conv_block=down_block, conv_block=ConvBlock,
in_channels=self.filters[:-1], in_channels=self.filters[:-1],
out_channels=self.filters[1:], out_channels=self.filters[1:],
kernels=kernels[1:-1], kernels=kernels[1:-1],
strides=strides[1:-1], strides=strides[1:-1],
) )
self.bottleneck = self.get_conv_block( self.bottleneck = self.get_conv_block(
conv_block=down_block, conv_block=ConvBlock,
in_channels=self.filters[-2], in_channels=self.filters[-2],
out_channels=self.filters[-1], out_channels=self.filters[-1],
kernel_size=kernels[-1], kernel_size=kernels[-1],
@ -65,6 +72,8 @@ class UNet(nn.Module):
kernels=kernels[1:][::-1], kernels=kernels[1:][::-1],
strides=strides[1:][::-1], strides=strides[1:][::-1],
) )
self.deep_supervision_head1 = self.get_output_block(1)
self.deep_supervision_head2 = self.get_output_block(2)
self.output_block = self.get_output_block(decoder_level=0) self.output_block = self.get_output_block(decoder_level=0)
self.apply(self.initialize_weights) self.apply(self.initialize_weights)
self.n_layers = len(self.upsamples) - 1 self.n_layers = len(self.upsamples) - 1
@ -76,9 +85,17 @@ class UNet(nn.Module):
out = downsample(out) out = downsample(out)
encoder_outputs.append(out) encoder_outputs.append(out)
out = self.bottleneck(out) out = self.bottleneck(out)
for idx, upsample in enumerate(self.upsamples): decoder_outputs = []
out = upsample(out, encoder_outputs[self.n_layers - idx]) for i, upsample in enumerate(self.upsamples):
out = upsample(out, encoder_outputs[self.depth - i - 2])
decoder_outputs.append(out)
out = self.output_block(out) out = self.output_block(out)
if self.training and self.deep_supervision:
out1 = self.deep_supervision_head1(decoder_outputs[-2])
out2 = self.deep_supervision_head2(decoder_outputs[-3])
out1 = nn.functional.interpolate(out1, out.shape[2:], mode="trilinear", align_corners=True)
out2 = nn.functional.interpolate(out2, out.shape[2:], mode="trilinear", align_corners=True)
return torch.stack([out, out1, out2])
return out return out
def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride): def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride):

File diff suppressed because one or more lines are too long

View file

@ -29,7 +29,8 @@ parser.add_argument(
choices=["training", "val", "test"], choices=["training", "val", "test"],
help="Mode for data preprocessing", help="Mode for data preprocessing",
) )
parser.add_argument("--dilation", action="store_true", help="Perform morphological label dilation") parser.add_argument("--ohe", action="store_true", help="Add one-hot-encoding for foreground voxels (voxels > 0)")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--task", type=str, help="Number of task to be run. MSD uses numbers 01-10") parser.add_argument("--task", type=str, help="Number of task to be run. MSD uses numbers 01-10")
parser.add_argument("--dim", type=int, default=3, choices=[2, 3], help="Data dimension to prepare") parser.add_argument("--dim", type=int, default=3, choices=[2, 3], help="Data dimension to prepare")
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs for data preprocessing") parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs for data preprocessing")
@ -44,4 +45,4 @@ if __name__ == "__main__":
if args.exec_mode == "test": if args.exec_mode == "test":
path = os.path.join(path, "test") path = os.path.join(path, "test")
end = time.time() end = time.time()
print(f"Preprocessing time: {(end - start):.2f}") print(f"Pre-processing time: {(end - start):.2f}")

View file

@ -1,8 +1,7 @@
git+https://github.com/NVIDIA/dllogger git+https://github.com/NVIDIA/dllogger
nibabel==3.1.1 nibabel==3.2.1
joblib==0.16.0 joblib==1.0.1
scikit-learn==0.23.2 pytorch-lightning==1.3.8
pynvml==8.0.4 scikit-learn==1.0
pillow==6.2.0 scikit-image==0.18.3
fsspec==0.8.0 pynvml==11.0.0
pytorch_ranger==0.1.1

View file

@ -35,7 +35,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py") path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
cmd = "" cmd = ""
cmd += f"python main.py --task {args.task} --benchmark --max_epochs 2 --min_epochs 1 --optimizer adam " cmd += f"python main.py --task {args.task} --benchmark --epochs 2 "
cmd += f"--results {args.results} " cmd += f"--results {args.results} "
cmd += f"--logname {args.logname} " cmd += f"--logname {args.logname} "
cmd += f"--exec_mode {args.mode} " cmd += f"--exec_mode {args.mode} "

View file

@ -24,6 +24,7 @@ parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4],
parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet") parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision") parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation") parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
parser.add_argument("--resume_training", action="store_true", help="Resume training from checkpoint")
parser.add_argument("--results", type=str, default="/results", help="Path to results directory") parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
parser.add_argument("--logname", type=str, default="log", help="Name of dlloger output") parser.add_argument("--logname", type=str, default="log", help="Name of dlloger output")
@ -40,4 +41,5 @@ if __name__ == "__main__":
cmd += f"--gpus {args.gpus} " cmd += f"--gpus {args.gpus} "
cmd += "--amp " if args.amp else "" cmd += "--amp " if args.amp else ""
cmd += "--tta " if args.tta else "" cmd += "--tta " if args.tta else ""
cmd += "--resume_training " if args.resume_training else ""
call(cmd, shell=True) call(cmd, shell=True)

View file

@ -0,0 +1,17 @@
import math
from torch.optim.lr_scheduler import LambdaLR
class WarmupCosineSchedule(LambdaLR):
def __init__(self, optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.cycles = cycles
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))

View file

@ -101,7 +101,7 @@ def get_unet_params(args):
strides, kernels, sizes = [], [], patch_size[:] strides, kernels, sizes = [], [], patch_size[:]
while True: while True:
spacing_ratio = [spacing / min(spacings) for spacing in spacings] spacing_ratio = [spacing / min(spacings) for spacing in spacings]
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] stride = [2 if ratio <= 2 and size >= 2 * args.min_fmap else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride): if all(s == 1 for s in stride):
break break
@ -109,7 +109,7 @@ def get_unet_params(args):
spacings = [i * j for i, j in zip(spacings, stride)] spacings = [i * j for i, j in zip(spacings, stride)]
kernels.append(kernel) kernels.append(kernel)
strides.append(stride) strides.append(stride)
if len(strides) == 5: if len(strides) == 6:
break break
strides.insert(0, len(spacings) * [1]) strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3]) kernels.append(len(spacings) * [3])
@ -184,13 +184,15 @@ def get_main_args(strings=None):
arg("--logname", type=str, default=None, help="Name of dlloger output") arg("--logname", type=str, default=None, help="Name of dlloger output")
arg("--task", type=str, help="Task number. MSD uses numbers 01-10") arg("--task", type=str, help="Task number. MSD uses numbers 01-10")
arg("--gpus", type=non_negative_int, default=1, help="Number of gpus") arg("--gpus", type=non_negative_int, default=1, help="Number of gpus")
arg("--learning_rate", type=float, default=0.001, help="Learning rate") arg("--learning_rate", type=float, default=0.0008, help="Learning rate")
arg("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value") arg("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
arg("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU") arg("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
arg("--tta", action="store_true", help="Enable test time augmentation") arg("--tta", action="store_true", help="Enable test time augmentation")
arg("--brats", action="store_true", help="Enable BraTS specific training and inference")
arg("--deep_supervision", action="store_true", help="Enable deep supervision")
arg("--more_chn", action="store_true", help="Create encoder with more channels")
arg("--amp", action="store_true", help="Enable automatic mixed precision") arg("--amp", action="store_true", help="Enable automatic mixed precision")
arg("--benchmark", action="store_true", help="Run model benchmarking") arg("--benchmark", action="store_true", help="Run model benchmarking")
arg("--residual", action="store_true", help="Enable residual block in encoder")
arg("--focal", action="store_true", help="Use focal loss instead of cross entropy") arg("--focal", action="store_true", help="Use focal loss instead of cross entropy")
arg("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm") arg("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm")
arg("--save_ckpt", action="store_true", help="Enable saving checkpoint") arg("--save_ckpt", action="store_true", help="Enable saving checkpoint")
@ -200,20 +202,16 @@ def get_main_args(strings=None):
arg("--ckpt_path", type=str, default=None, help="Path to checkpoint") arg("--ckpt_path", type=str, default=None, help="Path to checkpoint")
arg("--fold", type=non_negative_int, default=0, help="Fold number") arg("--fold", type=non_negative_int, default=0, help="Fold number")
arg("--patience", type=positive_int, default=100, help="Early stopping patience") arg("--patience", type=positive_int, default=100, help="Early stopping patience")
arg("--lr_patience", type=positive_int, default=70, help="Patience for ReduceLROnPlateau scheduler")
arg("--batch_size", type=positive_int, default=2, help="Batch size") arg("--batch_size", type=positive_int, default=2, help="Batch size")
arg("--val_batch_size", type=positive_int, default=4, help="Validation batch size") arg("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
arg("--steps", nargs="+", type=positive_int, required=False, help="Steps for multistep scheduler")
arg("--profile", action="store_true", help="Run dlprof profiling") arg("--profile", action="store_true", help="Run dlprof profiling")
arg("--momentum", type=float, default=0.99, help="Momentum factor") arg("--momentum", type=float, default=0.99, help="Momentum factor")
arg("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)") arg("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
arg("--save_preds", action="store_true", help="Enable prediction saving") arg("--save_preds", action="store_true", help="Enable prediction saving")
arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension") arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
arg("--resume_training", action="store_true", help="Resume training from the last checkpoint") arg("--resume_training", action="store_true", help="Resume training from the last checkpoint")
arg("--factor", type=float, default=0.3, help="Scheduler factor")
arg("--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading") arg("--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading")
arg("--min_epochs", type=non_negative_int, default=30, help="Force training for at least these many epochs") arg("--epochs", type=non_negative_int, default=1000, help="Number of training epochs")
arg("--max_epochs", type=non_negative_int, default=10000, help="Stop training after this number of epochs")
arg("--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics") arg("--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics")
arg("--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer") arg("--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer")
arg("--nvol", type=positive_int, default=1, help="Number of volumes which come into single batch size for 2D model") arg("--nvol", type=positive_int, default=1, help="Number of volumes which come into single batch size for 2D model")
@ -252,18 +250,22 @@ def get_main_args(strings=None):
) )
arg( arg(
"--scheduler", "--scheduler",
type=str, action="store_true",
default="none", help="Enable cosine rate scheduler with warmup",
choices=["none", "multistep", "cosine", "plateau"],
help="Learning rate scheduler",
) )
arg( arg(
"--optimizer", "--optimizer",
type=str, type=str,
default="radam", default="adam",
choices=["sgd", "radam", "adam"], choices=["sgd", "adam"],
help="Optimizer", help="Optimizer",
) )
arg(
"--min_fmap",
type=non_negative_int,
default=4,
help="The minimal size that feature map can be reduced in bottleneck",
)
arg( arg(
"--blend", "--blend",
type=str, type=str,