DeepLearningExamples/TensorFlow/Segmentation/UNet_3D_Medical/dataset/data_loader.py
Przemek Strzelczyk 79d4ced0be Adding 3DUnet/TF
2020-07-04 03:28:33 +02:00

255 lines
8.7 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import horovod.tensorflow as hvd
import numpy as np
import tensorflow as tf
from dataset.transforms import NormalizeImages, OneHotLabels, apply_transforms, PadXYZ, RandomCrop3D, \
RandomHorizontalFlip, RandomGammaCorrection, RandomVerticalFlip, RandomBrightnessCorrection, CenterCrop, \
apply_test_transforms, Cast
CLASSES = {0: "TumorCore", 1: "PeritumoralEdema", 2: "EnhancingTumor"}
def cross_validation(x: np.ndarray, fold_idx: int, n_folds: int):
if fold_idx < 0 or fold_idx >= n_folds:
raise ValueError('Fold index has to be [0, n_folds). Received index {} for {} folds'.format(fold_idx, n_folds))
_folders = np.array_split(x, n_folds)
return np.concatenate(_folders[:fold_idx] + _folders[fold_idx + 1:]), _folders[fold_idx]
class Dataset:
def __init__(self, data_dir, batch_size=2, fold_idx=0, n_folds=5, seed=0, pipeline_factor=1, params=None):
self._folders = np.array([os.path.join(data_dir, path) for path in os.listdir(data_dir)])
self._train, self._eval = cross_validation(self._folders, fold_idx=fold_idx, n_folds=n_folds)
self._pipeline_factor = pipeline_factor
self._data_dir = data_dir
self.params = params
self._batch_size = batch_size
self._seed = seed
self._xshape = (240, 240, 155, 4)
self._yshape = (240, 240, 155)
def parse(self, serialized):
features = {
'X': tf.io.FixedLenFeature([], tf.string),
'Y': tf.io.FixedLenFeature([], tf.string),
'mean': tf.io.FixedLenFeature([4], tf.float32),
'stdev': tf.io.FixedLenFeature([4], tf.float32)
}
parsed_example = tf.io.parse_single_example(serialized=serialized,
features=features)
x = tf.io.decode_raw(parsed_example['X'], tf.uint8)
x = tf.cast(tf.reshape(x, self._xshape), tf.uint8)
y = tf.io.decode_raw(parsed_example['Y'], tf.uint8)
y = tf.cast(tf.reshape(y, self._yshape), tf.uint8)
mean = parsed_example['mean']
stdev = parsed_example['stdev']
return x, y, mean, stdev
def parse_x(self, serialized):
features = {'X': tf.io.FixedLenFeature([], tf.string),
'Y': tf.io.FixedLenFeature([], tf.string),
'mean': tf.io.FixedLenFeature([4], tf.float32),
'stdev': tf.io.FixedLenFeature([4], tf.float32)}
parsed_example = tf.io.parse_single_example(serialized=serialized,
features=features)
x = tf.io.decode_raw(parsed_example['X'], tf.uint8)
x = tf.cast(tf.reshape(x, self._xshape), tf.uint8)
mean = parsed_example['mean']
stdev = parsed_example['stdev']
return x, mean, stdev
def train_fn(self):
assert len(self._train) > 0, "Training data not found."
ds = tf.data.TFRecordDataset(filenames=self._train)
ds = ds.shard(hvd.size(), hvd.rank())
ds = ds.cache()
ds = ds.shuffle(buffer_size=self._batch_size * 8, seed=self._seed)
ds = ds.repeat()
ds = ds.map(self.parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
transforms = [
RandomCrop3D((128, 128, 128)),
RandomHorizontalFlip() if self.params.augment else None,
Cast(dtype=tf.float32),
NormalizeImages(),
RandomBrightnessCorrection() if self.params.augment else None,
OneHotLabels(n_classes=4),
]
ds = ds.map(map_func=lambda x, y, mean, stdev: apply_transforms(x, y, mean, stdev, transforms=transforms),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size=self._batch_size,
drop_remainder=True)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return ds
def eval_fn(self):
ds = tf.data.TFRecordDataset(filenames=self._eval)
assert len(self._eval) > 0, "Evaluation data not found. Did you specify --fold flag?"
ds = ds.cache()
ds = ds.map(self.parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
transforms = [
CenterCrop((224, 224, 155)),
Cast(dtype=tf.float32),
NormalizeImages(),
OneHotLabels(n_classes=4),
PadXYZ()
]
ds = ds.map(map_func=lambda x, y, mean, stdev: apply_transforms(x, y, mean, stdev, transforms=transforms),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size=self._batch_size,
drop_remainder=False)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return ds
def test_fn(self, count=1, drop_remainder=False):
ds = tf.data.TFRecordDataset(filenames=self._eval)
assert len(self._eval) > 0, "Evaluation data not found. Did you specify --fold flag?"
ds = ds.repeat(count)
ds = ds.map(self.parse_x, num_parallel_calls=tf.data.experimental.AUTOTUNE)
transforms = [
CenterCrop((224, 224, 155)),
Cast(dtype=tf.float32),
NormalizeImages(),
PadXYZ((224, 224, 160))
]
ds = ds.map(map_func=lambda x, mean, stdev: apply_test_transforms(x, mean, stdev, transforms=transforms),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size=self._batch_size,
drop_remainder=drop_remainder)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return ds
def synth_train_fn(self):
"""Synthetic data function for testing"""
inputs = tf.random_uniform(self._xshape, dtype=tf.int32, minval=0, maxval=255, seed=self._seed,
name='synth_inputs')
masks = tf.random_uniform(self._yshape, dtype=tf.int32, minval=0, maxval=4, seed=self._seed,
name='synth_masks')
ds = tf.data.Dataset.from_tensors((inputs, masks))
ds = ds.repeat()
transforms = [
Cast(dtype=tf.uint8),
RandomCrop3D((128, 128, 128)),
RandomHorizontalFlip() if self.params.augment else None,
Cast(dtype=tf.float32),
NormalizeImages(),
RandomBrightnessCorrection() if self.params.augment else None,
OneHotLabels(n_classes=4),
]
ds = ds.map(map_func=lambda x, y: apply_transforms(x, y, transforms),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(self._batch_size)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return ds
def synth_predict_fn(self, count=1):
"""Synthetic data function for testing"""
inputs = tf.truncated_normal((64, 64, 64, 4), dtype=tf.float32, mean=0.0, stddev=1.0, seed=self._seed,
name='synth_inputs')
ds = tf.data.Dataset.from_tensors(inputs)
ds = ds.repeat(count)
ds = ds.batch(self._batch_size)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return ds
@property
def train_size(self):
return len(self._train)
@property
def eval_size(self):
return len(self._eval)
@property
def test_size(self):
return len(self._eval)
def main():
from time import time
hvd.init()
dataset = Dataset(data_dir='/data/BraTS19_tfrecord', batch_size=3)
it = dataset.test().make_initializable_iterator()
sess = tf.Session()
sess.run(it.initializer)
next_element = it.get_next()
t0 = time()
cnt = 0
# while True:
import matplotlib.pyplot as plt
import numpy.ma as ma
for i in range(200):
t0 = time()
# if i == 20:
# t0 = time()
res = sess.run(next_element)
a = res[0]
a = a[0, :, :, 80, 0]
a = ma.masked_array(a, mask=a == 0)
# plt.imshow(a.astype(np.uint8))
plt.imshow(a)
plt.colorbar()
plt.savefig("/opt/project/img.png")
# print()
print(time() - t0)
if __name__ == '__main__':
main()