DeepLearningExamples/TensorFlow/Segmentation/UNet_Industrial/model/unet.py
Pablo Ribalta 70f3e4362c Remove dllogger and fix bugs from GH
Signed-off-by: Pablo Ribalta <pribalta@nvidia.com>
2020-04-08 09:40:52 +02:00

513 lines
20 KiB
Python

# !/usr/bin/env python
# -*- coding: utf-8 -*-
# ==============================================================================
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import tensorflow as tf
import horovod.tensorflow as hvd
from model import layers
from model import blocks
from utils import hvd_utils
from utils import losses
from utils import metrics
from utils import image_processing
from dllogger import Logger
__all__ = ["UNet_v1"]
class UNet_v1(object):
authorized_weight_init_methods = [
"he_normal",
"he_uniform",
"glorot_normal",
"glorot_uniform",
"orthogonal",
]
authorized_models_variants = [
"original",
"tinyUNet",
]
def __init__(
self,
model_name,
compute_format,
input_format,
n_output_channels,
unet_variant,
activation_fn,
weight_init_method,
):
if unet_variant == "original": # Total Params: 36,950,273
input_filters = 64
unet_block_filters = [128, 256, 512]
bottleneck_filters = 1024
output_filters = 64
elif unet_variant == "tinyUNet": # Total Params: 1,824,945
input_filters = 32
unet_block_filters = [32, 64, 128]
bottleneck_filters = 256
output_filters = 32
else:
raise ValueError(
"Unknown `UNet` variant: %s. Authorized: %s" % (unet_variant, UNet_v1.authorized_models_variants)
)
if activation_fn not in blocks.authorized_activation_fn:
raise ValueError(
"Unknown activation function: %s - Authorised: %s" % (activation_fn, blocks.authorized_activation_fn)
)
self.model_hparams = tf.contrib.training.HParams(
compute_format=compute_format,
input_format=input_format,
input_filters=input_filters,
unet_block_filters=unet_block_filters,
bottleneck_filters=bottleneck_filters,
output_filters=output_filters,
n_output_channels=n_output_channels,
model_name=model_name,
)
self.conv2d_hparams = tf.contrib.training.HParams(
kernel_initializer=None, bias_initializer=tf.initializers.constant(0.0), activation_fn=activation_fn
)
if weight_init_method == "he_normal":
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
scale=2.0, distribution='truncated_normal', mode='fan_in'
)
elif weight_init_method == "he_uniform":
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
scale=2.0, distribution='uniform', mode='fan_in'
)
elif weight_init_method == "glorot_normal":
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
scale=1.0, distribution='truncated_normal', mode='fan_avg'
)
elif weight_init_method == "glorot_uniform":
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
scale=1.0, distribution='uniform', mode='fan_avg'
)
elif weight_init_method == "orthogonal":
self.conv2d_hparams.kernel_initializer = tf.initializers.orthogonal(gain=1.0)
else:
raise ValueError(
"Unknown weight init method: %s - Authorized: %s" %
(weight_init_method, UNet_v1.authorized_weight_init_methods)
)
def __call__(self, features, labels, mode, params):
if "debug_verbosity" not in params.keys():
raise RuntimeError("Parameter `debug_verbosity` is missing...")
if mode == tf.estimator.ModeKeys.TRAIN:
if "rmsprop_decay" not in params.keys():
raise RuntimeError("Parameter `rmsprop_decay` is missing...")
if "rmsprop_momentum" not in params.keys():
raise RuntimeError("Parameter `rmsprop_momentum` is missing...")
if "learning_rate" not in params.keys():
raise RuntimeError("Parameter `learning_rate` is missing...")
if "learning_rate_decay_steps" not in params.keys():
raise RuntimeError("Parameter `learning_rate` is missing...")
if "learning_rate_decay_factor" not in params.keys():
raise RuntimeError("Parameter `learning_rate` is missing...")
if "weight_decay" not in params.keys():
raise RuntimeError("Parameter `weight_decay` is missing...")
if "loss_fn_name" not in params.keys():
raise RuntimeError("Parameter `loss_fn_name` is missing...")
if mode == tf.estimator.ModeKeys.PREDICT:
y_pred, y_pred_logits = self.build_model(
features, training=False, reuse=False, debug_verbosity=params["debug_verbosity"]
)
predictions = {'logits': y_pred}
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
input_image, mask_image = features
with tf.device("/gpu:0"):
tf.identity(input_image, name="input_image_ref")
tf.identity(mask_image, name="mask_image_ref")
tf.identity(labels, name="labels_ref")
y_pred, y_pred_logits = self.build_model(
input_image,
training=mode == tf.estimator.ModeKeys.TRAIN,
reuse=False,
debug_verbosity=params["debug_verbosity"]
)
all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])
tf.identity(all_trainable_vars, name='trainable_parameters_count_ref')
if mode == tf.estimator.ModeKeys.EVAL:
eval_metrics = dict()
# ==================== Samples ==================== #
image_uint8 = tf.cast((input_image + 1) * 127.5, dtype=tf.uint8)
input_image_jpeg = tf.image.encode_jpeg(image_uint8[0], format='grayscale', quality=100)
tf.identity(input_image_jpeg, name="input_image_jpeg_ref")
for threshold in [None, 0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99]:
binarize_img, binarize_img_jpeg = image_processing.binarize_output(y_pred[0], threshold=threshold)
tf.identity(binarize_img_jpeg, name="output_sample_ths_%s_ref" % threshold)
tf.summary.image('output_sample_ths_%s' % threshold, binarize_img, 10)
# ==============+ Evaluation Metrics ==================== #
with tf.name_scope("IoU_Metrics"):
for threshold in [0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99]:
iou_score = metrics.iou_score(y_pred=y_pred, y_true=mask_image, threshold=threshold)
tf.identity(iou_score, name='iou_score_ths_%s_ref' % threshold)
tf.summary.scalar('iou_score_ths_%s' % threshold, iou_score)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metrics["IoU_THS_%s" % threshold] = tf.metrics.mean(iou_score)
labels = tf.cast(labels, tf.float32)
labels_preds = tf.reduce_max(y_pred, axis=(1, 2, 3))
assert (
abs(labels_preds - tf.clip_by_value(labels_preds, 0, 1)) < 0.00001,
"Clipping labels_preds introduces non-trivial loss."
)
labels_preds = tf.clip_by_value(labels_preds, 0, 1)
with tf.variable_scope("Confusion_Matrix") as scope:
tp, update_tp = tf.metrics.true_positives_at_thresholds(
labels=labels,
predictions=labels_preds,
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
)
tn, update_tn = tf.metrics.true_negatives_at_thresholds(
labels=labels,
predictions=labels_preds,
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
)
fp, update_fp = tf.metrics.false_positives_at_thresholds(
labels=labels,
predictions=labels_preds,
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
)
fn, update_fn = tf.metrics.false_negatives_at_thresholds(
labels=labels,
predictions=labels_preds,
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
)
if mode == tf.estimator.ModeKeys.TRAIN:
local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope.name)
confusion_matrix_reset_op = tf.initializers.variables(local_vars, name='reset_op')
with tf.control_dependencies([confusion_matrix_reset_op]):
with tf.control_dependencies([update_tp, update_tn, update_fp, update_fn]):
tp = tf.identity(tp)
tn = tf.identity(tn)
fp = tf.identity(fp)
fn = tf.identity(fn)
else:
eval_metrics["Confusion_Matrix_TP"] = tp, update_tp
eval_metrics["Confusion_Matrix_TN"] = tn, update_tn
eval_metrics["Confusion_Matrix_FP"] = fp, update_fp
eval_metrics["Confusion_Matrix_FN"] = fn, update_fn
tf.identity(tp, name='true_positives_ref') # Confusion_Matrix/true_positives_ref:0
tf.identity(tn, name='true_negatives_ref') # Confusion_Matrix/true_negatives_ref:0
tf.identity(fp, name='false_positives_ref') # Confusion_Matrix/false_positives_ref:0
tf.identity(fn, name='false_negatives_ref') # Confusion_Matrix/false_negatives_ref:0
tf.summary.scalar('true_positives', tp[3]) # For Ths = 0.5
tf.summary.scalar('true_negatives', tn[3]) # For Ths = 0.5
tf.summary.scalar('false_positives', fp[3]) # For Ths = 0.5
tf.summary.scalar('false_negatives', fn[3]) # For Ths = 0.5
binarized_mask, binarized_mask_jpeg = image_processing.binarize_output(mask_image[0], threshold=0.5)
tf.identity(binarized_mask_jpeg, name="mask_sample_ref")
tf.summary.image('sample_mask', binarized_mask, 10)
##########################
mask_max_val = tf.reduce_max(mask_image)
tf.identity(mask_max_val, name='mask_max_val_ref')
mask_min_val = tf.reduce_min(mask_image)
tf.identity(mask_min_val, name='mask_min_val_ref')
mask_mean_val = tf.reduce_mean(mask_image)
tf.identity(mask_mean_val, name='mask_mean_val_ref')
mask_std_val = tf.math.reduce_std(mask_image)
tf.identity(mask_std_val, name='mask_std_val_ref')
##########################
output_max_val = tf.reduce_max(y_pred)
tf.identity(output_max_val, name='output_max_val_ref')
output_min_val = tf.reduce_min(y_pred)
tf.identity(output_min_val, name='output_min_val_ref')
output_mean_val = tf.reduce_mean(y_pred)
tf.identity(output_mean_val, name='output_mean_val_ref')
output_std_val = tf.math.reduce_std(y_pred)
tf.identity(output_std_val, name='output_std_val_ref')
with tf.variable_scope("losses"):
# ==============+ Reconstruction Loss ==================== #
if params["loss_fn_name"] == "x-entropy":
reconstruction_loss = losses.reconstruction_x_entropy(y_pred=y_pred, y_true=mask_image)
elif params["loss_fn_name"] == "l2_loss":
reconstruction_loss = losses.reconstruction_l2loss(y_pred=y_pred, y_true=mask_image)
elif params["loss_fn_name"] == "dice_sorensen":
reconstruction_loss = 1 - losses.dice_coe(y_pred=y_pred, y_true=mask_image, loss_type='sorensen')
elif params["loss_fn_name"] == "dice_jaccard":
reconstruction_loss = 1 - losses.dice_coe(y_pred=y_pred, y_true=mask_image, loss_type='jaccard')
elif params["loss_fn_name"] == "adaptive_loss":
reconstruction_loss = losses.adaptive_loss(
y_pred=y_pred,
y_pred_logits=y_pred_logits,
y_true=mask_image,
switch_at_threshold=0.3,
loss_type='sorensen'
)
else:
raise ValueError("Unknown loss function received: %s" % params["loss_fn_name"])
tf.identity(reconstruction_loss, name='reconstruction_loss_ref')
tf.summary.scalar('reconstruction_loss', reconstruction_loss)
if mode == tf.estimator.ModeKeys.TRAIN:
# ============== Regularization Loss ==================== #
l2_loss = losses.regularization_l2loss(weight_decay=params["weight_decay"])
tf.identity(l2_loss, name='l2_loss_ref')
tf.summary.scalar('l2_loss', l2_loss)
total_loss = tf.add(reconstruction_loss, l2_loss, name="total_loss")
else:
total_loss = reconstruction_loss
tf.identity(total_loss, name='total_loss_ref')
tf.summary.scalar('total_loss', total_loss)
if mode == tf.estimator.ModeKeys.TRAIN:
with tf.variable_scope("optimizers"):
# Update Global Step
global_step = tf.train.get_or_create_global_step()
tf.identity(global_step, name="global_step_ref")
learning_rate = tf.train.exponential_decay(
learning_rate=params["learning_rate"],
decay_steps=params["learning_rate_decay_steps"],
decay_rate=params["learning_rate_decay_factor"],
global_step=global_step,
staircase=True
)
tf.identity(learning_rate, name="learning_rate_ref")
tf.summary.scalar('learning_rate_ref', learning_rate)
opt = tf.train.RMSPropOptimizer(
learning_rate=learning_rate,
use_locking=False,
centered=True,
decay=params["rmsprop_decay"],
momentum=params["rmsprop_momentum"],
)
if hvd_utils.is_using_hvd():
opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0')
if params["apply_manual_loss_scaling"]:
# if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
# Logger.log("Applying manual Loss Scaling ...")
loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
init_loss_scale=2**32, # 4,294,967,296
incr_every_n_steps=1000
)
opt = tf.contrib.mixed_precision.LossScaleOptimizer(opt, loss_scale_manager)
deterministic = True
gate_gradients = (tf.train.Optimizer.GATE_OP if deterministic else tf.train.Optimizer.GATE_NONE)
backprop_op = opt.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step)
train_op = tf.group(backprop_op, tf.get_collection(tf.GraphKeys.UPDATE_OPS))
return tf.estimator.EstimatorSpec(
mode,
loss=total_loss,
train_op=train_op,
)
elif mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode, loss=total_loss, eval_metric_ops=eval_metrics, predictions={"output": y_pred}
)
else:
raise NotImplementedError('Unknown mode {}'.format(mode))
def build_model(self, inputs, training=True, reuse=False, debug_verbosity=0):
"""
U-Net: Convolutional Networks for Biomedical Image Segmentation
https://arxiv.org/pdf/1505.04597
"""
skip_connections = []
with tf.variable_scope(self.model_hparams.model_name, reuse=reuse):
with tf.variable_scope("input_reshape"):
with tf.variable_scope("initial_zero_padding"):
inputs = tf.image.resize_image_with_crop_or_pad(inputs, target_height=512, target_width=512)
if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# This provides a large performance boost on GPU. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
# Reshape inputs: NHWC => NCHW
net = tf.transpose(inputs, [0, 3, 1, 2])
elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':
# Reshape inputs: NCHW => NHWC
net = tf.transpose(inputs, [0, 2, 3, 1])
else:
net = inputs
# net, out = input_block(net, filters=64)
net, out = blocks.input_unet_block(
net,
filters=self.model_hparams.input_filters,
data_format=self.model_hparams.compute_format,
is_training=training,
conv2d_hparams=self.conv2d_hparams
)
skip_connections.append(out)
for idx, filters in enumerate(self.model_hparams.unet_block_filters):
# net, out = downsample_block(net, filters=filters, idx=idx)
net, skip_connect = blocks.downsample_unet_block(
net,
filters=filters,
data_format=self.model_hparams.compute_format,
is_training=training,
conv2d_hparams=self.conv2d_hparams,
block_name="downsample_block_%d" % (idx + 1)
)
skip_connections.append(skip_connect)
net = blocks.bottleneck_unet_block(
net,
filters=self.model_hparams.bottleneck_filters,
data_format=self.model_hparams.compute_format,
is_training=training,
conv2d_hparams=self.conv2d_hparams,
)
for idx, filters in enumerate(reversed(self.model_hparams.unet_block_filters)):
net = blocks.upsample_unet_block(
net,
residual_input=skip_connections.pop(),
filters=filters,
data_format=self.model_hparams.compute_format,
is_training=training,
conv2d_hparams=self.conv2d_hparams,
block_name='upsample_block_%d' % (idx + 1)
)
logits = blocks.output_unet_block(
inputs=net,
residual_input=skip_connections.pop(),
filters=self.model_hparams.output_filters,
n_output_channels=self.model_hparams.n_output_channels,
data_format=self.model_hparams.compute_format,
is_training=training,
conv2d_hparams=self.conv2d_hparams,
block_name='ouputs_block'
)
if self.model_hparams.compute_format == "NCHW":
logits = tf.transpose(logits, [0, 2, 3, 1])
outputs = layers.sigmoid(logits)
return outputs, logits