DeepLearningExamples/TensorFlow/Segmentation/UNet_3D_Medical/model/model_fn.py
Przemek Strzelczyk 79d4ced0be Adding 3DUnet/TF
2020-07-04 03:28:33 +02:00

73 lines
2.8 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import horovod.tensorflow as hvd
import tensorflow as tf
from model.unet3d import Builder
from model.losses import make_loss, eval_dice, total_dice
from dataset.data_loader import CLASSES
def unet_3d(features, labels, mode, params):
logits = Builder(n_classes=4, normalization=params.normalization, mode=mode)(features)
if mode == tf.estimator.ModeKeys.PREDICT:
prediction = tf.argmax(input=logits, axis=-1, output_type=tf.dtypes.int32)
return tf.estimator.EstimatorSpec(mode=mode,
predictions={'predictions': tf.cast(prediction, tf.int8)})
labels = tf.cast(labels, tf.float32)
if not params.include_background:
labels = labels[..., 1:]
logits = logits[..., 1:]
if mode == tf.estimator.ModeKeys.EVAL:
eval_acc = eval_dice(y_true=labels, y_pred=tf.round(logits))
total_eval_acc = total_dice(tf.round(logits), labels)
metrics = {CLASSES[i]: tf.metrics.mean(eval_acc[i]) for i in range(eval_acc.shape[-1])}
metrics['WholeTumor'] = tf.metrics.mean(total_eval_acc)
return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(eval_acc),
eval_metric_ops=metrics)
loss = make_loss(params, y_pred=logits, y_true=labels)
loss = tf.identity(loss, name="total_loss_ref")
global_step = tf.compat.v1.train.get_or_create_global_step()
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=params.learning_rate)
optimizer = hvd.DistributedOptimizer(optimizer)
# NGC has TF_ENABLE_AUTO_MIXED_PRECISION enabled by default. We cannot use
# both graph_rewrite and envar, so if we're not in NGC we do graph_rewrite
try:
amp_envar = int(os.environ['TF_ENABLE_AUTO_MIXED_PRECISION']) == 1
except KeyError:
amp_envar = False
if params.use_amp and not amp_envar:
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
loss_scale='dynamic'
)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_op = optimizer.minimize(loss, global_step=global_step)
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op)