DeepLearningExamples/TensorFlow/Segmentation/VNet/utils/data_loader.py
Przemek Strzelczyk b4aef9945b Adding VNet/TF
2019-12-02 15:57:25 +01:00

276 lines
9.4 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import multiprocessing
import os
import SimpleITK as sitk
import horovod.tensorflow as hvd
import numpy as np
import tensorflow as tf
from scipy import stats
def parse_nifti(path, dtype, dst_size, interpolator, normalization=None, modality=None):
sitk_image = load_image(path)
sitk_image = resize_image(sitk_image,
dst_size=dst_size,
interpolator=interpolator)
image = sitk_to_np(sitk_image)
if modality and 'CT' not in modality:
if normalization:
image = stats.zscore(image, axis=None)
elif modality:
raise NotImplementedError
return image
def make_ref_image(img_path, dst_size, interpolator):
ref_image = load_image(img_path)
ref_image = resize_image(ref_image, dst_size=dst_size,
interpolator=interpolator)
return sitk_to_np(ref_image) / np.max(ref_image) * 255
def make_interpolator(interpolator):
if interpolator == 'linear':
return sitk.sitkLinear
else:
raise ValueError("Unknown interpolator type")
def load_image(img_path):
image = sitk.ReadImage(img_path)
if image.GetDimension() == 4:
image = sitk.GetImageFromArray(sitk.GetArrayFromImage(image)[-1, :, :, :])
if image.GetPixelID() != sitk.sitkFloat32:
return sitk.Cast(image, sitk.sitkFloat32)
return image
def sitk_to_np(sitk_img):
return np.transpose(sitk.GetArrayFromImage(sitk_img), [2, 1, 0])
def resize_image(sitk_img,
dst_size=(128, 128, 64),
interpolator=sitk.sitkNearestNeighbor):
reference_image = sitk.Image(dst_size, sitk_img.GetPixelIDValue())
reference_image.SetOrigin(sitk_img.GetOrigin())
reference_image.SetDirection(sitk_img.GetDirection())
reference_image.SetSpacing(
[sz * spc / nsz for nsz, sz, spc in zip(dst_size, sitk_img.GetSize(), sitk_img.GetSpacing())])
return sitk.Resample(sitk_img, reference_image, sitk.Transform(3, sitk.sitkIdentity), interpolator)
class MSDJsonParser:
def __init__(self, json_path):
with open(json_path) as f:
data = json.load(f)
self._labels = data.get('labels')
self._x_train = [os.path.join(os.path.dirname(json_path), p['image']) for p in data.get('training')]
self._y_train = [os.path.join(os.path.dirname(json_path), p['label']) for p in data.get('training')]
self._x_test = [os.path.join(os.path.dirname(json_path), p) for p in data.get('test')]
self._modality = [data.get('modality')[k] for k in data.get('modality').keys()]
@property
def labels(self):
return self._labels
@property
def x_train(self):
return self._x_train
@property
def y_train(self):
return self._y_train
@property
def x_test(self):
return self._x_test
@property
def modality(self):
return self._modality
def make_split(json_parser, train_split, split_seed=0):
np.random.seed(split_seed)
train_size = int(len(json_parser.x_train) * train_split)
return np.array(json_parser.x_train)[:train_size], np.array(json_parser.y_train)[:train_size], \
np.array(json_parser.x_train)[train_size:], np.array(json_parser.y_train)[train_size:]
class MSDDataset(object):
def __init__(self, json_path,
dst_size=[128, 128, 64],
seed=None,
interpolator=None,
data_normalization=None,
batch_size=1,
train_split=1.0,
split_seed=0):
self._json_parser = MSDJsonParser(json_path)
self._interpolator = make_interpolator(interpolator)
self._ref_image = make_ref_image(img_path=self._json_parser.x_test[0],
dst_size=dst_size,
interpolator=self._interpolator)
np.random.seed(split_seed)
self._train_img, self._train_label, \
self._eval_img, self._eval_label = make_split(self._json_parser, train_split)
self._test_img = np.array(self._json_parser.x_test)
self._dst_size = dst_size
self._seed = seed
self._batch_size = batch_size
self._train_split = train_split
self._data_normalization = data_normalization
np.random.seed(self._seed)
@property
def labels(self):
return self._json_parser.labels
@property
def train_steps(self):
global_batch_size = hvd.size() * self._batch_size
return math.ceil(
len(self._train_img) / global_batch_size)
@property
def eval_steps(self):
return math.ceil(len(self._eval_img) / self._batch_size)
@property
def test_steps(self):
return math.ceil(len(self._test_img) / self._batch_size)
def _parse_image(self, img):
return parse_nifti(path=img,
dst_size=self._dst_size,
dtype=tf.float32,
interpolator=self._interpolator,
normalization=self._data_normalization,
modality=self._json_parser.modality)
def _parse_label(self, label):
return parse_nifti(path=label,
dst_size=self._dst_size,
dtype=tf.int32,
interpolator=sitk.sitkNearestNeighbor)
def _augment(self, x, y):
# Horizontal flip
h_flip = tf.random_uniform([]) > 0.5
x = tf.cond(h_flip, lambda: tf.image.flip_left_right(x), lambda: x)
y = tf.cond(h_flip, lambda: tf.image.flip_left_right(y), lambda: y)
# Vertical flip
v_flip = tf.random_uniform([]) > 0.5
x = tf.cond(v_flip, lambda: tf.image.flip_up_down(x), lambda: x)
y = tf.cond(v_flip, lambda: tf.image.flip_up_down(y), lambda: y)
return x, y
def _img_generator(self, collection):
for element in collection:
yield self._parse_image(element)
def _label_generator(self, collection):
for element in collection:
yield self._parse_label(element)
def train_fn(self, augment):
images = tf.data.Dataset.from_generator(generator=lambda: self._img_generator(self._train_img),
output_types=tf.float32,
output_shapes=(32, 32, 32))
labels = tf.data.Dataset.from_generator(generator=lambda: self._label_generator(self._train_label),
output_types=tf.int32,
output_shapes=(32, 32, 32))
dataset = tf.data.Dataset.zip((images, labels))
dataset = dataset.cache()
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=self._batch_size * 2,
reshuffle_each_iteration=True,
seed=self._seed)
dataset = dataset.shard(hvd.size(), hvd.rank())
if augment:
dataset = dataset.apply(
tf.data.experimental.map_and_batch(map_func=self._augment,
batch_size=self._batch_size,
drop_remainder=True,
num_parallel_calls=multiprocessing.cpu_count()))
else:
dataset = dataset.batch(self._batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
def eval_fn(self):
images = tf.data.Dataset.from_generator(generator=lambda: self._img_generator(self._eval_img),
output_types=tf.float32,
output_shapes=(32, 32, 32))
labels = tf.data.Dataset.from_generator(generator=lambda: self._label_generator(self._eval_label),
output_types=tf.int32,
output_shapes=(32, 32, 32))
dataset = tf.data.Dataset.zip((images, labels))
dataset = dataset.cache()
dataset = dataset.batch(self._batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
def test_fn(self, count=1):
dataset = tf.data.Dataset.from_generator(generator=lambda: self._img_generator(self._test_img),
output_types=tf.float32,
output_shapes=(32, 32, 32))
dataset = dataset.cache()
dataset = dataset.repeat(count=count)
dataset = dataset.batch(self._batch_size, drop_remainder=True)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset