DeepLearningExamples/TensorFlow/Segmentation/UNet_3D_Medical/model/losses.py
Przemek Strzelczyk 79d4ced0be Adding 3DUnet/TF
2020-07-04 03:28:33 +02:00

84 lines
2.7 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def make_loss(params, y_true, y_pred):
if params.loss == 'dice':
return _dice(y_true, y_pred)
if params.loss == 'ce':
return _ce(y_true, y_pred)
if params.loss == 'dice+ce':
return tf.add(_ce(y_true, y_pred), _dice(y_true, y_pred), name="total_loss_ref")
raise ValueError('Unknown loss: {}'.format(params.loss))
def _ce(y_true, y_pred):
return tf.reduce_sum(
tf.reduce_mean(tf.keras.backend.binary_crossentropy(tf.cast(y_true, tf.float32), y_pred), axis=[0, 1, 2, 3]),
name='crossentropy_loss_ref')
def _dice(y_true, y_pred):
return tf.reduce_sum(dice_loss(predictions=y_pred, targets=y_true), name='dice_loss_ref')
def eval_dice(y_true, y_pred):
return 1 - dice_loss(predictions=y_pred, targets=y_true)
def dice_loss(predictions,
targets,
squared_pred=False,
smooth=1e-5,
top_smooth=0.0):
is_channels_first = False
n_len = len(predictions.get_shape())
reduce_axis = list(range(2, n_len)) if is_channels_first else list(range(1, n_len - 1))
intersection = tf.reduce_sum(targets * predictions, axis=reduce_axis)
if squared_pred:
targets = tf.square(targets)
predictions = tf.square(predictions)
y_true_o = tf.reduce_sum(targets, axis=reduce_axis)
y_pred_o = tf.reduce_sum(predictions, axis=reduce_axis)
denominator = y_true_o + y_pred_o
f = (2.0 * intersection + top_smooth) / (denominator + smooth)
return 1 - tf.reduce_mean(f, axis=0)
def total_dice(predictions,
targets,
smooth=1e-5,
top_smooth=0.0):
n_len = len(predictions.get_shape())
reduce_axis = list(range(1, n_len-1))
targets = tf.reduce_sum(targets, axis=-1)
predictions = tf.reduce_sum(predictions, axis=-1)
intersection = tf.reduce_sum(targets * predictions, axis=reduce_axis)
y_true_o = tf.reduce_sum(targets, axis=reduce_axis)
y_pred_o = tf.reduce_sum(predictions, axis=reduce_axis)
denominator = y_true_o + y_pred_o
return tf.reduce_mean((2.0 * intersection + top_smooth) / (denominator + smooth))