2019-03-18 19:45:05 +01:00
|
|
|
# !/usr/bin/env python
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
#
|
|
|
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
import horovod.tensorflow as hvd
|
|
|
|
|
|
|
|
from model import layers
|
|
|
|
from model import blocks
|
|
|
|
|
|
|
|
from utils import hvd_utils
|
|
|
|
|
|
|
|
from utils import losses
|
|
|
|
from utils import metrics
|
|
|
|
|
|
|
|
from utils import image_processing
|
|
|
|
|
2020-04-08 09:40:52 +02:00
|
|
|
from dllogger import Logger
|
2019-03-18 19:45:05 +01:00
|
|
|
|
|
|
|
__all__ = ["UNet_v1"]
|
|
|
|
|
|
|
|
|
|
|
|
class UNet_v1(object):
|
|
|
|
|
|
|
|
authorized_weight_init_methods = [
|
|
|
|
"he_normal",
|
|
|
|
"he_uniform",
|
|
|
|
"glorot_normal",
|
|
|
|
"glorot_uniform",
|
|
|
|
"orthogonal",
|
|
|
|
]
|
|
|
|
|
|
|
|
authorized_models_variants = [
|
|
|
|
"original",
|
|
|
|
"tinyUNet",
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_name,
|
|
|
|
compute_format,
|
|
|
|
input_format,
|
|
|
|
n_output_channels,
|
|
|
|
unet_variant,
|
|
|
|
activation_fn,
|
|
|
|
weight_init_method,
|
|
|
|
):
|
|
|
|
|
|
|
|
if unet_variant == "original": # Total Params: 36,950,273
|
|
|
|
input_filters = 64
|
|
|
|
unet_block_filters = [128, 256, 512]
|
|
|
|
bottleneck_filters = 1024
|
|
|
|
output_filters = 64
|
|
|
|
|
|
|
|
elif unet_variant == "tinyUNet": # Total Params: 1,824,945
|
|
|
|
input_filters = 32
|
|
|
|
unet_block_filters = [32, 64, 128]
|
|
|
|
bottleneck_filters = 256
|
|
|
|
output_filters = 32
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Unknown `UNet` variant: %s. Authorized: %s" % (unet_variant, UNet_v1.authorized_models_variants)
|
|
|
|
)
|
|
|
|
|
|
|
|
if activation_fn not in blocks.authorized_activation_fn:
|
|
|
|
raise ValueError(
|
|
|
|
"Unknown activation function: %s - Authorised: %s" % (activation_fn, blocks.authorized_activation_fn)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.model_hparams = tf.contrib.training.HParams(
|
|
|
|
compute_format=compute_format,
|
|
|
|
input_format=input_format,
|
|
|
|
input_filters=input_filters,
|
|
|
|
unet_block_filters=unet_block_filters,
|
|
|
|
bottleneck_filters=bottleneck_filters,
|
|
|
|
output_filters=output_filters,
|
|
|
|
n_output_channels=n_output_channels,
|
|
|
|
model_name=model_name,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.conv2d_hparams = tf.contrib.training.HParams(
|
|
|
|
kernel_initializer=None, bias_initializer=tf.initializers.constant(0.0), activation_fn=activation_fn
|
|
|
|
)
|
|
|
|
|
|
|
|
if weight_init_method == "he_normal":
|
|
|
|
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
|
|
|
|
scale=2.0, distribution='truncated_normal', mode='fan_in'
|
|
|
|
)
|
|
|
|
|
|
|
|
elif weight_init_method == "he_uniform":
|
|
|
|
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
|
|
|
|
scale=2.0, distribution='uniform', mode='fan_in'
|
|
|
|
)
|
|
|
|
|
|
|
|
elif weight_init_method == "glorot_normal":
|
|
|
|
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
|
|
|
|
scale=1.0, distribution='truncated_normal', mode='fan_avg'
|
|
|
|
)
|
|
|
|
|
|
|
|
elif weight_init_method == "glorot_uniform":
|
|
|
|
self.conv2d_hparams.kernel_initializer = tf.initializers.variance_scaling(
|
|
|
|
scale=1.0, distribution='uniform', mode='fan_avg'
|
|
|
|
)
|
|
|
|
|
|
|
|
elif weight_init_method == "orthogonal":
|
|
|
|
self.conv2d_hparams.kernel_initializer = tf.initializers.orthogonal(gain=1.0)
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"Unknown weight init method: %s - Authorized: %s" %
|
|
|
|
(weight_init_method, UNet_v1.authorized_weight_init_methods)
|
|
|
|
)
|
|
|
|
|
|
|
|
def __call__(self, features, labels, mode, params):
|
|
|
|
|
|
|
|
if "debug_verbosity" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `debug_verbosity` is missing...")
|
|
|
|
|
|
|
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
|
|
|
|
|
|
if "rmsprop_decay" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `rmsprop_decay` is missing...")
|
|
|
|
|
|
|
|
if "rmsprop_momentum" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `rmsprop_momentum` is missing...")
|
|
|
|
|
|
|
|
if "learning_rate" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `learning_rate` is missing...")
|
|
|
|
|
|
|
|
if "learning_rate_decay_steps" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `learning_rate` is missing...")
|
|
|
|
|
|
|
|
if "learning_rate_decay_factor" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `learning_rate` is missing...")
|
|
|
|
|
|
|
|
if "weight_decay" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `weight_decay` is missing...")
|
|
|
|
|
|
|
|
if "loss_fn_name" not in params.keys():
|
|
|
|
raise RuntimeError("Parameter `loss_fn_name` is missing...")
|
|
|
|
|
2019-10-09 14:45:45 +02:00
|
|
|
if mode == tf.estimator.ModeKeys.PREDICT:
|
|
|
|
y_pred, y_pred_logits = self.build_model(
|
|
|
|
features, training=False, reuse=False, debug_verbosity=params["debug_verbosity"]
|
|
|
|
)
|
|
|
|
|
|
|
|
predictions = {'logits': y_pred}
|
|
|
|
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
|
|
|
|
|
2019-03-18 19:45:05 +01:00
|
|
|
input_image, mask_image = features
|
|
|
|
|
|
|
|
with tf.device("/gpu:0"):
|
|
|
|
|
|
|
|
tf.identity(input_image, name="input_image_ref")
|
|
|
|
tf.identity(mask_image, name="mask_image_ref")
|
|
|
|
tf.identity(labels, name="labels_ref")
|
|
|
|
|
|
|
|
y_pred, y_pred_logits = self.build_model(
|
|
|
|
input_image,
|
|
|
|
training=mode == tf.estimator.ModeKeys.TRAIN,
|
|
|
|
reuse=False,
|
|
|
|
debug_verbosity=params["debug_verbosity"]
|
|
|
|
)
|
|
|
|
|
|
|
|
all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])
|
|
|
|
tf.identity(all_trainable_vars, name='trainable_parameters_count_ref')
|
|
|
|
|
|
|
|
if mode == tf.estimator.ModeKeys.EVAL:
|
|
|
|
eval_metrics = dict()
|
|
|
|
|
|
|
|
# ==================== Samples ==================== #
|
|
|
|
|
|
|
|
image_uint8 = tf.cast((input_image + 1) * 127.5, dtype=tf.uint8)
|
|
|
|
input_image_jpeg = tf.image.encode_jpeg(image_uint8[0], format='grayscale', quality=100)
|
|
|
|
tf.identity(input_image_jpeg, name="input_image_jpeg_ref")
|
|
|
|
|
|
|
|
for threshold in [None, 0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99]:
|
|
|
|
binarize_img, binarize_img_jpeg = image_processing.binarize_output(y_pred[0], threshold=threshold)
|
|
|
|
|
|
|
|
tf.identity(binarize_img_jpeg, name="output_sample_ths_%s_ref" % threshold)
|
|
|
|
tf.summary.image('output_sample_ths_%s' % threshold, binarize_img, 10)
|
|
|
|
|
|
|
|
# ==============+ Evaluation Metrics ==================== #
|
|
|
|
|
|
|
|
with tf.name_scope("IoU_Metrics"):
|
|
|
|
|
|
|
|
for threshold in [0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99]:
|
|
|
|
|
|
|
|
iou_score = metrics.iou_score(y_pred=y_pred, y_true=mask_image, threshold=threshold)
|
|
|
|
|
|
|
|
tf.identity(iou_score, name='iou_score_ths_%s_ref' % threshold)
|
|
|
|
tf.summary.scalar('iou_score_ths_%s' % threshold, iou_score)
|
|
|
|
|
|
|
|
if mode == tf.estimator.ModeKeys.EVAL:
|
|
|
|
eval_metrics["IoU_THS_%s" % threshold] = tf.metrics.mean(iou_score)
|
|
|
|
|
|
|
|
labels = tf.cast(labels, tf.float32)
|
|
|
|
labels_preds = tf.reduce_max(y_pred, axis=(1, 2, 3))
|
|
|
|
|
2020-04-08 09:40:52 +02:00
|
|
|
assert (
|
|
|
|
abs(labels_preds - tf.clip_by_value(labels_preds, 0, 1)) < 0.00001,
|
|
|
|
"Clipping labels_preds introduces non-trivial loss."
|
|
|
|
)
|
|
|
|
labels_preds = tf.clip_by_value(labels_preds, 0, 1)
|
|
|
|
|
2019-03-18 19:45:05 +01:00
|
|
|
with tf.variable_scope("Confusion_Matrix") as scope:
|
|
|
|
|
|
|
|
tp, update_tp = tf.metrics.true_positives_at_thresholds(
|
|
|
|
labels=labels,
|
|
|
|
predictions=labels_preds,
|
|
|
|
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
|
|
|
|
)
|
|
|
|
|
|
|
|
tn, update_tn = tf.metrics.true_negatives_at_thresholds(
|
|
|
|
labels=labels,
|
|
|
|
predictions=labels_preds,
|
|
|
|
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
|
|
|
|
)
|
|
|
|
|
|
|
|
fp, update_fp = tf.metrics.false_positives_at_thresholds(
|
|
|
|
labels=labels,
|
|
|
|
predictions=labels_preds,
|
|
|
|
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
|
|
|
|
)
|
|
|
|
|
|
|
|
fn, update_fn = tf.metrics.false_negatives_at_thresholds(
|
|
|
|
labels=labels,
|
|
|
|
predictions=labels_preds,
|
|
|
|
thresholds=[0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99],
|
|
|
|
)
|
|
|
|
|
|
|
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
|
|
local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=scope.name)
|
|
|
|
confusion_matrix_reset_op = tf.initializers.variables(local_vars, name='reset_op')
|
|
|
|
|
|
|
|
with tf.control_dependencies([confusion_matrix_reset_op]):
|
|
|
|
with tf.control_dependencies([update_tp, update_tn, update_fp, update_fn]):
|
|
|
|
tp = tf.identity(tp)
|
|
|
|
tn = tf.identity(tn)
|
|
|
|
fp = tf.identity(fp)
|
|
|
|
fn = tf.identity(fn)
|
|
|
|
|
|
|
|
else:
|
|
|
|
eval_metrics["Confusion_Matrix_TP"] = tp, update_tp
|
|
|
|
eval_metrics["Confusion_Matrix_TN"] = tn, update_tn
|
|
|
|
eval_metrics["Confusion_Matrix_FP"] = fp, update_fp
|
|
|
|
eval_metrics["Confusion_Matrix_FN"] = fn, update_fn
|
|
|
|
|
|
|
|
tf.identity(tp, name='true_positives_ref') # Confusion_Matrix/true_positives_ref:0
|
|
|
|
tf.identity(tn, name='true_negatives_ref') # Confusion_Matrix/true_negatives_ref:0
|
|
|
|
tf.identity(fp, name='false_positives_ref') # Confusion_Matrix/false_positives_ref:0
|
|
|
|
tf.identity(fn, name='false_negatives_ref') # Confusion_Matrix/false_negatives_ref:0
|
|
|
|
|
|
|
|
tf.summary.scalar('true_positives', tp[3]) # For Ths = 0.5
|
|
|
|
tf.summary.scalar('true_negatives', tn[3]) # For Ths = 0.5
|
|
|
|
tf.summary.scalar('false_positives', fp[3]) # For Ths = 0.5
|
|
|
|
tf.summary.scalar('false_negatives', fn[3]) # For Ths = 0.5
|
|
|
|
|
|
|
|
binarized_mask, binarized_mask_jpeg = image_processing.binarize_output(mask_image[0], threshold=0.5)
|
|
|
|
tf.identity(binarized_mask_jpeg, name="mask_sample_ref")
|
|
|
|
tf.summary.image('sample_mask', binarized_mask, 10)
|
|
|
|
|
|
|
|
##########################
|
|
|
|
|
|
|
|
mask_max_val = tf.reduce_max(mask_image)
|
|
|
|
tf.identity(mask_max_val, name='mask_max_val_ref')
|
|
|
|
|
|
|
|
mask_min_val = tf.reduce_min(mask_image)
|
|
|
|
tf.identity(mask_min_val, name='mask_min_val_ref')
|
|
|
|
|
|
|
|
mask_mean_val = tf.reduce_mean(mask_image)
|
|
|
|
tf.identity(mask_mean_val, name='mask_mean_val_ref')
|
|
|
|
|
|
|
|
mask_std_val = tf.math.reduce_std(mask_image)
|
|
|
|
tf.identity(mask_std_val, name='mask_std_val_ref')
|
|
|
|
|
|
|
|
##########################
|
|
|
|
|
|
|
|
output_max_val = tf.reduce_max(y_pred)
|
|
|
|
tf.identity(output_max_val, name='output_max_val_ref')
|
|
|
|
|
|
|
|
output_min_val = tf.reduce_min(y_pred)
|
|
|
|
tf.identity(output_min_val, name='output_min_val_ref')
|
|
|
|
|
|
|
|
output_mean_val = tf.reduce_mean(y_pred)
|
|
|
|
tf.identity(output_mean_val, name='output_mean_val_ref')
|
|
|
|
|
|
|
|
output_std_val = tf.math.reduce_std(y_pred)
|
|
|
|
tf.identity(output_std_val, name='output_std_val_ref')
|
|
|
|
|
|
|
|
with tf.variable_scope("losses"):
|
|
|
|
|
|
|
|
# ==============+ Reconstruction Loss ==================== #
|
|
|
|
|
|
|
|
if params["loss_fn_name"] == "x-entropy":
|
|
|
|
reconstruction_loss = losses.reconstruction_x_entropy(y_pred=y_pred, y_true=mask_image)
|
|
|
|
|
|
|
|
elif params["loss_fn_name"] == "l2_loss":
|
|
|
|
reconstruction_loss = losses.reconstruction_l2loss(y_pred=y_pred, y_true=mask_image)
|
|
|
|
|
|
|
|
elif params["loss_fn_name"] == "dice_sorensen":
|
|
|
|
reconstruction_loss = 1 - losses.dice_coe(y_pred=y_pred, y_true=mask_image, loss_type='sorensen')
|
|
|
|
|
|
|
|
elif params["loss_fn_name"] == "dice_jaccard":
|
|
|
|
reconstruction_loss = 1 - losses.dice_coe(y_pred=y_pred, y_true=mask_image, loss_type='jaccard')
|
|
|
|
|
|
|
|
elif params["loss_fn_name"] == "adaptive_loss":
|
|
|
|
reconstruction_loss = losses.adaptive_loss(
|
|
|
|
y_pred=y_pred,
|
|
|
|
y_pred_logits=y_pred_logits,
|
|
|
|
y_true=mask_image,
|
|
|
|
switch_at_threshold=0.3,
|
|
|
|
loss_type='sorensen'
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise ValueError("Unknown loss function received: %s" % params["loss_fn_name"])
|
|
|
|
|
|
|
|
tf.identity(reconstruction_loss, name='reconstruction_loss_ref')
|
|
|
|
tf.summary.scalar('reconstruction_loss', reconstruction_loss)
|
|
|
|
|
|
|
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
|
|
|
|
|
|
# ============== Regularization Loss ==================== #
|
|
|
|
|
|
|
|
l2_loss = losses.regularization_l2loss(weight_decay=params["weight_decay"])
|
|
|
|
|
|
|
|
tf.identity(l2_loss, name='l2_loss_ref')
|
|
|
|
tf.summary.scalar('l2_loss', l2_loss)
|
|
|
|
|
|
|
|
total_loss = tf.add(reconstruction_loss, l2_loss, name="total_loss")
|
|
|
|
|
|
|
|
else:
|
|
|
|
total_loss = reconstruction_loss
|
|
|
|
|
|
|
|
tf.identity(total_loss, name='total_loss_ref')
|
|
|
|
tf.summary.scalar('total_loss', total_loss)
|
|
|
|
|
|
|
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
|
|
|
|
|
|
|
with tf.variable_scope("optimizers"):
|
|
|
|
|
|
|
|
# Update Global Step
|
|
|
|
global_step = tf.train.get_or_create_global_step()
|
|
|
|
tf.identity(global_step, name="global_step_ref")
|
|
|
|
|
|
|
|
learning_rate = tf.train.exponential_decay(
|
|
|
|
learning_rate=params["learning_rate"],
|
|
|
|
decay_steps=params["learning_rate_decay_steps"],
|
|
|
|
decay_rate=params["learning_rate_decay_factor"],
|
|
|
|
global_step=global_step,
|
|
|
|
staircase=True
|
|
|
|
)
|
|
|
|
|
|
|
|
tf.identity(learning_rate, name="learning_rate_ref")
|
|
|
|
tf.summary.scalar('learning_rate_ref', learning_rate)
|
|
|
|
|
|
|
|
opt = tf.train.RMSPropOptimizer(
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
use_locking=False,
|
|
|
|
centered=True,
|
|
|
|
decay=params["rmsprop_decay"],
|
|
|
|
momentum=params["rmsprop_momentum"],
|
|
|
|
)
|
|
|
|
|
|
|
|
if hvd_utils.is_using_hvd():
|
|
|
|
opt = hvd.DistributedOptimizer(opt, device_dense='/gpu:0')
|
|
|
|
|
|
|
|
if params["apply_manual_loss_scaling"]:
|
|
|
|
|
2020-04-08 09:40:52 +02:00
|
|
|
# if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
|
|
|
|
# Logger.log("Applying manual Loss Scaling ...")
|
2019-03-18 19:45:05 +01:00
|
|
|
|
|
|
|
loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
|
|
|
|
init_loss_scale=2**32, # 4,294,967,296
|
|
|
|
incr_every_n_steps=1000
|
|
|
|
)
|
|
|
|
opt = tf.contrib.mixed_precision.LossScaleOptimizer(opt, loss_scale_manager)
|
|
|
|
|
|
|
|
deterministic = True
|
|
|
|
gate_gradients = (tf.train.Optimizer.GATE_OP if deterministic else tf.train.Optimizer.GATE_NONE)
|
|
|
|
|
|
|
|
backprop_op = opt.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step)
|
|
|
|
|
|
|
|
train_op = tf.group(backprop_op, tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
|
|
|
|
|
|
|
return tf.estimator.EstimatorSpec(
|
|
|
|
mode,
|
|
|
|
loss=total_loss,
|
|
|
|
train_op=train_op,
|
|
|
|
)
|
|
|
|
|
|
|
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
|
|
|
|
|
|
|
return tf.estimator.EstimatorSpec(
|
|
|
|
mode, loss=total_loss, eval_metric_ops=eval_metrics, predictions={"output": y_pred}
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise NotImplementedError('Unknown mode {}'.format(mode))
|
|
|
|
|
|
|
|
def build_model(self, inputs, training=True, reuse=False, debug_verbosity=0):
|
|
|
|
"""
|
|
|
|
U-Net: Convolutional Networks for Biomedical Image Segmentation
|
|
|
|
https://arxiv.org/pdf/1505.04597
|
|
|
|
"""
|
|
|
|
|
|
|
|
skip_connections = []
|
|
|
|
|
|
|
|
with tf.variable_scope(self.model_hparams.model_name, reuse=reuse):
|
|
|
|
|
|
|
|
with tf.variable_scope("input_reshape"):
|
|
|
|
|
|
|
|
with tf.variable_scope("initial_zero_padding"):
|
|
|
|
inputs = tf.image.resize_image_with_crop_or_pad(inputs, target_height=512, target_width=512)
|
|
|
|
|
|
|
|
if self.model_hparams.input_format == 'NHWC' and self.model_hparams.compute_format == 'NCHW':
|
|
|
|
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
|
|
|
|
# This provides a large performance boost on GPU. See
|
|
|
|
# https://www.tensorflow.org/performance/performance_guide#data_formats
|
|
|
|
|
|
|
|
# Reshape inputs: NHWC => NCHW
|
|
|
|
net = tf.transpose(inputs, [0, 3, 1, 2])
|
|
|
|
|
|
|
|
elif self.model_hparams.input_format == 'NCHW' and self.model_hparams.compute_format == 'NHWC':
|
|
|
|
|
|
|
|
# Reshape inputs: NCHW => NHWC
|
|
|
|
net = tf.transpose(inputs, [0, 2, 3, 1])
|
|
|
|
|
|
|
|
else:
|
|
|
|
net = inputs
|
|
|
|
|
|
|
|
# net, out = input_block(net, filters=64)
|
|
|
|
|
|
|
|
net, out = blocks.input_unet_block(
|
|
|
|
net,
|
|
|
|
filters=self.model_hparams.input_filters,
|
|
|
|
data_format=self.model_hparams.compute_format,
|
|
|
|
is_training=training,
|
|
|
|
conv2d_hparams=self.conv2d_hparams
|
|
|
|
)
|
|
|
|
|
|
|
|
skip_connections.append(out)
|
|
|
|
|
|
|
|
for idx, filters in enumerate(self.model_hparams.unet_block_filters):
|
|
|
|
# net, out = downsample_block(net, filters=filters, idx=idx)
|
|
|
|
|
|
|
|
net, skip_connect = blocks.downsample_unet_block(
|
|
|
|
net,
|
|
|
|
filters=filters,
|
|
|
|
data_format=self.model_hparams.compute_format,
|
|
|
|
is_training=training,
|
|
|
|
conv2d_hparams=self.conv2d_hparams,
|
|
|
|
block_name="downsample_block_%d" % (idx + 1)
|
|
|
|
)
|
|
|
|
|
|
|
|
skip_connections.append(skip_connect)
|
|
|
|
|
|
|
|
net = blocks.bottleneck_unet_block(
|
|
|
|
net,
|
|
|
|
filters=self.model_hparams.bottleneck_filters,
|
|
|
|
data_format=self.model_hparams.compute_format,
|
|
|
|
is_training=training,
|
|
|
|
conv2d_hparams=self.conv2d_hparams,
|
|
|
|
)
|
|
|
|
|
|
|
|
for idx, filters in enumerate(reversed(self.model_hparams.unet_block_filters)):
|
|
|
|
net = blocks.upsample_unet_block(
|
|
|
|
net,
|
|
|
|
residual_input=skip_connections.pop(),
|
|
|
|
filters=filters,
|
|
|
|
data_format=self.model_hparams.compute_format,
|
|
|
|
is_training=training,
|
|
|
|
conv2d_hparams=self.conv2d_hparams,
|
|
|
|
block_name='upsample_block_%d' % (idx + 1)
|
|
|
|
)
|
|
|
|
|
|
|
|
logits = blocks.output_unet_block(
|
|
|
|
inputs=net,
|
|
|
|
residual_input=skip_connections.pop(),
|
|
|
|
filters=self.model_hparams.output_filters,
|
|
|
|
n_output_channels=self.model_hparams.n_output_channels,
|
|
|
|
data_format=self.model_hparams.compute_format,
|
|
|
|
is_training=training,
|
|
|
|
conv2d_hparams=self.conv2d_hparams,
|
|
|
|
block_name='ouputs_block'
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.model_hparams.compute_format == "NCHW":
|
|
|
|
logits = tf.transpose(logits, [0, 2, 3, 1])
|
|
|
|
|
|
|
|
outputs = layers.sigmoid(logits)
|
|
|
|
|
|
|
|
return outputs, logits
|