[nnUNet/PyT] Add Jupyter notebook with BraTS21 solution
This commit is contained in:
parent
8c98f155e0
commit
028534a5b9
|
@ -1,30 +1,24 @@
|
|||
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.02-py3
|
||||
FROM ${FROM_IMAGE_NAME}
|
||||
FROM ${FROM_IMAGE_NAME}
|
||||
|
||||
ADD . /workspace/nnunet_pyt
|
||||
WORKDIR /workspace/nnunet_pyt
|
||||
ADD ./triton/requirements.txt /
|
||||
RUN pip install --disable-pip-version-check -r /requirements.txt
|
||||
RUN apt-get update && apt-get install -y libb64-dev libb64-0d
|
||||
|
||||
ADD ./requirements.txt /
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install --disable-pip-version-check -r requirements.txt
|
||||
RUN pip install --disable-pip-version-check -r triton/requirements.txt
|
||||
RUN pip install pytorch-lightning==1.0.0 --no-dependencies
|
||||
RUN pip install torchtext==0.6.0 --no-dependencies
|
||||
RUN pip install monai==0.4.0 --no-dependencies
|
||||
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.30.0
|
||||
RUN pip install torch_optimizer==0.0.1a15 --no-dependencies
|
||||
RUN pip install numpy==1.20.3
|
||||
RUN pip install --disable-pip-version-check -r /requirements.txt
|
||||
RUN pip install monai==0.7.0 --no-dependencies
|
||||
RUN pip install numpy --upgrade
|
||||
RUN pip install nvidia-pyindex==1.0.9
|
||||
RUN pip install nvidia-dlprof==1.2.0
|
||||
RUN pip install nvidia_dlprof_pytorch_nvtx==1.2.0
|
||||
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==1.5.0
|
||||
|
||||
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
||||
RUN unzip -qq awscliv2.zip
|
||||
RUN ./aws/install
|
||||
RUN rm -rf awscliv2.zip aws
|
||||
|
||||
# Install Perf Client required library
|
||||
RUN apt-get update && apt-get install -y libb64-dev libb64-0d
|
||||
|
||||
# Install Triton Client Python API and copy Perf Client
|
||||
#COPY --from=triton-client /workspace/install/ /workspace/install/
|
||||
#RUN pip install /workspace/install/python/triton*.whl
|
||||
WORKDIR /workspace/nnunet_pyt
|
||||
ADD . /workspace/nnunet_pyt
|
||||
|
|
|
@ -69,7 +69,7 @@ The following figure shows the architecture of the 3D U-Net model and its differ
|
|||
|
||||
All convolution blocks in U-Net in both encoder and decoder are using two convolution layers followed by instance normalization and a leaky ReLU nonlinearity. For downsampling we are using stride convolution whereas transposed convolution for upsampling.
|
||||
|
||||
All models were trained with RAdam optimizer, learning rate 0.001 and weight decay 0.0001. For loss function we use the average of [cross-entropy](https://en.wikipedia.org/wiki/Cross_entropy) and [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient).
|
||||
All models were trained with Adam optimizer, learning rate 0.0008 and weight decay 0.0001. For loss function we use the average of [cross-entropy](https://en.wikipedia.org/wiki/Cross_entropy) and [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient).
|
||||
|
||||
Early stopping is triggered if validation dice score wasn't improved during the last 100 epochs.
|
||||
|
||||
|
@ -304,7 +304,7 @@ To see the full list of available options and their descriptions, use the `-h` o
|
|||
The following example output is printed when running the model:
|
||||
|
||||
```
|
||||
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,radam,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
|
||||
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
|
@ -370,8 +370,8 @@ optional arguments:
|
|||
type of CPU affinity (default: socket_unique_interleaved)
|
||||
--scheduler {none,multistep,cosine,plateau}
|
||||
Learning rate scheduler (default: none)
|
||||
--optimizer {sgd,radam,adam}
|
||||
Optimizer (default: radam)
|
||||
--optimizer {sgd,adam}
|
||||
Optimizer (default: adam)
|
||||
--blend {gaussian,constant}
|
||||
How to blend output of overlapping windows (default: gaussian)
|
||||
--train_batches TRAIN_BATCHES
|
||||
|
@ -418,7 +418,7 @@ If you have dataset in other format or you want customize data preprocessing or
|
|||
### Training process
|
||||
|
||||
The model trains for at least `--min_epochs` and at most `--max_epochs` epochs. After each epoch evaluation, the validation set is done and validation loss is monitored for early stopping (see `--patience` flag). Default training settings are:
|
||||
* RAdam optimizer with learning rate of 0.001 and weight decay 0.0001.
|
||||
* Adam optimizer with learning rate of 0.0008 and weight decay 0.0001.
|
||||
* Training batch size is set to 2 for 3D U-Net and 16 for 2D U-Net.
|
||||
|
||||
This default parametrization is applied when running scripts from the `scripts/` directory and when running `main.py` without explicitly overriding these parameters. By default, the training is in full precision. To enable AMP, pass the `--amp` flag. AMP can be enabled for every mode of execution.
|
||||
|
@ -454,8 +454,6 @@ The script will then:
|
|||
|
||||
## Performance
|
||||
|
||||
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
||||
|
||||
### Benchmarking
|
||||
|
||||
The following section shows how to run benchmarks to measure the model performance in training and inference modes.
|
||||
|
@ -633,6 +631,9 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
|||
|
||||
### Changelog
|
||||
|
||||
October 2021
|
||||
- Add Jupyter Notebook with BraTS solution
|
||||
|
||||
May 2021
|
||||
- Add Triton Inference Server support
|
||||
- Removed deep supervision, attention and drop block
|
||||
|
|
|
@ -25,7 +25,7 @@ from nvidia.dali.plugin.pytorch import DALIGenericIterator
|
|||
|
||||
|
||||
def get_numpy_reader(files, shard_id, num_shards, seed, shuffle):
|
||||
return ops.NumpyReader(
|
||||
return ops.readers.Numpy(
|
||||
seed=seed,
|
||||
files=files,
|
||||
device="cpu",
|
||||
|
@ -42,6 +42,7 @@ class TrainPipeline(Pipeline):
|
|||
def __init__(self, batch_size, num_threads, device_id, **kwargs):
|
||||
super(TrainPipeline, self).__init__(batch_size, num_threads, device_id)
|
||||
self.dim = kwargs["dim"]
|
||||
self.internal_seed = kwargs["seed"]
|
||||
self.oversampling = kwargs["oversampling"]
|
||||
self.input_x = get_numpy_reader(
|
||||
num_shards=kwargs["gpus"],
|
||||
|
@ -60,6 +61,7 @@ class TrainPipeline(Pipeline):
|
|||
self.patch_size = kwargs["patch_size"]
|
||||
if self.dim == 2:
|
||||
self.patch_size = [kwargs["batch_size_2d"]] + self.patch_size
|
||||
|
||||
self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
|
||||
self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
|
||||
|
||||
|
@ -69,7 +71,7 @@ class TrainPipeline(Pipeline):
|
|||
return img, lbl
|
||||
|
||||
def random_augmentation(self, probability, augmented, original):
|
||||
condition = fn.cast(fn.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
|
||||
condition = fn.cast(fn.random.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
|
||||
neg_condition = condition ^ True
|
||||
return condition * augmented + neg_condition * original
|
||||
|
||||
|
@ -77,45 +79,60 @@ class TrainPipeline(Pipeline):
|
|||
def slice_fn(img):
|
||||
return fn.slice(img, 1, 3, axes=[0])
|
||||
|
||||
def crop_fn(self, img, lbl):
|
||||
center = fn.segmentation.random_mask_pixel(lbl, foreground=fn.coin_flip(probability=self.oversampling))
|
||||
crop_anchor = self.slice_fn(center) - self.crop_shape // 2
|
||||
adjusted_anchor = math.max(0, crop_anchor)
|
||||
max_anchor = self.slice_fn(fn.shapes(lbl)) - self.crop_shape
|
||||
crop_anchor = math.min(adjusted_anchor, max_anchor)
|
||||
img = fn.slice(img.gpu(), crop_anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad")
|
||||
lbl = fn.slice(lbl.gpu(), crop_anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad")
|
||||
return img, lbl
|
||||
def biased_crop_fn(self, img, label):
|
||||
roi_start, roi_end = fn.segmentation.random_object_bbox(
|
||||
label,
|
||||
format="start_end",
|
||||
foreground_prob=self.oversampling,
|
||||
background=0,
|
||||
seed=self.internal_seed,
|
||||
device="cpu",
|
||||
cache_objects=True,
|
||||
)
|
||||
|
||||
anchor = fn.roi_random_crop(label, roi_start=roi_start, roi_end=roi_end, crop_shape=[1, *self.patch_size])
|
||||
anchor = fn.slice(anchor, 1, 3, axes=[0]) # drop channels from anchor
|
||||
img, label = fn.slice(
|
||||
[img, label], anchor, self.crop_shape, axis_names="DHW", out_of_bounds_policy="pad", device="cpu"
|
||||
)
|
||||
|
||||
return img.gpu(), label.gpu()
|
||||
|
||||
def zoom_fn(self, img, lbl):
|
||||
resized_shape = self.crop_shape * self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.0)), 1.0)
|
||||
img, lbl = fn.crop(img, crop=resized_shape), fn.crop(lbl, crop=resized_shape)
|
||||
scale = self.random_augmentation(0.15, fn.random.uniform(range=(0.7, 1.0)), 1.0)
|
||||
d, h, w = [scale * x for x in self.patch_size]
|
||||
if self.dim == 2:
|
||||
d = self.patch_size[0]
|
||||
img, lbl = fn.crop(img, crop_h=h, crop_w=w, crop_d=d), fn.crop(lbl, crop_h=h, crop_w=w, crop_d=d)
|
||||
img = fn.resize(img, interp_type=types.DALIInterpType.INTERP_CUBIC, size=self.crop_shape_float)
|
||||
lbl = fn.resize(lbl, interp_type=types.DALIInterpType.INTERP_NN, size=self.crop_shape_float)
|
||||
return img, lbl
|
||||
|
||||
def noise_fn(self, img):
|
||||
img_noised = img + fn.random.normal(img, stddev=fn.uniform(range=(0.0, 0.33)))
|
||||
img_noised = img + fn.random.normal(img, stddev=fn.random.uniform(range=(0.0, 0.33)))
|
||||
return self.random_augmentation(0.15, img_noised, img)
|
||||
|
||||
def blur_fn(self, img):
|
||||
img_blured = fn.gaussian_blur(img, sigma=fn.uniform(range=(0.5, 1.5)))
|
||||
return self.random_augmentation(0.15, img_blured, img)
|
||||
img_blurred = fn.gaussian_blur(img, sigma=fn.random.uniform(range=(0.5, 1.5)))
|
||||
return self.random_augmentation(0.15, img_blurred, img)
|
||||
|
||||
def brightness_fn(self, img):
|
||||
brightness_scale = self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.3)), 1.0)
|
||||
brightness_scale = self.random_augmentation(0.15, fn.random.uniform(range=(0.7, 1.3)), 1.0)
|
||||
return img * brightness_scale
|
||||
|
||||
def contrast_fn(self, img):
|
||||
min_, max_ = fn.reductions.min(img), fn.reductions.max(img)
|
||||
scale = self.random_augmentation(0.15, fn.uniform(range=(0.65, 1.5)), 1.0)
|
||||
scale = self.random_augmentation(0.15, fn.random.uniform(range=(0.65, 1.5)), 1.0)
|
||||
img = math.clamp(img * scale, min_, max_)
|
||||
return img
|
||||
|
||||
def flips_fn(self, img, lbl):
|
||||
kwargs = {"horizontal": fn.coin_flip(probability=0.33), "vertical": fn.coin_flip(probability=0.33)}
|
||||
kwargs = {
|
||||
"horizontal": fn.random.coin_flip(probability=0.33),
|
||||
"vertical": fn.random.coin_flip(probability=0.33),
|
||||
}
|
||||
if self.dim == 3:
|
||||
kwargs.update({"depthwise": fn.coin_flip(probability=0.33)})
|
||||
kwargs.update({"depthwise": fn.random.coin_flip(probability=0.33)})
|
||||
return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
|
||||
|
||||
def transpose_fn(self, img, lbl):
|
||||
|
@ -124,7 +141,7 @@ class TrainPipeline(Pipeline):
|
|||
|
||||
def define_graph(self):
|
||||
img, lbl = self.load_data()
|
||||
img, lbl = self.crop_fn(img, lbl)
|
||||
img, lbl = self.biased_crop_fn(img, lbl)
|
||||
img, lbl = self.zoom_fn(img, lbl)
|
||||
img, lbl = self.flips_fn(img, lbl)
|
||||
img = self.noise_fn(img)
|
||||
|
@ -141,15 +158,15 @@ class EvalPipeline(Pipeline):
|
|||
super(EvalPipeline, self).__init__(batch_size, num_threads, device_id)
|
||||
self.input_x = get_numpy_reader(
|
||||
files=kwargs["imgs"],
|
||||
shard_id=device_id,
|
||||
num_shards=kwargs["gpus"],
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
seed=kwargs["seed"],
|
||||
shuffle=False,
|
||||
)
|
||||
self.input_y = get_numpy_reader(
|
||||
files=kwargs["lbls"],
|
||||
shard_id=device_id,
|
||||
num_shards=kwargs["gpus"],
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
seed=kwargs["seed"],
|
||||
shuffle=False,
|
||||
)
|
||||
|
@ -281,6 +298,12 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
|
|||
imgs = list(itertools.chain(*(100 * [imgs])))[: nbs * kwargs["gpus"]]
|
||||
lbls = list(itertools.chain(*(100 * [lbls])))[: nbs * kwargs["gpus"]]
|
||||
|
||||
if mode == "eval": # To avoid padding for the multigpu evaluation.
|
||||
rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
imgs, lbls = np.array_split(imgs, kwargs["gpus"]), np.array_split(lbls, kwargs["gpus"])
|
||||
imgs, lbls = [list(x) for x in imgs], [list(x) for x in lbls]
|
||||
imgs, lbls = imgs[rank], lbls[rank]
|
||||
|
||||
pipe_kwargs = {
|
||||
"imgs": imgs,
|
||||
"lbls": lbls,
|
||||
|
|
|
@ -57,8 +57,7 @@ class DataModule(LightningDataModule):
|
|||
self.val_imgs = get_split(imgs, val_idx)
|
||||
self.val_lbls = get_split(lbls, val_idx)
|
||||
if is_main_process():
|
||||
ntrain, nval = len(self.train_imgs), len(self.val_imgs)
|
||||
print(f"Number of examples: Train {ntrain} - Val {nval}")
|
||||
print(f"Number of examples: Train {len(self.train_imgs)} - Val {len(self.val_imgs)}")
|
||||
elif is_main_process():
|
||||
print(f"Number of test examples: {len(self.test_imgs)}")
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ task = {
|
|||
"08": "Task08_HepaticVessel",
|
||||
"09": "Task09_Spleen",
|
||||
"10": "Task10_Colon",
|
||||
"11": "BraTS2021_train",
|
||||
"12": "BraTS2021_val",
|
||||
}
|
||||
|
||||
patch_size = {
|
||||
|
@ -36,6 +38,8 @@ patch_size = {
|
|||
"08_3d": [64, 192, 192],
|
||||
"09_3d": [64, 192, 160],
|
||||
"10_3d": [56, 192, 160],
|
||||
"11_3d": [128, 128, 128],
|
||||
"12_3d": [128, 128, 128],
|
||||
"01_2d": [192, 160],
|
||||
"02_2d": [320, 256],
|
||||
"03_2d": [512, 512],
|
||||
|
@ -60,7 +64,8 @@ spacings = {
|
|||
"08_3d": [1.5, 0.8, 0.8],
|
||||
"09_3d": [1.6, 0.79, 0.79],
|
||||
"10_3d": [3, 0.78, 0.78],
|
||||
"11_3d": [5, 0.741, 0.741],
|
||||
"11_3d": [1.0, 1.0, 1.0],
|
||||
"12_3d": [1.0, 1.0, 1.0],
|
||||
"01_2d": [1.0, 1.0],
|
||||
"02_2d": [1.25, 1.25],
|
||||
"03_2d": [0.7676, 0.7676],
|
||||
|
@ -80,7 +85,6 @@ ct_min = {
|
|||
"08": -3,
|
||||
"09": -41,
|
||||
"10": -30,
|
||||
"11": -958,
|
||||
}
|
||||
|
||||
ct_max = {
|
||||
|
@ -90,9 +94,8 @@ ct_max = {
|
|||
"08": 243,
|
||||
"09": 176,
|
||||
"10": 165.82,
|
||||
"11": 93,
|
||||
}
|
||||
|
||||
ct_mean = {"03": 99.4, "06": -158.58, "07": 77.9, "08": 104.37, "09": 99.29, "10": 62.18, "11": -547.7}
|
||||
ct_mean = {"03": 99.4, "06": -158.58, "07": 77.9, "08": 104.37, "09": 99.29, "10": 62.18}
|
||||
|
||||
ct_std = {"03": 39.36, "06": 324.7, "07": 75.4, "08": 52.62, "09": 39.47, "10": 32.65, "11": 281.08}
|
||||
ct_std = {"03": 39.36, "06": 324.7, "07": 75.4, "08": 52.62, "09": 39.47, "10": 32.65}
|
||||
|
|
|
@ -22,7 +22,6 @@ import monai.transforms as transforms
|
|||
import nibabel
|
||||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
from skimage.morphology import dilation, erosion, square
|
||||
from skimage.transform import resize
|
||||
from utils.utils import get_task_code, make_empty_dir
|
||||
|
||||
|
@ -39,32 +38,34 @@ class Preprocessor:
|
|||
self.target_spacing = None
|
||||
self.task = args.task
|
||||
self.task_code = get_task_code(args)
|
||||
self.verbose = args.verbose
|
||||
self.patch_size = patch_size[self.task_code]
|
||||
self.training = args.exec_mode == "training"
|
||||
self.data_path = os.path.join(args.data, task[args.task])
|
||||
metadata_path = os.path.join(self.data_path, "dataset.json")
|
||||
self.metadata = json.load(open(metadata_path, "r"))
|
||||
self.modality = self.metadata["modality"]["0"]
|
||||
self.results = os.path.join(args.results, self.task_code)
|
||||
if not self.training:
|
||||
self.results = os.path.join(self.results, self.args.exec_mode)
|
||||
self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
|
||||
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=False, channel_wise=True)
|
||||
metadata_path = os.path.join(self.data_path, "dataset.json")
|
||||
nonzero = True if self.modality != "CT" else False # normalize only non-zero region for MRI
|
||||
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=nonzero, channel_wise=True)
|
||||
if self.args.exec_mode == "val":
|
||||
dataset_json = json.load(open(metadata_path, "r"))
|
||||
dataset_json["val"] = dataset_json["training"]
|
||||
with open(metadata_path, "w") as outfile:
|
||||
json.dump(dataset_json, outfile)
|
||||
self.metadata = json.load(open(metadata_path, "r"))
|
||||
self.modality = self.metadata["modality"]["0"]
|
||||
|
||||
def run(self):
|
||||
make_empty_dir(self.results)
|
||||
|
||||
print(f"Preprocessing {self.data_path}")
|
||||
try:
|
||||
self.target_spacing = spacings[self.task_code]
|
||||
except:
|
||||
self.collect_spacings()
|
||||
print(f"Target spacing {self.target_spacing}")
|
||||
if self.verbose:
|
||||
print(f"Target spacing {self.target_spacing}")
|
||||
|
||||
if self.modality == "CT":
|
||||
try:
|
||||
|
@ -77,7 +78,8 @@ class Preprocessor:
|
|||
|
||||
_mean = round(self.ct_mean, 2)
|
||||
_std = round(self.ct_std, 2)
|
||||
print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}")
|
||||
if self.verbose:
|
||||
print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}")
|
||||
|
||||
self.run_parallel(self.preprocess_pair, self.args.exec_mode)
|
||||
|
||||
|
@ -86,7 +88,7 @@ class Preprocessor:
|
|||
"patch_size": self.patch_size,
|
||||
"spacings": self.target_spacing,
|
||||
"n_class": len(self.metadata["labels"]),
|
||||
"in_channels": len(self.metadata["modality"]),
|
||||
"in_channels": len(self.metadata["modality"]) + int(self.args.ohe),
|
||||
},
|
||||
open(os.path.join(self.results, "config.pkl"), "wb"),
|
||||
)
|
||||
|
@ -99,9 +101,10 @@ class Preprocessor:
|
|||
image, label = data["image"], data["label"]
|
||||
test_metadata = None
|
||||
else:
|
||||
orig_shape = image.shape[1:]
|
||||
bbox = transforms.utils.generate_spatial_bounding_box(image)
|
||||
test_metadata = np.vstack([bbox, image.shape[1:]])
|
||||
image = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(image)
|
||||
test_metadata = np.vstack([bbox, orig_shape, image.shape[1:]])
|
||||
if label is not None:
|
||||
label = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(label)
|
||||
if self.args.dim == 3:
|
||||
|
@ -111,11 +114,16 @@ class Preprocessor:
|
|||
image = self.normalize(image)
|
||||
if self.training:
|
||||
image, label = self.standardize(image, label)
|
||||
if self.args.dilation:
|
||||
new_lbl = np.zeros(label.shape, dtype=np.uint8)
|
||||
for depth in range(label.shape[1]):
|
||||
new_lbl[0, depth] = erosion(dilation(label[0, depth], square(3)), square(3))
|
||||
label = new_lbl
|
||||
|
||||
if self.args.ohe:
|
||||
mask = np.ones(image.shape[1:], dtype=np.float32)
|
||||
for i in range(image.shape[0]):
|
||||
zeros = np.where(image[i] <= 0)
|
||||
mask[zeros] *= 0.0
|
||||
image = self.normalize_intensity(image).astype(np.float32)
|
||||
mask = np.expand_dims(mask, 0)
|
||||
image = np.concatenate([image, mask])
|
||||
|
||||
self.save(image, label, fname, test_metadata)
|
||||
|
||||
def resample(self, image, label, image_spacings):
|
||||
|
@ -145,7 +153,8 @@ class Preprocessor:
|
|||
|
||||
def save(self, image, label, fname, test_metadata):
|
||||
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
|
||||
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
|
||||
if self.verbose:
|
||||
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
|
||||
self.save_npy(image, fname, "_x.npy")
|
||||
if label is not None:
|
||||
self.save_npy(label, fname, "_y.npy")
|
||||
|
|
BIN
PyTorch/Segmentation/nnUNet/images/unet-brats.jpg
Normal file
BIN
PyTorch/Segmentation/nnUNet/images/unet-brats.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 64 KiB |
|
@ -14,9 +14,10 @@
|
|||
|
||||
import os
|
||||
|
||||
import nvidia_dlprof_pytorch_nvtx
|
||||
import torch
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, early_stopping
|
||||
|
||||
from data_loading.data_module import DataModule
|
||||
from models.nn_unet import NNUnet
|
||||
|
@ -28,8 +29,6 @@ if __name__ == "__main__":
|
|||
args = get_main_args()
|
||||
|
||||
if args.profile:
|
||||
import nvidia_dlprof_pytorch_nvtx
|
||||
|
||||
nvidia_dlprof_pytorch_nvtx.init()
|
||||
print("Profiling enabled")
|
||||
|
||||
|
@ -61,9 +60,13 @@ if __name__ == "__main__":
|
|||
]
|
||||
elif args.exec_mode == "train":
|
||||
model = NNUnet(args)
|
||||
early_stopping = EarlyStopping(monitor="dice_mean", patience=args.patience, verbose=True, mode="max")
|
||||
callbacks = [early_stopping]
|
||||
if args.save_ckpt:
|
||||
model_ckpt = ModelCheckpoint(monitor="dice_sum", mode="max", save_last=True)
|
||||
callbacks = [EarlyStopping(monitor="dice_sum", patience=args.patience, verbose=True, mode="max")]
|
||||
model_ckpt = ModelCheckpoint(
|
||||
filename="{epoch}-{dice_mean:.2f}", monitor="dice_mean", mode="max", save_last=True
|
||||
)
|
||||
callbacks.append(model_ckpt)
|
||||
else: # Evaluation or inference
|
||||
if ckpt_path is not None:
|
||||
model = NNUnet.load_from_checkpoint(ckpt_path)
|
||||
|
@ -76,8 +79,8 @@ if __name__ == "__main__":
|
|||
precision=16 if args.amp else 32,
|
||||
benchmark=True,
|
||||
deterministic=False,
|
||||
min_epochs=args.min_epochs,
|
||||
max_epochs=args.max_epochs,
|
||||
min_epochs=args.epochs,
|
||||
max_epochs=args.epochs,
|
||||
sync_batchnorm=args.sync_batchnorm,
|
||||
gradient_clip_val=args.gradient_clip_val,
|
||||
callbacks=callbacks,
|
||||
|
@ -85,7 +88,6 @@ if __name__ == "__main__":
|
|||
default_root_dir=args.results,
|
||||
resume_from_checkpoint=ckpt_path,
|
||||
accelerator="ddp" if args.gpus > 1 else None,
|
||||
checkpoint_callback=model_ckpt,
|
||||
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
|
||||
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
|
||||
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
|
||||
|
@ -106,6 +108,9 @@ if __name__ == "__main__":
|
|||
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
||||
elif args.exec_mode == "train":
|
||||
trainer.fit(model, data_module)
|
||||
if is_main_process():
|
||||
logname = args.logname if args.logname is not None else "train_log.json"
|
||||
log(logname, torch.tensor(model.best_mean_dice), results=args.results)
|
||||
elif args.exec_mode == "evaluate":
|
||||
model.args = args
|
||||
trainer.test(model, test_dataloaders=data_module.val_dataloader())
|
||||
|
@ -113,13 +118,14 @@ if __name__ == "__main__":
|
|||
logname = args.logname if args.logname is not None else "eval_log.json"
|
||||
log(logname, model.eval_dice, results=args.results)
|
||||
elif args.exec_mode == "predict":
|
||||
model.args = args
|
||||
if args.save_preds:
|
||||
prec = "amp" if args.amp else "fp32"
|
||||
dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}_{prec}"
|
||||
ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
|
||||
dir_name = f"predictions_{ckpt_name}"
|
||||
dir_name += f"_task={model.args.task}_fold={model.args.fold}"
|
||||
if args.tta:
|
||||
dir_name += "_tta"
|
||||
save_dir = os.path.join(args.results, dir_name)
|
||||
model.save_dir = save_dir
|
||||
make_empty_dir(save_dir)
|
||||
model.args = args
|
||||
trainer.test(model, test_dataloaders=data_module.test_dataloader())
|
||||
|
|
|
@ -67,7 +67,6 @@ def get_output_padding(kernel_size, stride, padding):
|
|||
return out_padding if len(out_padding) > 1 else out_padding[0]
|
||||
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||
super(ConvLayer, self).__init__()
|
||||
|
@ -94,30 +93,6 @@ class ConvBlock(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class ResidBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||
super(ResidBlock, self).__init__()
|
||||
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
|
||||
self.conv2 = get_conv(out_channels, out_channels, kernel_size, 1, kwargs["dim"])
|
||||
self.norm = get_norm(kwargs["norm"], out_channels)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
|
||||
self.downsample = None
|
||||
if max(stride) > 1 or in_channels != out_channels:
|
||||
self.downsample = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
|
||||
self.norm_res = get_norm(kwargs["norm"], out_channels)
|
||||
|
||||
def forward(self, input_data):
|
||||
residual = input_data
|
||||
out = self.conv1(input_data)
|
||||
out = self.conv2(out)
|
||||
out = self.norm(out)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
residual = self.norm_res(residual)
|
||||
out = self.lrelu(out + residual)
|
||||
return out
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
|
||||
super(UpsampleBlock, self).__init__()
|
||||
|
|
|
@ -13,21 +13,32 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from monai.losses import DiceLoss, FocalLoss
|
||||
from monai.losses import DiceCELoss, DiceFocalLoss, DiceLoss, FocalLoss
|
||||
|
||||
|
||||
class Loss(nn.Module):
|
||||
def __init__(self, focal):
|
||||
super(Loss, self).__init__()
|
||||
self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
|
||||
self.focal = FocalLoss(gamma=2.0)
|
||||
self.cross_entropy = nn.CrossEntropyLoss()
|
||||
self.use_focal = focal
|
||||
if focal:
|
||||
self.loss = DiceFocalLoss(gamma=2.0, softmax=True, to_onehot_y=True, batch=True)
|
||||
else:
|
||||
self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
loss = self.dice(y_pred, y_true)
|
||||
if self.use_focal:
|
||||
loss += self.focal(y_pred, y_true)
|
||||
else:
|
||||
loss += self.cross_entropy(y_pred, y_true[:, 0].long())
|
||||
return loss
|
||||
return self.loss(y_pred, y_true)
|
||||
|
||||
|
||||
class LossBraTS(nn.Module):
|
||||
def __init__(self, focal):
|
||||
super(LossBraTS, self).__init__()
|
||||
self.dice = DiceLoss(sigmoid=True, batch=True)
|
||||
self.ce = FocalLoss(gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
|
||||
|
||||
def _loss(self, p, y):
|
||||
return self.dice(p, y) + self.ce(p, y.float())
|
||||
|
||||
def forward(self, p, y):
|
||||
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
|
||||
p_wt, p_tc, p_et = p[:, 0].unsqueeze(1), p[:, 1].unsqueeze(1), p[:, 2].unsqueeze(1)
|
||||
l_wt, l_tc, l_et = self._loss(p_wt, y_wt), self._loss(p_tc, y_tc), self._loss(p_et, y_et)
|
||||
return l_wt + l_tc + l_et
|
||||
|
|
|
@ -13,35 +13,61 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.metrics.functional import stat_scores
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
from torchmetrics import Metric
|
||||
|
||||
|
||||
class Dice(Metric):
|
||||
def __init__(self, nclass):
|
||||
super().__init__(dist_sync_on_step=True)
|
||||
self.add_state("n_updates", default=torch.zeros(1), dist_reduce_fx="sum")
|
||||
self.add_state("dice", default=torch.zeros((nclass,)), dist_reduce_fx="sum")
|
||||
def __init__(self, n_class, brats):
|
||||
super().__init__(dist_sync_on_step=False)
|
||||
self.n_class = n_class
|
||||
self.brats = brats
|
||||
self.add_state("steps", default=torch.zeros(1), dist_reduce_fx="sum")
|
||||
self.add_state("dice", default=torch.zeros((n_class,)), dist_reduce_fx="sum")
|
||||
self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, pred, target):
|
||||
self.n_updates += 1
|
||||
self.dice += self.compute_stats(pred, target)
|
||||
def update(self, preds, target, loss):
|
||||
self.steps += 1
|
||||
self.dice += self.compute_stats_brats(preds, target) if self.brats else self.compute_stats(preds, target)
|
||||
self.loss += loss
|
||||
|
||||
def compute(self):
|
||||
return 100 * self.dice / self.n_updates
|
||||
return 100 * self.dice / self.steps, self.loss / self.steps
|
||||
|
||||
@staticmethod
|
||||
def compute_stats(pred, target):
|
||||
num_classes = pred.shape[1]
|
||||
scores = torch.zeros(num_classes - 1, device=pred.device, dtype=torch.float32)
|
||||
for i in range(1, num_classes):
|
||||
if (target != i).all():
|
||||
def compute_stats_brats(self, p, y):
|
||||
scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32)
|
||||
p = (torch.sigmoid(p) > 0.5).int()
|
||||
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
|
||||
y = torch.stack([y_wt, y_tc, y_et], dim=1)
|
||||
|
||||
for i in range(self.n_class):
|
||||
p_i, y_i = p[:, i], y[:, i]
|
||||
if (y_i != 1).all():
|
||||
# no foreground class
|
||||
_, _pred = torch.max(pred, 1)
|
||||
scores[i - 1] += 1 if (_pred != i).all() else 0
|
||||
scores[i - 1] += 1 if (p_i != 1).all() else 0
|
||||
continue
|
||||
_tp, _fp, _tn, _fn, _ = stat_scores(pred=pred, target=target, class_index=i)
|
||||
denom = (2 * _tp + _fp + _fn).to(torch.float)
|
||||
score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
|
||||
tp, fn, fp = self.get_stats(p_i, y_i, 1)
|
||||
denom = (2 * tp + fp + fn).to(torch.float)
|
||||
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
|
||||
scores[i - 1] += score_cls
|
||||
return scores
|
||||
|
||||
def compute_stats(self, preds, target):
|
||||
scores = torch.zeros(self.n_class, device=preds.device, dtype=torch.float32)
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
for i in range(1, self.n_class + 1):
|
||||
if (target != i).all():
|
||||
# no foreground class
|
||||
scores[i - 1] += 1 if (preds != i).all() else 0
|
||||
continue
|
||||
tp, fn, fp = self.get_stats(preds, target, i)
|
||||
denom = (2 * tp + fp + fn).to(torch.float)
|
||||
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
|
||||
scores[i - 1] += score_cls
|
||||
return scores
|
||||
|
||||
@staticmethod
|
||||
def get_stats(preds, target, class_idx):
|
||||
tp = torch.logical_and(preds == class_idx, target == class_idx).sum()
|
||||
fn = torch.logical_and(preds != class_idx, target == class_idx).sum()
|
||||
fp = torch.logical_and(preds == class_idx, target != class_idx).sum()
|
||||
return tp, fn, fp
|
||||
|
|
|
@ -20,8 +20,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
from apex.optimizers import FusedAdam, FusedSGD
|
||||
from monai.inferers import sliding_window_inference
|
||||
from scipy.special import expit, softmax
|
||||
from skimage.transform import resize
|
||||
from torch_optimizer import RAdam
|
||||
from utils.scheduler import WarmupCosineSchedule
|
||||
from utils.utils import (
|
||||
flip,
|
||||
get_dllogger,
|
||||
|
@ -33,7 +34,7 @@ from utils.utils import (
|
|||
layout_2d,
|
||||
)
|
||||
|
||||
from models.loss import Loss
|
||||
from models.loss import Loss, LossBraTS
|
||||
from models.metrics import Dice
|
||||
from models.unet import UNet
|
||||
|
||||
|
@ -41,24 +42,25 @@ from models.unet import UNet
|
|||
class NNUnet(pl.LightningModule):
|
||||
def __init__(self, args, bermuda=False, data_dir=None):
|
||||
super(NNUnet, self).__init__()
|
||||
self.save_hyperparameters()
|
||||
self.args = args
|
||||
self.bermuda = bermuda
|
||||
if data_dir is not None:
|
||||
self.args.data = data_dir
|
||||
self.save_hyperparameters()
|
||||
self.build_nnunet()
|
||||
self.best_sum = 0
|
||||
self.best_sum_epoch = 0
|
||||
self.best_mean = 0
|
||||
self.best_mean_epoch = 0
|
||||
self.best_dice = self.n_class * [0]
|
||||
self.best_epoch = self.n_class * [0]
|
||||
self.best_sum_dice = self.n_class * [0]
|
||||
self.best_mean_dice = self.n_class * [0]
|
||||
self.test_idx = 0
|
||||
self.test_imgs = []
|
||||
if not self.bermuda:
|
||||
self.learning_rate = args.learning_rate
|
||||
self.loss = Loss(self.args.focal)
|
||||
loss = LossBraTS if self.args.brats else Loss
|
||||
self.loss = loss(self.args.focal)
|
||||
self.tta_flips = get_tta_flips(args.dim)
|
||||
self.dice = Dice(self.n_class)
|
||||
self.dice = Dice(self.n_class, self.args.brats)
|
||||
if self.args.exec_mode in ["train", "evaluate"]:
|
||||
self.dllogger = get_dllogger(args.results)
|
||||
|
||||
|
@ -72,10 +74,17 @@ class NNUnet(pl.LightningModule):
|
|||
return self.model(img)
|
||||
return self.tta_inference(img) if self.args.tta else self.do_inference(img)
|
||||
|
||||
def compute_loss(self, preds, label):
|
||||
if self.args.deep_supervision:
|
||||
pred0, pred1, pred2 = preds
|
||||
loss = self.loss(pred0, label) + 0.5 * self.loss(pred1, label) + 0.25 * self.loss(pred2, label)
|
||||
return loss / 1.75
|
||||
return self.loss(preds, label)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
img, lbl = self.get_train_data(batch)
|
||||
pred = self.model(img)
|
||||
loss = self.loss(pred, lbl)
|
||||
loss = self.compute_loss(pred, lbl)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
|
@ -84,49 +93,50 @@ class NNUnet(pl.LightningModule):
|
|||
img, lbl = batch["image"], batch["label"]
|
||||
pred = self._forward(img)
|
||||
loss = self.loss(pred, lbl)
|
||||
self.dice.update(pred, lbl[:, 0])
|
||||
return {"val_loss": loss}
|
||||
self.dice.update(pred, lbl[:, 0], loss)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
if self.args.exec_mode == "evaluate":
|
||||
return self.validation_step(batch, batch_idx)
|
||||
img = batch["image"]
|
||||
pred = self._forward(img)
|
||||
pred = self._forward(img).squeeze(0).cpu().detach().numpy()
|
||||
if self.args.save_preds:
|
||||
meta = batch["meta"][0].cpu().detach().numpy()
|
||||
original_shape = meta[2]
|
||||
min_d, max_d = meta[0, 0], meta[1, 0]
|
||||
min_h, max_h = meta[0, 1], meta[1, 1]
|
||||
min_w, max_w = meta[0, 2], meta[1, 2]
|
||||
|
||||
final_pred = torch.zeros((1, pred.shape[1], *original_shape), device=img.device)
|
||||
final_pred[:, :, min_d:max_d, min_h:max_h, min_w:max_w] = pred
|
||||
final_pred = nn.functional.softmax(final_pred, dim=1)
|
||||
final_pred = final_pred.squeeze(0).cpu().detach().numpy()
|
||||
|
||||
if not all(original_shape == final_pred.shape[1:]):
|
||||
class_ = final_pred.shape[0]
|
||||
resized_pred = np.zeros((class_, *original_shape))
|
||||
for i in range(class_):
|
||||
n_class, original_shape, cropped_shape = pred.shape[0], meta[2], meta[3]
|
||||
if not all(cropped_shape == pred.shape[1:]):
|
||||
resized_pred = np.zeros((n_class, *cropped_shape))
|
||||
for i in range(n_class):
|
||||
resized_pred[i] = resize(
|
||||
final_pred[i], original_shape, order=3, mode="edge", cval=0, clip=True, anti_aliasing=False
|
||||
pred[i], cropped_shape, order=3, mode="edge", cval=0, clip=True, anti_aliasing=False
|
||||
)
|
||||
final_pred = resized_pred
|
||||
pred = resized_pred
|
||||
final_pred = np.zeros((n_class, *original_shape))
|
||||
final_pred[:, min_d:max_d, min_h:max_h, min_w:max_w] = pred
|
||||
if self.args.brats:
|
||||
final_pred = expit(final_pred)
|
||||
else:
|
||||
final_pred = softmax(final_pred, axis=0)
|
||||
|
||||
self.save_mask(final_pred)
|
||||
|
||||
def build_nnunet(self):
|
||||
in_channels, n_class, kernels, strides, self.patch_size = get_unet_params(self.args)
|
||||
self.n_class = n_class - 1
|
||||
if self.args.brats:
|
||||
n_class = 3
|
||||
self.model = UNet(
|
||||
in_channels=in_channels,
|
||||
n_class=n_class,
|
||||
kernels=kernels,
|
||||
strides=strides,
|
||||
dimension=self.args.dim,
|
||||
residual=self.args.residual,
|
||||
normalization_layer=self.args.norm,
|
||||
negative_slope=self.args.negative_slope,
|
||||
deep_supervision=self.args.deep_supervision,
|
||||
more_chn=self.args.more_chn,
|
||||
)
|
||||
if is_main_process():
|
||||
print(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}")
|
||||
|
@ -180,39 +190,30 @@ class NNUnet(pl.LightningModule):
|
|||
mode=self.args.blend,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def metric_mean(name, outputs):
|
||||
return torch.stack([out[name] for out in outputs]).mean(dim=0)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
if self.current_epoch < self.args.skip_first_n_eval:
|
||||
self.log("dice_sum", 0.001 * self.current_epoch)
|
||||
self.dice.reset()
|
||||
return None
|
||||
loss = self.metric_mean("val_loss", outputs)
|
||||
dice = self.dice.compute()
|
||||
dice_sum = torch.sum(dice)
|
||||
if dice_sum >= self.best_sum:
|
||||
self.best_sum = dice_sum
|
||||
self.best_sum_dice = dice[:]
|
||||
self.best_sum_epoch = self.current_epoch
|
||||
dice, loss = self.dice.compute()
|
||||
self.dice.reset()
|
||||
dice_mean = torch.mean(dice)
|
||||
if dice_mean >= self.best_mean:
|
||||
self.best_mean = dice_mean
|
||||
self.best_mean_dice = dice[:]
|
||||
self.best_mean_epoch = self.current_epoch
|
||||
for i, dice_i in enumerate(dice):
|
||||
if dice_i > self.best_dice[i]:
|
||||
self.best_dice[i], self.best_epoch[i] = dice_i, self.current_epoch
|
||||
|
||||
if is_main_process():
|
||||
metrics = {}
|
||||
metrics.update({"mean dice": round(torch.mean(dice).item(), 2)})
|
||||
metrics.update({"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)})
|
||||
metrics.update({"Mean dice": round(torch.mean(dice).item(), 2)})
|
||||
metrics.update({"Highest": round(torch.mean(self.best_mean_dice).item(), 2)})
|
||||
if self.n_class > 1:
|
||||
metrics.update({f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)})
|
||||
metrics.update({f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice)})
|
||||
metrics.update({"val_loss": round(loss.item(), 4)})
|
||||
self.dllogger.log(step=self.current_epoch, data=metrics)
|
||||
self.dllogger.flush()
|
||||
|
||||
self.log("val_loss", loss)
|
||||
self.log("dice_sum", dice_sum)
|
||||
self.log("dice_mean", dice_mean)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
if self.args.exec_mode == "evaluate":
|
||||
|
@ -222,22 +223,20 @@ class NNUnet(pl.LightningModule):
|
|||
optimizer = {
|
||||
"sgd": FusedSGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum),
|
||||
"adam": FusedAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
||||
"radam": RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
|
||||
}[self.args.optimizer.lower()]
|
||||
|
||||
scheduler = {
|
||||
"none": None,
|
||||
"multistep": torch.optim.lr_scheduler.MultiStepLR(optimizer, self.args.steps, gamma=self.args.factor),
|
||||
"cosine": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.args.max_epochs),
|
||||
"plateau": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, factor=self.args.factor, patience=self.args.lr_patience
|
||||
),
|
||||
}[self.args.scheduler.lower()]
|
||||
|
||||
opt_dict = {"optimizer": optimizer, "monitor": "val_loss"}
|
||||
if scheduler is not None:
|
||||
opt_dict.update({"lr_scheduler": scheduler})
|
||||
return opt_dict
|
||||
if self.args.scheduler:
|
||||
scheduler = {
|
||||
"scheduler": WarmupCosineSchedule(
|
||||
optimizer=optimizer,
|
||||
warmup_steps=250,
|
||||
t_total=self.args.epochs * len(self.trainer.datamodule.train_dataloader()),
|
||||
),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
}
|
||||
return {"optimizer": optimizer, "monitor": "val_loss", "lr_scheduler": scheduler}
|
||||
return {"optimizer": optimizer, "monitor": "val_loss"}
|
||||
|
||||
def save_mask(self, pred):
|
||||
if self.test_idx == 0:
|
||||
|
|
|
@ -11,9 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.layers import ConvBlock, OutputBlock, ResidBlock, UpsampleBlock
|
||||
from models.layers import ConvBlock, OutputBlock, UpsampleBlock
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
|
@ -25,34 +27,39 @@ class UNet(nn.Module):
|
|||
strides,
|
||||
normalization_layer,
|
||||
negative_slope,
|
||||
residual,
|
||||
dimension,
|
||||
deep_supervision,
|
||||
more_chn,
|
||||
):
|
||||
super(UNet, self).__init__()
|
||||
self.more_chn = more_chn
|
||||
self.dim = dimension
|
||||
self.n_class = n_class
|
||||
self.residual = residual
|
||||
self.negative_slope = negative_slope
|
||||
self.norm = normalization_layer + f"norm{dimension}d"
|
||||
self.filters = [min(2 ** (5 + i), 320 if dimension == 3 else 512) for i in range(len(strides))]
|
||||
self.deep_supervision = deep_supervision
|
||||
self.depth = len(strides)
|
||||
if self.more_chn:
|
||||
self.filters = [64, 96, 128, 192, 256, 384, 512, 768, 1024][: self.depth]
|
||||
else:
|
||||
self.filters = [min(2 ** (5 + i), 320 if dimension == 3 else 512) for i in range(self.depth)]
|
||||
|
||||
down_block = ResidBlock if self.residual else ConvBlock
|
||||
self.input_block = self.get_conv_block(
|
||||
conv_block=down_block,
|
||||
conv_block=ConvBlock,
|
||||
in_channels=in_channels,
|
||||
out_channels=self.filters[0],
|
||||
kernel_size=kernels[0],
|
||||
stride=strides[0],
|
||||
)
|
||||
self.downsamples = self.get_module_list(
|
||||
conv_block=down_block,
|
||||
conv_block=ConvBlock,
|
||||
in_channels=self.filters[:-1],
|
||||
out_channels=self.filters[1:],
|
||||
kernels=kernels[1:-1],
|
||||
strides=strides[1:-1],
|
||||
)
|
||||
self.bottleneck = self.get_conv_block(
|
||||
conv_block=down_block,
|
||||
conv_block=ConvBlock,
|
||||
in_channels=self.filters[-2],
|
||||
out_channels=self.filters[-1],
|
||||
kernel_size=kernels[-1],
|
||||
|
@ -65,6 +72,8 @@ class UNet(nn.Module):
|
|||
kernels=kernels[1:][::-1],
|
||||
strides=strides[1:][::-1],
|
||||
)
|
||||
self.deep_supervision_head1 = self.get_output_block(1)
|
||||
self.deep_supervision_head2 = self.get_output_block(2)
|
||||
self.output_block = self.get_output_block(decoder_level=0)
|
||||
self.apply(self.initialize_weights)
|
||||
self.n_layers = len(self.upsamples) - 1
|
||||
|
@ -76,9 +85,17 @@ class UNet(nn.Module):
|
|||
out = downsample(out)
|
||||
encoder_outputs.append(out)
|
||||
out = self.bottleneck(out)
|
||||
for idx, upsample in enumerate(self.upsamples):
|
||||
out = upsample(out, encoder_outputs[self.n_layers - idx])
|
||||
decoder_outputs = []
|
||||
for i, upsample in enumerate(self.upsamples):
|
||||
out = upsample(out, encoder_outputs[self.depth - i - 2])
|
||||
decoder_outputs.append(out)
|
||||
out = self.output_block(out)
|
||||
if self.training and self.deep_supervision:
|
||||
out1 = self.deep_supervision_head1(decoder_outputs[-2])
|
||||
out2 = self.deep_supervision_head2(decoder_outputs[-3])
|
||||
out1 = nn.functional.interpolate(out1, out.shape[2:], mode="trilinear", align_corners=True)
|
||||
out2 = nn.functional.interpolate(out2, out.shape[2:], mode="trilinear", align_corners=True)
|
||||
return torch.stack([out, out1, out2])
|
||||
return out
|
||||
|
||||
def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride):
|
||||
|
|
746
PyTorch/Segmentation/nnUNet/notebooks/BraTS21.ipynb
Normal file
746
PyTorch/Segmentation/nnUNet/notebooks/BraTS21.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -29,7 +29,8 @@ parser.add_argument(
|
|||
choices=["training", "val", "test"],
|
||||
help="Mode for data preprocessing",
|
||||
)
|
||||
parser.add_argument("--dilation", action="store_true", help="Perform morphological label dilation")
|
||||
parser.add_argument("--ohe", action="store_true", help="Add one-hot-encoding for foreground voxels (voxels > 0)")
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
parser.add_argument("--task", type=str, help="Number of task to be run. MSD uses numbers 01-10")
|
||||
parser.add_argument("--dim", type=int, default=3, choices=[2, 3], help="Data dimension to prepare")
|
||||
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs for data preprocessing")
|
||||
|
@ -44,4 +45,4 @@ if __name__ == "__main__":
|
|||
if args.exec_mode == "test":
|
||||
path = os.path.join(path, "test")
|
||||
end = time.time()
|
||||
print(f"Preprocessing time: {(end - start):.2f}")
|
||||
print(f"Pre-processing time: {(end - start):.2f}")
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
git+https://github.com/NVIDIA/dllogger
|
||||
nibabel==3.1.1
|
||||
joblib==0.16.0
|
||||
scikit-learn==0.23.2
|
||||
pynvml==8.0.4
|
||||
pillow==6.2.0
|
||||
fsspec==0.8.0
|
||||
pytorch_ranger==0.1.1
|
||||
nibabel==3.2.1
|
||||
joblib==1.0.1
|
||||
pytorch-lightning==1.3.8
|
||||
scikit-learn==1.0
|
||||
scikit-image==0.18.3
|
||||
pynvml==11.0.0
|
||||
|
|
|
@ -35,7 +35,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
|
||||
cmd = ""
|
||||
cmd += f"python main.py --task {args.task} --benchmark --max_epochs 2 --min_epochs 1 --optimizer adam "
|
||||
cmd += f"python main.py --task {args.task} --benchmark --epochs 2 "
|
||||
cmd += f"--results {args.results} "
|
||||
cmd += f"--logname {args.logname} "
|
||||
cmd += f"--exec_mode {args.mode} "
|
||||
|
|
|
@ -24,6 +24,7 @@ parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4],
|
|||
parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
|
||||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||
parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
|
||||
parser.add_argument("--resume_training", action="store_true", help="Resume training from checkpoint")
|
||||
parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
|
||||
parser.add_argument("--logname", type=str, default="log", help="Name of dlloger output")
|
||||
|
||||
|
@ -40,4 +41,5 @@ if __name__ == "__main__":
|
|||
cmd += f"--gpus {args.gpus} "
|
||||
cmd += "--amp " if args.amp else ""
|
||||
cmd += "--tta " if args.tta else ""
|
||||
cmd += "--resume_training " if args.resume_training else ""
|
||||
call(cmd, shell=True)
|
||||
|
|
17
PyTorch/Segmentation/nnUNet/utils/scheduler.py
Normal file
17
PyTorch/Segmentation/nnUNet/utils/scheduler.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import math
|
||||
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
|
||||
class WarmupCosineSchedule(LambdaLR):
|
||||
def __init__(self, optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
self.cycles = cycles
|
||||
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1.0, self.warmup_steps))
|
||||
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
|
|
@ -101,7 +101,7 @@ def get_unet_params(args):
|
|||
strides, kernels, sizes = [], [], patch_size[:]
|
||||
while True:
|
||||
spacing_ratio = [spacing / min(spacings) for spacing in spacings]
|
||||
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
|
||||
stride = [2 if ratio <= 2 and size >= 2 * args.min_fmap else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
|
||||
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
|
||||
if all(s == 1 for s in stride):
|
||||
break
|
||||
|
@ -109,7 +109,7 @@ def get_unet_params(args):
|
|||
spacings = [i * j for i, j in zip(spacings, stride)]
|
||||
kernels.append(kernel)
|
||||
strides.append(stride)
|
||||
if len(strides) == 5:
|
||||
if len(strides) == 6:
|
||||
break
|
||||
strides.insert(0, len(spacings) * [1])
|
||||
kernels.append(len(spacings) * [3])
|
||||
|
@ -184,13 +184,15 @@ def get_main_args(strings=None):
|
|||
arg("--logname", type=str, default=None, help="Name of dlloger output")
|
||||
arg("--task", type=str, help="Task number. MSD uses numbers 01-10")
|
||||
arg("--gpus", type=non_negative_int, default=1, help="Number of gpus")
|
||||
arg("--learning_rate", type=float, default=0.001, help="Learning rate")
|
||||
arg("--learning_rate", type=float, default=0.0008, help="Learning rate")
|
||||
arg("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
|
||||
arg("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
|
||||
arg("--tta", action="store_true", help="Enable test time augmentation")
|
||||
arg("--brats", action="store_true", help="Enable BraTS specific training and inference")
|
||||
arg("--deep_supervision", action="store_true", help="Enable deep supervision")
|
||||
arg("--more_chn", action="store_true", help="Create encoder with more channels")
|
||||
arg("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||
arg("--benchmark", action="store_true", help="Run model benchmarking")
|
||||
arg("--residual", action="store_true", help="Enable residual block in encoder")
|
||||
arg("--focal", action="store_true", help="Use focal loss instead of cross entropy")
|
||||
arg("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm")
|
||||
arg("--save_ckpt", action="store_true", help="Enable saving checkpoint")
|
||||
|
@ -200,20 +202,16 @@ def get_main_args(strings=None):
|
|||
arg("--ckpt_path", type=str, default=None, help="Path to checkpoint")
|
||||
arg("--fold", type=non_negative_int, default=0, help="Fold number")
|
||||
arg("--patience", type=positive_int, default=100, help="Early stopping patience")
|
||||
arg("--lr_patience", type=positive_int, default=70, help="Patience for ReduceLROnPlateau scheduler")
|
||||
arg("--batch_size", type=positive_int, default=2, help="Batch size")
|
||||
arg("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
|
||||
arg("--steps", nargs="+", type=positive_int, required=False, help="Steps for multistep scheduler")
|
||||
arg("--profile", action="store_true", help="Run dlprof profiling")
|
||||
arg("--momentum", type=float, default=0.99, help="Momentum factor")
|
||||
arg("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
|
||||
arg("--save_preds", action="store_true", help="Enable prediction saving")
|
||||
arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
|
||||
arg("--resume_training", action="store_true", help="Resume training from the last checkpoint")
|
||||
arg("--factor", type=float, default=0.3, help="Scheduler factor")
|
||||
arg("--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading")
|
||||
arg("--min_epochs", type=non_negative_int, default=30, help="Force training for at least these many epochs")
|
||||
arg("--max_epochs", type=non_negative_int, default=10000, help="Stop training after this number of epochs")
|
||||
arg("--epochs", type=non_negative_int, default=1000, help="Number of training epochs")
|
||||
arg("--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics")
|
||||
arg("--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer")
|
||||
arg("--nvol", type=positive_int, default=1, help="Number of volumes which come into single batch size for 2D model")
|
||||
|
@ -252,18 +250,22 @@ def get_main_args(strings=None):
|
|||
)
|
||||
arg(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["none", "multistep", "cosine", "plateau"],
|
||||
help="Learning rate scheduler",
|
||||
action="store_true",
|
||||
help="Enable cosine rate scheduler with warmup",
|
||||
)
|
||||
arg(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="radam",
|
||||
choices=["sgd", "radam", "adam"],
|
||||
default="adam",
|
||||
choices=["sgd", "adam"],
|
||||
help="Optimizer",
|
||||
)
|
||||
arg(
|
||||
"--min_fmap",
|
||||
type=non_negative_int,
|
||||
default=4,
|
||||
help="The minimal size that feature map can be reduced in bottleneck",
|
||||
)
|
||||
arg(
|
||||
"--blend",
|
||||
type=str,
|
||||
|
|
Loading…
Reference in a new issue