135 lines
5.1 KiB
Python
135 lines
5.1 KiB
Python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
def _normalization(inputs, name, mode):
|
|
training = mode == tf.estimator.ModeKeys.TRAIN
|
|
|
|
if name == 'instancenorm':
|
|
gamma_initializer = tf.constant_initializer(1.0)
|
|
return tf.contrib.layers.instance_norm(
|
|
inputs,
|
|
center=True,
|
|
scale=True,
|
|
epsilon=1e-6,
|
|
param_initializers={'gamma': gamma_initializer},
|
|
reuse=None,
|
|
variables_collections=None,
|
|
outputs_collections=None,
|
|
trainable=True,
|
|
data_format='NHWC',
|
|
scope=None)
|
|
|
|
if name == 'groupnorm':
|
|
return tf.contrib.layers.group_norm(inputs=inputs,
|
|
groups=16,
|
|
channels_axis=-1,
|
|
reduction_axes=(-4, -3, -2),
|
|
activation_fn=None,
|
|
trainable=True)
|
|
|
|
if name == 'batchnorm':
|
|
return tf.keras.layers.BatchNormalization(axis=-1,
|
|
trainable=True,
|
|
virtual_batch_size=None)(inputs, training=training)
|
|
elif name == 'none':
|
|
return inputs
|
|
else:
|
|
raise ValueError('Invalid normalization layer')
|
|
|
|
|
|
def _activation(x, activation):
|
|
if activation == 'relu':
|
|
return tf.nn.relu(x)
|
|
elif activation == 'leaky_relu':
|
|
return tf.nn.leaky_relu(x, alpha=0.01)
|
|
elif activation == 'sigmoid':
|
|
return tf.nn.sigmoid(x)
|
|
elif activation == 'softmax':
|
|
return tf.nn.softmax(x, axis=-1)
|
|
elif activation == 'none':
|
|
return x
|
|
else:
|
|
raise ValueError("Unknown activation {}".format(activation))
|
|
|
|
|
|
def convolution(x,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
mode=tf.estimator.ModeKeys.TRAIN,
|
|
normalization='batchnorm',
|
|
activation='leaky_relu',
|
|
transpose=False):
|
|
|
|
if transpose:
|
|
conv = tf.keras.layers.Conv3DTranspose
|
|
else:
|
|
conv = tf.keras.layers.Conv3D
|
|
regularizer = None#tf.keras.regularizers.l2(1e-5)
|
|
|
|
x = conv(filters=out_channels,
|
|
kernel_size=kernel_size,
|
|
strides=stride,
|
|
activation=None,
|
|
padding='same',
|
|
data_format='channels_last',
|
|
kernel_initializer=tf.glorot_uniform_initializer(),
|
|
kernel_regularizer=regularizer,
|
|
bias_initializer=tf.zeros_initializer(),
|
|
bias_regularizer=regularizer)(x)
|
|
|
|
x = _normalization(x, normalization, mode)
|
|
|
|
return _activation(x, activation)
|
|
|
|
|
|
def upsample_block(x, skip_connection, out_channels, normalization, mode):
|
|
x = convolution(x, kernel_size=2, out_channels=out_channels, stride=2,
|
|
normalization='none', activation='none', transpose=True)
|
|
x = tf.keras.layers.Concatenate(axis=-1)([x, skip_connection])
|
|
|
|
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
|
|
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
|
|
return x
|
|
|
|
|
|
def input_block(x, out_channels, normalization, mode):
|
|
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
|
|
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
|
|
return x
|
|
|
|
|
|
def downsample_block(x, out_channels, normalization, mode):
|
|
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode, stride=2)
|
|
return convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
|
|
|
|
|
|
def linear_block(x, out_channels, mode, activation='leaky_relu', normalization='none'):
|
|
x = convolution(x, out_channels=out_channels, normalization=normalization, mode=mode)
|
|
return convolution(x, out_channels=out_channels, activation=activation, mode=mode, normalization=normalization)
|
|
|
|
|
|
def output_layer(x, out_channels, activation):
|
|
x = tf.keras.layers.Conv3D(out_channels,
|
|
kernel_size=3,
|
|
activation=None,
|
|
padding='same',
|
|
kernel_regularizer=None,
|
|
kernel_initializer=tf.glorot_uniform_initializer(),
|
|
bias_initializer=tf.zeros_initializer(),
|
|
bias_regularizer=None)(x)
|
|
return _activation(x, activation)
|