# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """GNMT attention sequence-to-sequence model with dynamic RNN support.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn import attention_wrapper import block_lstm import model import model_helper from utils import misc_utils as utils class GNMTModel(model.BaseModel): """Sequence-to-sequence dynamic model with GNMT attention architecture. """ def __init__(self, hparams, mode, features, scope=None, extra_args=None): self.is_gnmt_attention = ( hparams.attention_architecture in ["gnmt", "gnmt_v2"]) super(GNMTModel, self).__init__( hparams=hparams, mode=mode, features=features, scope=scope, extra_args=extra_args) def _prepare_beam_search_decoder_inputs( self, beam_width, memory, source_sequence_length, encoder_state): memory = tf.contrib.seq2seq.tile_batch( memory, multiplier=beam_width) source_sequence_length = tf.contrib.seq2seq.tile_batch( source_sequence_length, multiplier=beam_width) encoder_state = tf.contrib.seq2seq.tile_batch( encoder_state, multiplier=beam_width) batch_size = self.batch_size * beam_width return memory, source_sequence_length, encoder_state, batch_size def _build_encoder(self, hparams): """Build a GNMT encoder.""" assert hparams.encoder_type == "gnmt" # Build GNMT encoder. num_bi_layers = 1 num_uni_layers = self.num_encoder_layers - num_bi_layers utils.print_out("# Build a GNMT encoder") utils.print_out(" num_bi_layers = %d" % num_bi_layers) utils.print_out(" num_uni_layers = %d" % num_uni_layers) # source is batch-majored source = self.features["source"] import sys print('source.shape: %s' % source.shape, file=sys.stderr) if self.time_major: # Later rnn would use time-majored inputs source = tf.transpose(source) with tf.variable_scope("encoder"): dtype = self.dtype encoder_emb_inp = tf.cast( self.encoder_emb_lookup_fn(self.embedding_encoder, source), dtype) # Build 1st bidi layer. bi_encoder_outputs, bi_encoder_state = self._build_encoder_layers_bidi( encoder_emb_inp, self.features["source_sequence_length"], hparams, dtype) # Build all the rest unidi layers encoder_state, encoder_outputs = self._build_encoder_layers_unidi( bi_encoder_outputs, self.features["source_sequence_length"], num_uni_layers, hparams, dtype) # Pass all encoder states to the decoder # except the first bi-directional layer encoder_state = (bi_encoder_state[1],) + ( (encoder_state,) if num_uni_layers == 1 else encoder_state) return encoder_outputs, encoder_state def _build_encoder_layers_bidi(self, inputs, sequence_length, hparams, dtype): """docstring.""" if hparams.use_fused_lstm: fn = self._build_bidi_rnn_fused elif hparams.use_cudnn_lstm: fn = self._build_bidi_rnn_cudnn else: fn = self._build_bidi_rnn_base return fn(inputs, sequence_length, hparams, dtype) def _build_bidi_rnn_fused(self, inputs, sequence_length, hparams, dtype): if (not np.isclose(hparams.dropout, 0.) and self.mode == tf.contrib.learn.ModeKeys.TRAIN): inputs = tf.nn.dropout(inputs, keep_prob=1-hparams.dropout) fwd_cell = block_lstm.LSTMBlockFusedCell( hparams.num_units, hparams.forget_bias, dtype=dtype) fwd_encoder_outputs, (fwd_final_c, fwd_final_h) = fwd_cell( inputs, dtype=dtype, sequence_length=sequence_length) inputs_r = tf.reverse_sequence( inputs, sequence_length, batch_axis=1, seq_axis=0) bak_cell = block_lstm.LSTMBlockFusedCell( hparams.num_units, hparams.forget_bias, dtype=dtype) bak_encoder_outputs, (bak_final_c, bak_final_h) = bak_cell( inputs_r, dtype=dtype, sequence_length=sequence_length) bak_encoder_outputs = tf.reverse_sequence( bak_encoder_outputs, sequence_length, batch_axis=1, seq_axis=0) bi_encoder_outputs = tf.concat( [fwd_encoder_outputs, bak_encoder_outputs], axis=-1) fwd_state = tf.nn.rnn_cell.LSTMStateTuple(fwd_final_c, fwd_final_h) bak_state = tf.nn.rnn_cell.LSTMStateTuple(bak_final_c, bak_final_h) bi_encoder_state = (fwd_state, bak_state) # mask aren't applied on outputs, but final states are post-masking. return bi_encoder_outputs, bi_encoder_state def _build_unidi_rnn_fused(self, inputs, state, sequence_length, hparams, dtype): if (not np.isclose(hparams.dropout, 0.) and self.mode == tf.contrib.learn.ModeKeys.TRAIN): inputs = tf.nn.dropout(inputs, keep_prob=1-hparams.dropout) cell = block_lstm.LSTMBlockFusedCell( hparams.num_units, hparams.forget_bias, dtype=dtype) outputs, (final_c, final_h) = cell( inputs, state, dtype=dtype, sequence_length=sequence_length) # mask aren't applied on outputs, but final states are post-masking. return outputs, tf.nn.rnn_cell.LSTMStateTuple(final_c, final_h) def _build_unidi_rnn_cudnn(self, inputs, state, sequence_length, dtype, hparams, num_layers, is_fwd): # cudnn inputs only support time-major if not self.time_major: inputs = tf.transpose(inputs, axis=[1, 0, 2]) if num_layers == 1 and not np.isclose(hparams.dropout, 0.): # Special case when drop is used and only one layer dropout = 0. inputs = tf.nn.dropout(inputs, keep_prob=1-dropout) else: dropout = hparams.dropout # the outputs would be in time-majored sequence_length = tf.transpose(sequence_length) if not is_fwd: inputs = tf.reverse_sequence( inputs, sequence_length, batch_axis=1, seq_axis=0) cell = tf.contrib.cudnn_rnn.CudnnLSTM( num_layers=num_layers, num_units=hparams.num_units, direction=cudnn_rnn.CUDNN_RNN_UNIDIRECTION, dtype=self.dtype, dropout=dropout) outputs, (h, c) = cell(inputs, initial_state=state) """ # Mask outputs # [batch, time] mask = tf.sequence_mask(sequence_length, dtype=self.dtype) # [time, batch] mask = tf.transpose(mask) outputs *= mask """ if not is_fwd: outputs = tf.reverse_sequence( inputs, sequence_length, batch_axis=1, seq_axis=0) # NOTICE! There's no way to get the "correct" masked cell state in cudnn # rnn. if num_layers == 1: h = tf.squeeze(h, axis=0) c = tf.squeeze(c, axis=0) return outputs, tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h) # Split h and c to form a h.set_shape((num_layers, None, hparams.num_units)) c.set_shape((num_layers, None, hparams.num_units)) hs = tf.unstack(h) cs = tf.unstack(c) # The cell passed to bidi-dyanmic-rnn is a MultiRNNCell consisting 2 regular # LSTM, the state of each is a simple LSTMStateTuple. Thus the state of the # MultiRNNCell is a tuple of LSTMStateTuple. states = tuple( tf.nn.rnn_cell.LSTMStateTuple(c=c, h=h) for h, c in zip(hs, cs)) # No need to transpose back return outputs, states def _build_encoder_cell(self, hparams, num_layers, num_residual_layers, dtype=None): """Build a multi-layer RNN cell that can be used by encoder.""" return model_helper.create_rnn_cell( unit_type=hparams.unit_type, num_units=self.num_units, num_layers=num_layers, num_residual_layers=num_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, mode=self.mode, dtype=dtype, single_cell_fn=self.single_cell_fn, use_block_lstm=hparams.use_block_lstm) def _build_bidi_rnn_base(self, inputs, sequence_length, hparams, dtype): """Create and call biddirectional RNN cells.""" # num_residual_layers: Number of residual layers from top to bottom. For # example, if `num_bi_layers=4` and `num_residual_layers=2`, the last 2 # RNN layers in each RNN cell will be wrapped with `ResidualWrapper`. # Construct forward and backward cells fw_cell = self._build_encoder_cell(hparams, 1, # num_bi_layers, 0, # num_bi_residual_layers, dtype) bw_cell = self._build_encoder_cell(hparams, 1, # num_bi_layers, 0, # num_bi_residual_layers, dtype) if hparams.use_dynamic_rnn: bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn( fw_cell, bw_cell, inputs, dtype=dtype, sequence_length=sequence_length, time_major=self.time_major, swap_memory=True) else: bi_outputs, bi_state = tf.contrib.recurrent.bidirectional_functional_rnn( fw_cell, bw_cell, inputs, dtype=dtype, sequence_length=sequence_length, time_major=self.time_major, use_tpu=False) return tf.concat(bi_outputs, -1), bi_state def _build_bidi_rnn_cudnn(self, inputs, sequence_length, hparams, dtype): # Notice cudnn rnn dropout is applied between layers. (if 1 layer only then # no dropout). if not np.isclose(hparams.dropout, 0.): inputs = tf.nn.dropout(inputs, keep_prob=1-hparams.dropout) if not hparams.use_loose_bidi_cudnn_lstm: fwd_outputs, fwd_states = self._build_unidi_rnn_cudnn( inputs, None, # initial_state sequence_length, dtype, hparams, 1, # num_layer is_fwd=True) bak_outputs, bak_states = self._build_unidi_rnn_cudnn( inputs, None, # initial_state sequence_length, dtype, hparams, 1, # num_layer is_fwd=False) bi_outputs = tf.concat([fwd_outputs, bak_outputs], axis=-1) return bi_outputs, (fwd_states, bak_states) else: # Cudnn only accept time-majored inputs if not self.time_major: inputs = tf.transpose(inputs, axis=[1, 0, 2]) bi_outputs, (bi_h, bi_c) = tf.contrib.cudnn_rnn.CudnnLSTM( num_layers=1, # num_bi_layers, num_units=hparams.num_units, direction=cudnn_rnn.CUDNN_RNN_BIDIRECTION, dropout=0., # one layer, dropout isn't applied anyway, seed=hparams.random_seed, dtype=self.dtype, kernel_initializer=tf.get_variable_scope().initializer, bias_initializer=tf.zeros_initializer())(inputs) # state shape is [num_layers * num_dir, batch, dim] bi_h.set_shape((2, None, hparams.num_units)) bi_c.set_shape((2, None, hparams.num_units)) fwd_h, bak_h = tf.unstack(bi_h) fwd_c, bak_c = tf.unstack(bi_c) # No need to transpose back return bi_outputs, (tf.nn.rnn_cell.LSTMStateTuple(c=fwd_c, h=fwd_h), tf.nn.rnn_cell.LSTMStateTuple(c=bak_c, h=bak_h)) def _build_encoder_layers_unidi(self, inputs, sequence_length, num_uni_layers, hparams, dtype): """Build encoder layers all at once.""" encoder_outputs = None encoder_state = tuple() if hparams.use_fused_lstm: for i in range(num_uni_layers): if (not np.isclose(hparams.dropout, 0.) and self.mode == tf.contrib.learn.ModeKeys.TRAIN): cell_inputs = tf.nn.dropout(inputs, keep_prob=1-hparams.dropout) else: cell_inputs = inputs cell = block_lstm.LSTMBlockFusedCell( hparams.num_units, hparams.forget_bias, dtype=dtype) encoder_outputs, (final_c, final_h) = cell( cell_inputs, dtype=dtype, sequence_length=sequence_length) encoder_state += (tf.nn.rnn_cell.LSTMStateTuple(final_c, final_h),) if i >= num_uni_layers - self.num_encoder_residual_layers: # Add the pre-dropout inputs. Residual wrapper is applied after # dropout wrapper. encoder_outputs += inputs inputs = encoder_outputs elif hparams.use_cudnn_lstm: # Single layer cudnn rnn, dropout isnt applied in the kernel for i in range(num_uni_layers): if (not np.isclose(hparams.dropout, 0.) and self.mode == tf.contrib.learn.ModeKeys.TRAIN): inputs = tf.nn.dropout(inputs, keep_prob=1-hparams.dropout) encoder_outputs, encoder_states = self._build_unidi_rnn_cudnn( inputs, None, # initial_state sequence_length, dtype, hparams, 1, # num_layer is_fwd=True) encoder_state += (tf.nn.rnn_cell.LSTMStateTuple(encoder_states.c, encoder_states.h),) if i >= num_uni_layers - self.num_encoder_residual_layers: encoder_outputs += inputs inputs = encoder_outputs else: uni_cell = model_helper.create_rnn_cell( unit_type=hparams.unit_type, num_units=hparams.num_units, num_layers=num_uni_layers, num_residual_layers=self.num_encoder_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, dtype=dtype, mode=self.mode, single_cell_fn=self.single_cell_fn, use_block_lstm=hparams.use_block_lstm) if hparams.use_dynamic_rnn: encoder_outputs, encoder_state = tf.nn.dynamic_rnn( uni_cell, inputs, dtype=dtype, sequence_length=sequence_length, time_major=self.time_major) else: encoder_outputs, encoder_state = tf.contrib.recurrent.functional_rnn( uni_cell, inputs, dtype=dtype, sequence_length=sequence_length, time_major=self.time_major, use_tpu=False) return encoder_state, encoder_outputs def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Build a RNN cell with GNMT attention architecture.""" # GNMT attention assert self.is_gnmt_attention attention_option = hparams.attention attention_architecture = hparams.attention_architecture assert attention_option == "normed_bahdanau" assert attention_architecture == "gnmt_v2" num_units = hparams.num_units infer_mode = hparams.infer_mode dtype = tf.float16 if hparams.use_fp16 else tf.float32 if self.time_major: memory = tf.transpose(encoder_outputs, [1, 0, 2]) else: memory = encoder_outputs if (self.mode == tf.contrib.learn.ModeKeys.INFER and infer_mode == "beam_search"): memory, source_sequence_length, encoder_state, batch_size = ( self._prepare_beam_search_decoder_inputs( hparams.beam_width, memory, source_sequence_length, encoder_state)) else: batch_size = self.batch_size attention_mechanism = model.create_attention_mechanism( num_units, memory, source_sequence_length, dtype=dtype) cell_list = model_helper._cell_list( # pylint: disable=protected-access unit_type=hparams.unit_type, num_units=num_units, num_layers=self.num_decoder_layers, num_residual_layers=self.num_decoder_residual_layers, forget_bias=hparams.forget_bias, dropout=hparams.dropout, mode=self.mode, dtype=dtype, single_cell_fn=self.single_cell_fn, residual_fn=gnmt_residual_fn, use_block_lstm=hparams.use_block_lstm) # Only wrap the bottom layer with the attention mechanism. attention_cell = cell_list.pop(0) # Only generate alignment in greedy INFER mode. alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and infer_mode != "beam_search") attention_cell = attention_wrapper.AttentionWrapper( attention_cell, attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, alignment_history=alignment_history, name="attention") cell = GNMTAttentionMultiCell(attention_cell, cell_list) if hparams.pass_hidden_state: decoder_initial_state = tuple( zs.clone(cell_state=es) if isinstance(zs, attention_wrapper.AttentionWrapperState) else es for zs, es in zip( cell.zero_state(batch_size, dtype), encoder_state)) else: decoder_initial_state = cell.zero_state(batch_size, dtype) return cell, decoder_initial_state def _build_decoder_cudnn(self, encoder_outputs, encoder_state, hparams): pass """ # Training # Use dynamic_rnn to compute the 1st layer outputs and attention # GNMT attention with tf.variable_scope("decoder") as decoder_scope: assert self.is_gnmt_attention attention_option = hparams.attention attention_architecture = hparams.attention_architecture assert attention_option == "normed_bahdanau" assert attention_architecture == "gnmt_v2" num_units = hparams.num_units infer_mode = hparams.infer_mode dtype = tf.float16 if hparams.use_fp16 else tf.float32 if self.time_major: memory = tf.transpose(encoder_outputs, [1, 0, 2]) else: memory = encoder_outputs source_sequence_length = self.features["source_sequence_length"] if (self.mode == tf.contrib.learn.ModeKeys.INFER and infer_mode == "beam_search"): memory, source_sequence_length, encoder_state, batch_size = ( self._prepare_beam_search_decoder_inputs( hparams.beam_width, memory, source_sequence_length, encoder_state)) else: batch_size = self.batch_size attention_mechanism = model.create_attention_mechanism( num_units, memory, source_sequence_length, dtype=dtype) attention_cell = model_helper._cell_list( # pylint: disable=protected-access unit_type=hparams.unit_type, num_units=num_units, num_layers=1, # just one layer num_residual_layers=0, # 1st layer has no residual connection. forget_bias=hparams.forget_bias, dropout=hparams.dropout, mode=self.mode, dtype=dtype, single_cell_fn=self.single_cell_fn, residual_fn=gnmt_residual_fn, use_block_lstm=False)[0] # Only generate alignment in greedy INFER mode. alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and infer_mode != "beam_search") attention_cell = attention_wrapper.AttentionWrapper( attention_cell, attention_mechanism, attention_layer_size=None, # don't use attention layer. output_attention=False, alignment_history=alignment_history, name="attention") decoder_attention_cell_initial_state = attention_cell.zero_state( batch_size, dtype).clone(cell_state=encoder_state[0]) # TODO(jamesqin): support frnn # [batch, time] target_input = self.features["target_input"] if self.time_major: # If using time_major mode, then target_input should be [time, batch] # then the decoder_emb_inp would be [time, batch, dim] target_input = tf.transpose(target_input) decoder_emb_inp = tf.cast( tf.nn.embedding_lookup(self.embedding_decoder, target_input), self.dtype) attention_cell_outputs, attention_cell_state = tf.nn.dynamic_rnn( attention_cell, decoder_emb_inp, sequence_length=self.features["target_sequence_length"], initial_state=decoder_attention_cell_initial_state, dtype=self.dtype, scope=decoder_scope, parallel_iterations=hparams.parallel_iterations, time_major=self.time_major) attention = None inputs = tf.concat([target_input, attention_cell_outputs], axis=-1) initial_state = encoder_state[1:] num_bi_layers = 1 num_unidi_decoder_layers = self.num_decoder_layers = num_bi_layers # 3 layers of uni cudnn for i in range(num_unidi_decoder_layers): # Concat input with attention if (not np.isclose(hparams.dropout, 0.) and self.mode == tf.contrib.learn.ModeKeys.TRAIN): inputs = tf.nn.dropout(inputs, keep_prob=1 - hparams.dropout) outputs, states = self._build_unidi_rnn_cudnn( inputs, initial_state[i], self.features["target_sequence_length"], self.dtype, hparams, 1, # num_layer is_fwd=True) if i >= num_unidi_decoder_layers - self.num_decoder_residual_layers: outputs += inputs inputs = outputs pass """ def _build_decoder_fused_for_training(self, encoder_outputs, initial_state, decoder_emb_inp, hparams): assert self.mode == tf.contrib.learn.ModeKeys.TRAIN num_bi_layers = 1 num_unidi_decoder_layers = self.num_decoder_layers - num_bi_layers assert num_unidi_decoder_layers == 3 # The 1st LSTM layer if self.time_major: batch = tf.shape(encoder_outputs)[1] tgt_max_len = tf.shape(decoder_emb_inp)[0] # [batch_size] -> scalar initial_attention = tf.zeros( shape=[tgt_max_len, batch, hparams.num_units], dtype=self.dtype) else: batch = tf.shape(encoder_outputs)[0] tgt_max_len = tf.shape(decoder_emb_inp)[1] initial_attention = tf.zeros( shape=[batch, tgt_max_len, hparams.num_units], dtype=self.dtype) # Concat with initial attention dec_inp = tf.concat([decoder_emb_inp, initial_attention], axis=-1) # [tgt_time, batch, units] # var_scope naming chosen to agree with inference graph. with tf.variable_scope("multi_rnn_cell/cell_0_attention/attention"): outputs, _ = self._build_unidi_rnn_fused( dec_inp, initial_state[0], self.features["target_sequence_length"], hparams, self.dtype) # Get attention # Fused attention layer has memory of shape [batch, src_time, ...] if self.time_major: memory = tf.transpose(encoder_outputs, [1, 0, 2]) else: memory = encoder_outputs fused_attention_layer = attention_wrapper.BahdanauAttentionFusedLayer( hparams.num_units, memory, memory_sequence_length=self.features["source_sequence_length"], dtype=self.dtype) # [batch, tgt_time, units] if self.time_major: queries = tf.transpose(outputs, [1, 0, 2]) else: queries = outputs fused_attention = fused_attention_layer(queries) if self.time_major: # [tgt_time, batch, units] fused_attention = tf.transpose(fused_attention, [1, 0, 2]) # 2-4th layer inputs = outputs for i in range(num_unidi_decoder_layers): # [tgt_time, batch, 2 * units] concat_inputs = tf.concat([inputs, fused_attention], axis=-1) # var_scope naming chosen to agree with inference graph. with tf.variable_scope("multi_rnn_cell/cell_%d" % (i+1)): outputs, _ = self._build_unidi_rnn_fused( concat_inputs, initial_state[i + 1], self.features["target_sequence_length"], hparams, self.dtype) if i >= num_unidi_decoder_layers - self.num_decoder_residual_layers: # gnmt_v2 attention adds the original inputs. outputs += inputs inputs = outputs return outputs class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell): """A MultiCell with GNMT attention style.""" def __init__(self, attention_cell, cells): """Creates a GNMTAttentionMultiCell. Args: attention_cell: An instance of AttentionWrapper. cells: A list of RNNCell wrapped with AttentionInputWrapper. """ cells = [attention_cell] + cells super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) def __call__(self, inputs, state, scope=None): """Run the cell with bottom layer's attention copied to all upper layers.""" if not tf.contrib.framework.nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) with tf.variable_scope(scope or "multi_rnn_cell"): new_states = [] with tf.variable_scope("cell_0_attention"): attention_cell = self._cells[0] attention_state = state[0] cur_inp, new_attention_state = attention_cell(inputs, attention_state) new_states.append(new_attention_state) for i in range(1, len(self._cells)): with tf.variable_scope("cell_%d" % i): cell = self._cells[i] cur_state = state[i] cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) return cur_inp, tuple(new_states) def gnmt_residual_fn(inputs, outputs): """Residual function that handles different inputs and outputs inner dims. Args: inputs: cell inputs, this is actual inputs concatenated with the attention vector. outputs: cell outputs Returns: outputs + actual inputs """ def split_input(inp, out): inp_dim = inp.get_shape().as_list()[-1] out_dim = out.get_shape().as_list()[-1] return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1) actual_inputs, _ = tf.contrib.framework.nest.map_structure( split_input, inputs, outputs) def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) tf.contrib.framework.nest.assert_same_structure(actual_inputs, outputs) tf.contrib.framework.nest.map_structure( assert_shape_match, actual_inputs, outputs) return tf.contrib.framework.nest.map_structure( lambda inp, out: inp + out, actual_inputs, outputs)