DeepLearningExamples/TensorFlow/Segmentation/UNet_3D_Medical/model/layers.py
Przemek Strzelczyk 79d4ced0be Adding 3DUnet/TF
2020-07-04 03:28:33 +02:00

136 lines
5.1 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def _normalization(inputs, name, mode):
training = mode == tf.estimator.ModeKeys.TRAIN
if name == 'instancenorm':
gamma_initializer = tf.constant_initializer(1.0)
return tf.contrib.layers.instance_norm(
inputs,
center=True,
scale=True,
epsilon=1e-6,
param_initializers={'gamma': gamma_initializer},
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
data_format='NHWC',
scope=None)
if name == 'groupnorm':
return tf.contrib.layers.group_norm(inputs=inputs,
groups=16,
channels_axis=-1,
reduction_axes=(-4, -3, -2),
activation_fn=None,
trainable=True)
if name == 'batchnorm':
return tf.keras.layers.BatchNormalization(axis=-1,
trainable=True,
virtual_batch_size=None)(inputs, training=training)
elif name == 'none':
return inputs
else:
raise ValueError('Invalid normalization layer')
def _activation(x, activation):
if activation == 'relu':
return tf.nn.relu(x)
elif activation == 'leaky_relu':
return tf.nn.leaky_relu(x, alpha=0.01)
elif activation == 'sigmoid':
return tf.nn.sigmoid(x)
elif activation == 'softmax':
return tf.nn.softmax(x, axis=-1)
elif activation == 'none':
return x
else:
raise ValueError("Unknown activation {}".format(activation))
def convolution(x,
out_channels,
kernel_size=3,
stride=1,
mode=tf.estimator.ModeKeys.TRAIN,
normalization='batchnorm',
activation='leaky_relu',
transpose=False):
if transpose:
conv = tf.keras.layers.Conv3DTranspose
else:
conv = tf.keras.layers.Conv3D
regularizer = None#tf.keras.regularizers.l2(1e-5)
x = conv(filters=out_channels,
kernel_size=kernel_size,
strides=stride,
activation=None,
padding='same',
data_format='channels_last',
kernel_initializer=tf.glorot_uniform_initializer(),
kernel_regularizer=regularizer,
bias_initializer=tf.zeros_initializer(),
bias_regularizer=regularizer)(x)
x = _normalization(x, normalization, mode)
return _activation(x, activation)
def upsample_block(x, skip_connection, out_channels, normalization, mode):
x = convolution(x, kernel_size=2, out_channels=out_channels, stride=2,
normalization='none', activation='none', transpose=True)
x = tf.keras.layers.Concatenate(axis=-1)([x, skip_connection])
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
return x
def input_block(x, out_channels, normalization, mode):
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
return x
def downsample_block(x, out_channels, normalization, mode):
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode, stride=2)
return convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
def linear_block(x, out_channels, mode, activation='leaky_relu', normalization='none'):
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
return convolution(x, out_channels=out_channels, activation=activation, mode=mode, normalization=normalization)
def output_layer(x, out_channels, activation):
x = tf.keras.layers.Conv3D(out_channels,
kernel_size=3,
activation=None,
padding='same',
kernel_regularizer=None,
kernel_initializer=tf.glorot_uniform_initializer(),
bias_initializer=tf.zeros_initializer(),
bias_regularizer=None)(x)
return _activation(x, activation)