1120 lines
44 KiB
Python
1120 lines
44 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""A decoder that performs beam search."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import collections
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
|
||
|
import attention_wrapper
|
||
|
|
||
|
__all__ = [
|
||
|
"BeamSearchDecoderOutput",
|
||
|
"BeamSearchDecoderState",
|
||
|
"BeamSearchDecoder",
|
||
|
"FinalBeamSearchDecoderOutput",
|
||
|
"tile_batch",
|
||
|
]
|
||
|
|
||
|
|
||
|
class BeamSearchDecoderState(
|
||
|
collections.namedtuple("BeamSearchDecoderState",
|
||
|
("cell_state", "log_probs", "finished", "lengths",
|
||
|
"accumulated_attention_probs"))):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BeamSearchDecoderOutput(
|
||
|
collections.namedtuple("BeamSearchDecoderOutput",
|
||
|
("scores", "predicted_ids", "parent_ids"))):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class FinalBeamSearchDecoderOutput(
|
||
|
collections.namedtuple("FinalBeamDecoderOutput",
|
||
|
["predicted_ids", "beam_search_decoder_output"])):
|
||
|
"""Final outputs returned by the beam search after all decoding is finished.
|
||
|
|
||
|
Args:
|
||
|
predicted_ids: The final prediction. A tensor of shape
|
||
|
`[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if
|
||
|
`output_time_major` is True). Beams are ordered from best to worst.
|
||
|
beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that
|
||
|
describes the state of the beam search.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _tile_batch(t, multiplier):
|
||
|
"""Core single-tensor implementation of tile_batch."""
|
||
|
t = tf.convert_to_tensor(t, name="t")
|
||
|
shape_t = tf.shape(t)
|
||
|
if t.shape.ndims is None or t.shape.ndims < 1:
|
||
|
raise ValueError("t must have statically known rank")
|
||
|
tiling = [1] * (t.shape.ndims + 1)
|
||
|
tiling[1] = multiplier
|
||
|
tiled_static_batch_size = (
|
||
|
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
|
||
|
tiled = tf.tile(tf.expand_dims(t, 1), tiling)
|
||
|
tiled = tf.reshape(
|
||
|
tiled, tf.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
|
||
|
tiled.set_shape(
|
||
|
tf.TensorShape([tiled_static_batch_size]).concatenate(
|
||
|
t.shape[1:]))
|
||
|
return tiled
|
||
|
|
||
|
|
||
|
def tile_batch(t, multiplier, name=None):
|
||
|
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
|
||
|
|
||
|
For each tensor t in a (possibly nested structure) of tensors,
|
||
|
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||
|
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
||
|
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
||
|
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
||
|
`multiplier` times.
|
||
|
|
||
|
Args:
|
||
|
t: `Tensor` shaped `[batch_size, ...]`.
|
||
|
multiplier: Python int.
|
||
|
name: Name scope for any created operations.
|
||
|
|
||
|
Returns:
|
||
|
A (possibly nested structure of) `Tensor` shaped
|
||
|
`[batch_size * multiplier, ...]`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if tensor(s) `t` do not have a statically known rank or
|
||
|
the rank is < 1.
|
||
|
"""
|
||
|
flat_t = tf.contrib.framework.nest.flatten(t)
|
||
|
with tf.name_scope(name, "tile_batch", flat_t + [multiplier]):
|
||
|
return tf.contrib.framework.nest.map_structure(
|
||
|
lambda t_: _tile_batch(t_, multiplier), t)
|
||
|
|
||
|
|
||
|
def gather_tree_from_array(t, parent_ids, sequence_length):
|
||
|
"""Calculates the full beams for `TensorArray`s.
|
||
|
|
||
|
Args:
|
||
|
t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
|
||
|
shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
|
||
|
where `s` is the depth shape.
|
||
|
parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
|
||
|
sequence_length: The sequence length of shape `[batch_size, beam_width]`.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` which is a stacked `TensorArray` of the same size and type as
|
||
|
`t` and where beams are sorted in each `Tensor` according to `parent_ids`.
|
||
|
"""
|
||
|
max_time = parent_ids.shape[0].value or tf.shape(parent_ids)[0]
|
||
|
batch_size = parent_ids.shape[1].value or tf.shape(parent_ids)[1]
|
||
|
beam_width = parent_ids.shape[2].value or tf.shape(parent_ids)[2]
|
||
|
|
||
|
# Generate beam ids that will be reordered by gather_tree.
|
||
|
beam_ids = tf.expand_dims(
|
||
|
tf.expand_dims(tf.range(beam_width), 0), 0)
|
||
|
beam_ids = tf.tile(beam_ids, [max_time, batch_size, 1])
|
||
|
|
||
|
max_sequence_lengths = tf.to_int32(tf.reduce_max(sequence_length, axis=1))
|
||
|
sorted_beam_ids = tf.contrib.seq2seq.gather_tree(
|
||
|
step_ids=beam_ids,
|
||
|
parent_ids=parent_ids,
|
||
|
max_sequence_lengths=max_sequence_lengths,
|
||
|
end_token=beam_width + 1)
|
||
|
|
||
|
# For out of range steps, simply copy the same beam.
|
||
|
in_bound_steps = tf.transpose(
|
||
|
tf.sequence_mask(sequence_length, maxlen=max_time),
|
||
|
perm=[2, 0, 1])
|
||
|
sorted_beam_ids = tf.where(
|
||
|
in_bound_steps, x=sorted_beam_ids, y=beam_ids)
|
||
|
|
||
|
# Generate indices for gather_nd.
|
||
|
time_ind = tf.tile(tf.reshape(
|
||
|
tf.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width])
|
||
|
batch_ind = tf.tile(tf.reshape(
|
||
|
tf.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width])
|
||
|
batch_ind = tf.transpose(batch_ind, perm=[1, 0, 2])
|
||
|
indices = tf.stack([time_ind, batch_ind, sorted_beam_ids], -1)
|
||
|
|
||
|
# Gather from a tensor with collapsed additional dimensions.
|
||
|
gather_from = t
|
||
|
final_shape = tf.shape(gather_from)
|
||
|
gather_from = tf.reshape(
|
||
|
gather_from, [max_time, batch_size, beam_width, -1])
|
||
|
ordered = tf.gather_nd(gather_from, indices)
|
||
|
ordered = tf.reshape(ordered, final_shape)
|
||
|
|
||
|
return ordered
|
||
|
|
||
|
|
||
|
def _check_maybe(t):
|
||
|
if t.shape.ndims is None:
|
||
|
raise ValueError(
|
||
|
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||
|
|
||
|
|
||
|
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
|
||
|
"""Raises an exception if dimensions are known statically and can not be
|
||
|
reshaped to [batch_size, beam_size, -1].
|
||
|
"""
|
||
|
reshaped_shape = tf.TensorShape([batch_size, beam_width, None])
|
||
|
if (batch_size is not None and shape[0].value is not None
|
||
|
and (shape[0] != batch_size * beam_width
|
||
|
or (shape.ndims >= 2 and shape[1].value is not None
|
||
|
and (shape[0] != batch_size or shape[1] != beam_width)))):
|
||
|
tf.logging.warn("TensorArray reordering expects elements to be "
|
||
|
"reshapable to %s which is incompatible with the "
|
||
|
"current shape %s. Consider setting "
|
||
|
"reorder_tensor_arrays to False to disable TensorArray "
|
||
|
"reordering during the beam search."
|
||
|
% (reshaped_shape, shape))
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _check_batch_beam(t, batch_size, beam_width):
|
||
|
"""Returns an Assert operation checking that the elements of the stacked
|
||
|
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
|
||
|
the TensorArray elements have a known rank of at least 1.
|
||
|
"""
|
||
|
error_message = ("TensorArray reordering expects elements to be "
|
||
|
"reshapable to [batch_size, beam_size, -1] which is "
|
||
|
"incompatible with the dynamic shape of %s elements. "
|
||
|
"Consider setting reorder_tensor_arrays to False to disable "
|
||
|
"TensorArray reordering during the beam search."
|
||
|
% (t.name))
|
||
|
rank = t.shape.ndims
|
||
|
shape = tf.shape(t)
|
||
|
if rank == 2:
|
||
|
condition = tf.equal(shape[1], batch_size * beam_width)
|
||
|
else:
|
||
|
condition = tf.logical_or(
|
||
|
tf.equal(shape[1], batch_size * beam_width),
|
||
|
tf.logical_and(
|
||
|
tf.equal(shape[1], batch_size),
|
||
|
tf.equal(shape[2], beam_width)))
|
||
|
return tf.Assert(condition, [error_message])
|
||
|
|
||
|
|
||
|
class BeamSearchDecoder(tf.contrib.seq2seq.Decoder):
|
||
|
"""BeamSearch sampling decoder.
|
||
|
|
||
|
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
|
||
|
`AttentionWrapper`, then you must ensure that:
|
||
|
|
||
|
- The encoder output has been tiled to `beam_width` via
|
||
|
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
|
||
|
- The `batch_size` argument passed to the `zero_state` method of this
|
||
|
wrapper is equal to `true_batch_size * beam_width`.
|
||
|
- The initial state created with `zero_state` above contains a
|
||
|
`cell_state` value containing properly tiled final state from the
|
||
|
encoder.
|
||
|
|
||
|
An example:
|
||
|
|
||
|
```
|
||
|
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
|
||
|
encoder_outputs, multiplier=beam_width)
|
||
|
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
|
||
|
encoder_final_state, multiplier=beam_width)
|
||
|
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
|
||
|
sequence_length, multiplier=beam_width)
|
||
|
attention_mechanism = MyFavoriteAttentionMechanism(
|
||
|
num_units=attention_depth,
|
||
|
memory=tiled_inputs,
|
||
|
memory_sequence_length=tiled_sequence_length)
|
||
|
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
|
||
|
decoder_initial_state = attention_cell.zero_state(
|
||
|
dtype, batch_size=true_batch_size * beam_width)
|
||
|
decoder_initial_state = decoder_initial_state.clone(
|
||
|
cell_state=tiled_encoder_final_state)
|
||
|
```
|
||
|
|
||
|
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
|
||
|
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
|
||
|
the translation to cover all inputs.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
cell,
|
||
|
embedding,
|
||
|
start_tokens,
|
||
|
end_token,
|
||
|
initial_state,
|
||
|
beam_width,
|
||
|
output_layer=None,
|
||
|
length_penalty_weight=0.0,
|
||
|
coverage_penalty_weight=0.0,
|
||
|
reorder_tensor_arrays=True):
|
||
|
"""Initialize the BeamSearchDecoder.
|
||
|
|
||
|
Args:
|
||
|
cell: An `RNNCell` instance.
|
||
|
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
|
||
|
or the `params` argument for `embedding_lookup`.
|
||
|
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||
|
end_token: `int32` scalar, the token that marks end of decoding.
|
||
|
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
|
||
|
beam_width: Python integer, the number of beams.
|
||
|
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
|
||
|
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
|
||
|
to storing the result or sampling.
|
||
|
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||
|
sentence. Disabled with 0.0.
|
||
|
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
|
||
|
state will be reordered according to the beam search path. If the
|
||
|
`TensorArray` can be reordered, the stacked form will be returned.
|
||
|
Otherwise, the `TensorArray` will be returned as is. Set this flag to
|
||
|
`False` if the cell state contains `TensorArray`s that are not amenable
|
||
|
to reordering.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `cell` is not an instance of `RNNCell`,
|
||
|
or `output_layer` is not an instance of `tf.layers.Layer`.
|
||
|
ValueError: If `start_tokens` is not a vector or
|
||
|
`end_token` is not a scalar.
|
||
|
"""
|
||
|
if (output_layer is not None and
|
||
|
not isinstance(output_layer, tf.layers.Layer)):
|
||
|
raise TypeError(
|
||
|
"output_layer must be a Layer, received: %s" % type(output_layer))
|
||
|
self._cell = cell
|
||
|
self._output_layer = output_layer
|
||
|
self._reorder_tensor_arrays = reorder_tensor_arrays
|
||
|
|
||
|
if callable(embedding):
|
||
|
self._embedding_fn = embedding
|
||
|
else:
|
||
|
self._embedding_fn = (
|
||
|
lambda ids: tf.nn.embedding_lookup(embedding, ids))
|
||
|
|
||
|
self._start_tokens = tf.convert_to_tensor(
|
||
|
start_tokens, dtype=tf.int32, name="start_tokens")
|
||
|
if self._start_tokens.get_shape().ndims != 1:
|
||
|
raise ValueError("start_tokens must be a vector")
|
||
|
self._end_token = tf.convert_to_tensor(
|
||
|
end_token, dtype=tf.int32, name="end_token")
|
||
|
if self._end_token.get_shape().ndims != 0:
|
||
|
raise ValueError("end_token must be a scalar")
|
||
|
|
||
|
self._batch_size = tf.size(start_tokens)
|
||
|
self._beam_width = beam_width
|
||
|
self._length_penalty_weight = length_penalty_weight
|
||
|
self._coverage_penalty_weight = coverage_penalty_weight
|
||
|
self._initial_cell_state = tf.contrib.framework.nest.map_structure(
|
||
|
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
|
||
|
self._start_tokens = tf.tile(
|
||
|
tf.expand_dims(self._start_tokens, 1), [1, self._beam_width])
|
||
|
self._start_inputs = self._embedding_fn(self._start_tokens)
|
||
|
|
||
|
self._finished = tf.one_hot(
|
||
|
tf.zeros([self._batch_size], dtype=tf.int32),
|
||
|
depth=self._beam_width,
|
||
|
on_value=False,
|
||
|
off_value=True,
|
||
|
dtype=tf.bool)
|
||
|
|
||
|
@property
|
||
|
def batch_size(self):
|
||
|
return self._batch_size
|
||
|
|
||
|
def _rnn_output_size(self):
|
||
|
size = self._cell.output_size
|
||
|
if self._output_layer is None:
|
||
|
return size
|
||
|
else:
|
||
|
# To use layer's compute_output_shape, we need to convert the
|
||
|
# RNNCell's output_size entries into shapes with an unknown
|
||
|
# batch size. We then pass this through the layer's
|
||
|
# compute_output_shape and read off all but the first (batch)
|
||
|
# dimensions to get the output size of the rnn with the layer
|
||
|
# applied to the top.
|
||
|
output_shape_with_unknown_batch = tf.contrib.framework.nest.map_structure(
|
||
|
lambda s: tf.TensorShape([None]).concatenate(s), size)
|
||
|
layer_output_shape = self._output_layer.compute_output_shape(
|
||
|
output_shape_with_unknown_batch)
|
||
|
return tf.contrib.framework.nest.map_structure(
|
||
|
lambda s: s[1:], layer_output_shape)
|
||
|
|
||
|
@property
|
||
|
def tracks_own_finished(self):
|
||
|
"""The BeamSearchDecoder shuffles its beams and their finished state.
|
||
|
|
||
|
For this reason, it conflicts with the `dynamic_decode` function's
|
||
|
tracking of finished states. Setting this property to true avoids
|
||
|
early stopping of decoding due to mismanagement of the finished state
|
||
|
in `dynamic_decode`.
|
||
|
|
||
|
Returns:
|
||
|
`True`.
|
||
|
"""
|
||
|
return True
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
# Return the cell output and the id
|
||
|
return BeamSearchDecoderOutput(
|
||
|
scores=tf.TensorShape([self._beam_width]),
|
||
|
predicted_ids=tf.TensorShape([self._beam_width]),
|
||
|
parent_ids=tf.TensorShape([self._beam_width]))
|
||
|
|
||
|
@property
|
||
|
def output_dtype(self):
|
||
|
# Assume the dtype of the cell is the output_size structure
|
||
|
# containing the input_state's first component's dtype.
|
||
|
# Return that structure and int32 (the id)
|
||
|
dtype = tf.contrib.framework.nest.flatten(self._initial_cell_state)[0].dtype
|
||
|
return BeamSearchDecoderOutput(
|
||
|
scores=tf.contrib.framework.nest.map_structure(
|
||
|
lambda _: dtype, self._rnn_output_size()),
|
||
|
predicted_ids=tf.int32,
|
||
|
parent_ids=tf.int32)
|
||
|
|
||
|
def initialize(self, name=None):
|
||
|
"""Initialize the decoder.
|
||
|
|
||
|
Args:
|
||
|
name: Name scope for any created operations.
|
||
|
|
||
|
Returns:
|
||
|
`(finished, start_inputs, initial_state)`.
|
||
|
"""
|
||
|
finished, start_inputs = self._finished, self._start_inputs
|
||
|
|
||
|
dtype = tf.contrib.framework.nest.flatten(self._initial_cell_state)[0].dtype
|
||
|
log_probs = tf.one_hot( # shape(batch_sz, beam_sz)
|
||
|
tf.zeros([self._batch_size], dtype=tf.int32),
|
||
|
depth=self._beam_width,
|
||
|
on_value=tf.convert_to_tensor(0.0, dtype=dtype),
|
||
|
off_value=tf.convert_to_tensor(-np.Inf, dtype=dtype),
|
||
|
dtype=dtype)
|
||
|
init_attention_probs = get_attention_probs(
|
||
|
self._initial_cell_state, self._coverage_penalty_weight)
|
||
|
if init_attention_probs is None:
|
||
|
init_attention_probs = ()
|
||
|
|
||
|
initial_state = BeamSearchDecoderState(
|
||
|
cell_state=self._initial_cell_state,
|
||
|
log_probs=log_probs,
|
||
|
finished=finished,
|
||
|
lengths=tf.zeros(
|
||
|
[self._batch_size, self._beam_width], dtype=tf.int64),
|
||
|
accumulated_attention_probs=init_attention_probs)
|
||
|
|
||
|
return (finished, start_inputs, initial_state)
|
||
|
|
||
|
def finalize(self, outputs, final_state, sequence_lengths):
|
||
|
"""Finalize and return the predicted_ids.
|
||
|
|
||
|
Args:
|
||
|
outputs: An instance of BeamSearchDecoderOutput.
|
||
|
final_state: An instance of BeamSearchDecoderState. Passed through to the
|
||
|
output.
|
||
|
sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
|
||
|
The sequence lengths determined for each beam during decode.
|
||
|
**NOTE** These are ignored; the updated sequence lengths are stored in
|
||
|
`final_state.lengths`.
|
||
|
|
||
|
Returns:
|
||
|
outputs: An instance of `FinalBeamSearchDecoderOutput` where the
|
||
|
predicted_ids are the result of calling _gather_tree.
|
||
|
final_state: The same input instance of `BeamSearchDecoderState`.
|
||
|
"""
|
||
|
del sequence_lengths
|
||
|
# Get max_sequence_length across all beams for each batch.
|
||
|
max_sequence_lengths = tf.to_int32(
|
||
|
tf.reduce_max(final_state.lengths, axis=1))
|
||
|
predicted_ids = tf.contrib.seq2seq.gather_tree(
|
||
|
outputs.predicted_ids,
|
||
|
outputs.parent_ids,
|
||
|
max_sequence_lengths=max_sequence_lengths,
|
||
|
end_token=self._end_token)
|
||
|
if self._reorder_tensor_arrays:
|
||
|
final_state = final_state._replace(
|
||
|
cell_state=tf.contrib.framework.nest.map_structure(
|
||
|
lambda t: self._maybe_sort_array_beams(
|
||
|
t, outputs.parent_ids, final_state.lengths),
|
||
|
final_state.cell_state))
|
||
|
outputs = FinalBeamSearchDecoderOutput(
|
||
|
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
|
||
|
return outputs, final_state
|
||
|
|
||
|
def _merge_batch_beams(self, t, s=None):
|
||
|
"""Merges the tensor from a batch of beams into a batch by beams.
|
||
|
|
||
|
More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We
|
||
|
reshape this into [batch_size*beam_width, s]
|
||
|
|
||
|
Args:
|
||
|
t: Tensor of dimension [batch_size, beam_width, s]
|
||
|
s: (Possibly known) depth shape.
|
||
|
|
||
|
Returns:
|
||
|
A reshaped version of t with dimension [batch_size * beam_width, s].
|
||
|
"""
|
||
|
if isinstance(s, tf.Tensor):
|
||
|
s = tf.contrib.util.constant_value(s)
|
||
|
if isinstance(s, tf.TensorShape):
|
||
|
return s
|
||
|
else:
|
||
|
s = tf.TensorShape(s)
|
||
|
else:
|
||
|
s = tf.TensorShape(s)
|
||
|
t_shape = tf.shape(t)
|
||
|
static_batch_size = tf.contrib.util.constant_value(self._batch_size)
|
||
|
batch_size_beam_width = (
|
||
|
None
|
||
|
if static_batch_size is None else static_batch_size * self._beam_width)
|
||
|
reshaped_t = tf.reshape(
|
||
|
t,
|
||
|
tf.concat(([self._batch_size * self._beam_width], t_shape[2:]), 0))
|
||
|
reshaped_t.set_shape(
|
||
|
(tf.TensorShape([batch_size_beam_width]).concatenate(s)))
|
||
|
return reshaped_t
|
||
|
|
||
|
def _split_batch_beams(self, t, s=None):
|
||
|
"""Splits the tensor from a batch by beams into a batch of beams.
|
||
|
|
||
|
More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
|
||
|
reshape this into [batch_size, beam_width, s]
|
||
|
|
||
|
Args:
|
||
|
t: Tensor of dimension [batch_size*beam_width, s].
|
||
|
s: (Possibly known) depth shape.
|
||
|
|
||
|
Returns:
|
||
|
A reshaped version of t with dimension [batch_size, beam_width, s].
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If, after reshaping, the new tensor is not shaped
|
||
|
`[batch_size, beam_width, s]` (assuming batch_size and beam_width
|
||
|
are known statically).
|
||
|
"""
|
||
|
if isinstance(s, tf.Tensor):
|
||
|
s = tf.TensorShape(tf.contrib.util.constant_value(s))
|
||
|
else:
|
||
|
s = tf.TensorShape(s)
|
||
|
t_shape = tf.shape(t)
|
||
|
reshaped_t = tf.reshape(
|
||
|
t,
|
||
|
tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), 0))
|
||
|
static_batch_size = tf.contrib.util.constant_value(self._batch_size)
|
||
|
expected_reshaped_shape = tf.TensorShape(
|
||
|
[static_batch_size, self._beam_width]).concatenate(s)
|
||
|
if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape):
|
||
|
raise ValueError("Unexpected behavior when reshaping between beam width "
|
||
|
"and batch size. The reshaped tensor has shape: %s. "
|
||
|
"We expected it to have shape "
|
||
|
"(batch_size, beam_width, depth) == %s. Perhaps you "
|
||
|
"forgot to create a zero_state with "
|
||
|
"batch_size=encoder_batch_size * beam_width?" %
|
||
|
(reshaped_t.shape, expected_reshaped_shape))
|
||
|
reshaped_t.set_shape(expected_reshaped_shape)
|
||
|
return reshaped_t
|
||
|
|
||
|
def _maybe_split_batch_beams(self, t, s):
|
||
|
"""Maybe splits the tensor from a batch by beams into a batch of beams.
|
||
|
|
||
|
We do this so that we can use nest and not run into problems with shapes.
|
||
|
|
||
|
Args:
|
||
|
t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
|
||
|
s: `Tensor`, Python int, or `TensorShape`.
|
||
|
|
||
|
Returns:
|
||
|
If `t` is a matrix or higher order tensor, then the return value is
|
||
|
`t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is
|
||
|
returned unchanged.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the rank of `t` is not statically known.
|
||
|
"""
|
||
|
if isinstance(t, tf.TensorArray):
|
||
|
return t
|
||
|
_check_maybe(t)
|
||
|
if t.shape.ndims >= 1:
|
||
|
return self._split_batch_beams(t, s)
|
||
|
else:
|
||
|
return t
|
||
|
|
||
|
def _maybe_merge_batch_beams(self, t, s):
|
||
|
"""Splits the tensor from a batch by beams into a batch of beams.
|
||
|
|
||
|
More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`,
|
||
|
then we reshape it to `[batch_size, beam_width] + s`.
|
||
|
|
||
|
Args:
|
||
|
t: `Tensor` of dimension `[batch_size * beam_width] + s`.
|
||
|
s: `Tensor`, Python int, or `TensorShape`.
|
||
|
|
||
|
Returns:
|
||
|
A reshaped version of t with shape `[batch_size, beam_width] + s`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the rank of `t` is not statically known.
|
||
|
"""
|
||
|
if isinstance(t, tf.TensorArray):
|
||
|
return t
|
||
|
_check_maybe(t)
|
||
|
if t.shape.ndims >= 2:
|
||
|
return self._merge_batch_beams(t, s)
|
||
|
else:
|
||
|
return t
|
||
|
|
||
|
def _maybe_sort_array_beams(self, t, parent_ids, sequence_length):
|
||
|
"""Maybe sorts beams within a `TensorArray`.
|
||
|
|
||
|
Args:
|
||
|
t: A `TensorArray` of size `max_time` that contains `Tensor`s of shape
|
||
|
`[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where
|
||
|
`s` is the depth shape.
|
||
|
parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
|
||
|
sequence_length: The sequence length of shape `[batch_size, beam_width]`.
|
||
|
|
||
|
Returns:
|
||
|
A `TensorArray` where beams are sorted in each `Tensor` or `t` itself if
|
||
|
it is not a `TensorArray` or does not meet shape requirements.
|
||
|
"""
|
||
|
if not isinstance(t, tf.TensorArray):
|
||
|
return t
|
||
|
# pylint: disable=protected-access
|
||
|
if (not t._infer_shape or not t._element_shape
|
||
|
or t._element_shape[0].ndims is None
|
||
|
or t._element_shape[0].ndims < 1):
|
||
|
shape = (
|
||
|
t._element_shape[0] if t._infer_shape and t._element_shape
|
||
|
else tf.TensorShape(None))
|
||
|
tf.logging.warn("The TensorArray %s in the cell state is not amenable to "
|
||
|
"sorting based on the beam search result. For a "
|
||
|
"TensorArray to be sorted, its elements shape must be "
|
||
|
"defined and have at least a rank of 1, but saw shape: %s"
|
||
|
% (t.handle.name, shape))
|
||
|
return t
|
||
|
shape = t._element_shape[0]
|
||
|
# pylint: enable=protected-access
|
||
|
if not _check_static_batch_beam_maybe(
|
||
|
shape, tf.contrib.util.constant_value(self._batch_size),
|
||
|
self._beam_width):
|
||
|
return t
|
||
|
t = t.stack()
|
||
|
with tf.control_dependencies(
|
||
|
[_check_batch_beam(t, self._batch_size, self._beam_width)]):
|
||
|
return gather_tree_from_array(t, parent_ids, sequence_length)
|
||
|
|
||
|
def step(self, time, inputs, state, name=None):
|
||
|
"""Perform a decoding step.
|
||
|
|
||
|
Args:
|
||
|
time: scalar `int32` tensor.
|
||
|
inputs: A (structure of) input tensors.
|
||
|
state: A (structure of) state tensors and TensorArrays.
|
||
|
name: Name scope for any created operations.
|
||
|
|
||
|
Returns:
|
||
|
`(outputs, next_state, next_inputs, finished)`.
|
||
|
"""
|
||
|
batch_size = self._batch_size
|
||
|
beam_width = self._beam_width
|
||
|
end_token = self._end_token
|
||
|
length_penalty_weight = self._length_penalty_weight
|
||
|
coverage_penalty_weight = self._coverage_penalty_weight
|
||
|
|
||
|
with tf.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
|
||
|
cell_state = state.cell_state
|
||
|
inputs = tf.contrib.framework.nest.map_structure(
|
||
|
lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
|
||
|
cell_state = tf.contrib.framework.nest.map_structure(
|
||
|
self._maybe_merge_batch_beams, cell_state, self._cell.state_size)
|
||
|
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
|
||
|
cell_outputs = tf.contrib.framework.nest.map_structure(
|
||
|
lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
|
||
|
next_cell_state = tf.contrib.framework.nest.map_structure(
|
||
|
self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
|
||
|
|
||
|
if self._output_layer is not None:
|
||
|
cell_outputs = self._output_layer(cell_outputs)
|
||
|
|
||
|
beam_search_output, beam_search_state = _beam_search_step(
|
||
|
time=time,
|
||
|
logits=cell_outputs,
|
||
|
next_cell_state=next_cell_state,
|
||
|
beam_state=state,
|
||
|
batch_size=batch_size,
|
||
|
beam_width=beam_width,
|
||
|
end_token=end_token,
|
||
|
length_penalty_weight=length_penalty_weight,
|
||
|
coverage_penalty_weight=coverage_penalty_weight)
|
||
|
|
||
|
finished = beam_search_state.finished
|
||
|
sample_ids = beam_search_output.predicted_ids
|
||
|
next_inputs = tf.cond(
|
||
|
tf.reduce_all(finished), lambda: self._start_inputs,
|
||
|
lambda: self._embedding_fn(sample_ids))
|
||
|
|
||
|
return (beam_search_output, beam_search_state, next_inputs, finished)
|
||
|
|
||
|
|
||
|
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||
|
beam_width, end_token, length_penalty_weight,
|
||
|
coverage_penalty_weight):
|
||
|
"""Performs a single step of Beam Search Decoding.
|
||
|
|
||
|
Args:
|
||
|
time: Beam search time step, should start at 0. At time 0 we assume
|
||
|
that all beams are equal and consider only the first beam for
|
||
|
continuations.
|
||
|
logits: Logits at the current time step. A tensor of shape
|
||
|
`[batch_size, beam_width, vocab_size]`
|
||
|
next_cell_state: The next state from the cell, e.g. an instance of
|
||
|
AttentionWrapperState if the cell is attentional.
|
||
|
beam_state: Current state of the beam search.
|
||
|
An instance of `BeamSearchDecoderState`.
|
||
|
batch_size: The batch size for this input.
|
||
|
beam_width: Python int. The size of the beams.
|
||
|
end_token: The int32 end token.
|
||
|
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||
|
sentence. Disabled with 0.0.
|
||
|
|
||
|
Returns:
|
||
|
A new beam state.
|
||
|
"""
|
||
|
static_batch_size = tf.contrib.util.constant_value(batch_size)
|
||
|
|
||
|
# Calculate the current lengths of the predictions
|
||
|
prediction_lengths = beam_state.lengths
|
||
|
previously_finished = beam_state.finished
|
||
|
not_finished = tf.logical_not(previously_finished)
|
||
|
|
||
|
# Calculate the total log probs for the new hypotheses
|
||
|
# Final Shape: [batch_size, beam_width, vocab_size]
|
||
|
step_log_probs = tf.nn.log_softmax(logits)
|
||
|
step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
|
||
|
total_probs = tf.expand_dims(beam_state.log_probs, 2) + step_log_probs
|
||
|
|
||
|
# Calculate the continuation lengths by adding to all continuing beams.
|
||
|
vocab_size = logits.shape[-1].value or tf.shape(logits)[-1]
|
||
|
lengths_to_add = tf.one_hot(
|
||
|
indices=tf.fill([batch_size, beam_width], end_token),
|
||
|
depth=vocab_size,
|
||
|
on_value=np.int64(0),
|
||
|
off_value=np.int64(1),
|
||
|
dtype=tf.int64)
|
||
|
add_mask = tf.to_int64(not_finished)
|
||
|
lengths_to_add *= tf.expand_dims(add_mask, 2)
|
||
|
new_prediction_lengths = (
|
||
|
lengths_to_add + tf.expand_dims(prediction_lengths, 2))
|
||
|
|
||
|
# Calculate the accumulated attention probabilities if coverage penalty is
|
||
|
# enabled.
|
||
|
accumulated_attention_probs = None
|
||
|
attention_probs = get_attention_probs(
|
||
|
next_cell_state, coverage_penalty_weight)
|
||
|
if attention_probs is not None:
|
||
|
attention_probs *= tf.expand_dims(tf.to_float(not_finished), 2)
|
||
|
accumulated_attention_probs = (
|
||
|
beam_state.accumulated_attention_probs + attention_probs)
|
||
|
|
||
|
batch_finished = tf.reduce_all(
|
||
|
previously_finished, axis=1, keepdims=True)
|
||
|
any_batch_finished = tf.reduce_any(batch_finished)
|
||
|
batch_finished = tf.tile(tf.expand_dims(batch_finished, 2),
|
||
|
[1, beam_width, vocab_size])
|
||
|
def _normalized_scores():
|
||
|
return _get_scores(
|
||
|
log_probs=total_probs,
|
||
|
sequence_lengths=new_prediction_lengths,
|
||
|
length_penalty_weight=length_penalty_weight,
|
||
|
coverage_penalty_weight=coverage_penalty_weight,
|
||
|
finished=batch_finished,
|
||
|
accumulated_attention_probs=accumulated_attention_probs)
|
||
|
|
||
|
# Normalize the scores of finished batches.
|
||
|
scores = tf.cond(any_batch_finished, _normalized_scores, lambda: total_probs)
|
||
|
|
||
|
time = tf.convert_to_tensor(time, name="time")
|
||
|
# During the first time step we only consider the initial beam
|
||
|
scores_flat = tf.reshape(scores, [batch_size, -1])
|
||
|
|
||
|
# Pick the next beams according to the specified successors function
|
||
|
next_beam_size = tf.convert_to_tensor(
|
||
|
beam_width, dtype=tf.int32, name="beam_width")
|
||
|
next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=next_beam_size)
|
||
|
|
||
|
next_beam_scores.set_shape([static_batch_size, beam_width])
|
||
|
word_indices.set_shape([static_batch_size, beam_width])
|
||
|
|
||
|
# Pick out the probs, beam_ids, and states according to the chosen predictions
|
||
|
next_beam_probs = _tensor_gather_helper(
|
||
|
gather_indices=word_indices,
|
||
|
gather_from=total_probs,
|
||
|
batch_size=batch_size,
|
||
|
range_size=beam_width * vocab_size,
|
||
|
gather_shape=[-1],
|
||
|
name="next_beam_probs")
|
||
|
# Note: just doing the following
|
||
|
# tf.to_int32(word_indices % vocab_size,
|
||
|
# name="next_beam_word_ids")
|
||
|
# would be a lot cleaner but for reasons unclear, that hides the results of
|
||
|
# the op which prevents capturing it with tfdbg debug ops.
|
||
|
raw_next_word_ids = tf.mod(
|
||
|
word_indices, vocab_size, name="next_beam_word_ids")
|
||
|
next_word_ids = tf.to_int32(raw_next_word_ids)
|
||
|
next_beam_ids = tf.to_int32(
|
||
|
word_indices / vocab_size, name="next_beam_parent_ids")
|
||
|
|
||
|
# Append new ids to current predictions
|
||
|
previously_finished = _tensor_gather_helper(
|
||
|
gather_indices=next_beam_ids,
|
||
|
gather_from=previously_finished,
|
||
|
batch_size=batch_size,
|
||
|
range_size=beam_width,
|
||
|
gather_shape=[-1])
|
||
|
next_finished = tf.logical_or(
|
||
|
previously_finished,
|
||
|
tf.equal(next_word_ids, end_token),
|
||
|
name="next_beam_finished")
|
||
|
|
||
|
# Calculate the length of the next predictions.
|
||
|
# 1. Finished beams remain unchanged.
|
||
|
# 2. Beams that are now finished (EOS predicted) have their length
|
||
|
# increased by 1.
|
||
|
# 3. Beams that are not yet finished have their length increased by 1.
|
||
|
lengths_to_add = tf.to_int64(tf.logical_not(previously_finished))
|
||
|
next_prediction_len = _tensor_gather_helper(
|
||
|
gather_indices=next_beam_ids,
|
||
|
gather_from=beam_state.lengths,
|
||
|
batch_size=batch_size,
|
||
|
range_size=beam_width,
|
||
|
gather_shape=[-1])
|
||
|
next_prediction_len += lengths_to_add
|
||
|
next_accumulated_attention_probs = ()
|
||
|
if accumulated_attention_probs is not None:
|
||
|
next_accumulated_attention_probs = _tensor_gather_helper(
|
||
|
gather_indices=next_beam_ids,
|
||
|
gather_from=accumulated_attention_probs,
|
||
|
batch_size=batch_size,
|
||
|
range_size=beam_width,
|
||
|
gather_shape=[batch_size * beam_width, -1],
|
||
|
name="next_accumulated_attention_probs")
|
||
|
|
||
|
# Pick out the cell_states according to the next_beam_ids. We use a
|
||
|
# different gather_shape here because the cell_state tensors, i.e.
|
||
|
# the tensors that would be gathered from, all have dimension
|
||
|
# greater than two and we need to preserve those dimensions.
|
||
|
# pylint: disable=g-long-lambda
|
||
|
next_cell_state = tf.contrib.framework.nest.map_structure(
|
||
|
lambda gather_from: _maybe_tensor_gather_helper(
|
||
|
gather_indices=next_beam_ids,
|
||
|
gather_from=gather_from,
|
||
|
batch_size=batch_size,
|
||
|
range_size=beam_width,
|
||
|
gather_shape=[batch_size * beam_width, -1]),
|
||
|
next_cell_state)
|
||
|
# pylint: enable=g-long-lambda
|
||
|
|
||
|
next_state = BeamSearchDecoderState(
|
||
|
cell_state=next_cell_state,
|
||
|
log_probs=next_beam_probs,
|
||
|
lengths=next_prediction_len,
|
||
|
finished=next_finished,
|
||
|
accumulated_attention_probs=next_accumulated_attention_probs)
|
||
|
|
||
|
output = BeamSearchDecoderOutput(
|
||
|
scores=next_beam_scores,
|
||
|
predicted_ids=next_word_ids,
|
||
|
parent_ids=next_beam_ids)
|
||
|
|
||
|
return output, next_state
|
||
|
|
||
|
|
||
|
def get_attention_probs(next_cell_state, coverage_penalty_weight):
|
||
|
"""Get attention probabilities from the cell state.
|
||
|
|
||
|
Args:
|
||
|
next_cell_state: The next state from the cell, e.g. an instance of
|
||
|
AttentionWrapperState if the cell is attentional.
|
||
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||
|
sentence. Disabled with 0.0.
|
||
|
|
||
|
Returns:
|
||
|
The attention probabilities with shape `[batch_size, beam_width, max_time]`
|
||
|
if coverage penalty is enabled. Otherwise, returns None.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If no cell is attentional but coverage penalty is enabled.
|
||
|
"""
|
||
|
if coverage_penalty_weight == 0.0:
|
||
|
return None
|
||
|
|
||
|
# Attention probabilities of each attention layer. Each with shape
|
||
|
# `[batch_size, beam_width, max_time]`.
|
||
|
probs_per_attn_layer = []
|
||
|
if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
|
||
|
probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
|
||
|
elif isinstance(next_cell_state, tuple):
|
||
|
for state in next_cell_state:
|
||
|
if isinstance(state, attention_wrapper.AttentionWrapperState):
|
||
|
probs_per_attn_layer.append(attention_probs_from_attn_state(state))
|
||
|
|
||
|
if not probs_per_attn_layer:
|
||
|
raise ValueError(
|
||
|
"coverage_penalty_weight must be 0.0 if no cell is attentional.")
|
||
|
|
||
|
if len(probs_per_attn_layer) == 1:
|
||
|
attention_probs = probs_per_attn_layer[0]
|
||
|
else:
|
||
|
# Calculate the average attention probabilities from all attention layers.
|
||
|
attention_probs = [
|
||
|
tf.expand_dims(prob, -1) for prob in probs_per_attn_layer]
|
||
|
attention_probs = tf.concat(attention_probs, -1)
|
||
|
attention_probs = tf.reduce_mean(attention_probs, -1)
|
||
|
|
||
|
return attention_probs
|
||
|
|
||
|
|
||
|
def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
|
||
|
coverage_penalty_weight, finished, accumulated_attention_probs):
|
||
|
"""Calculates scores for beam search hypotheses.
|
||
|
|
||
|
Args:
|
||
|
log_probs: The log probabilities with shape
|
||
|
`[batch_size, beam_width, vocab_size]`.
|
||
|
sequence_lengths: The array of sequence lengths.
|
||
|
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
|
||
|
coverage_penalty_weight: Float weight to penalize the coverage of source
|
||
|
sentence. Disabled with 0.0.
|
||
|
finished: A boolean tensor of shape `[batch_size, beam_width, vocab_size]`
|
||
|
that specifies which elements in the beam are finished already.
|
||
|
accumulated_attention_probs: Accumulated attention probabilities up to the
|
||
|
current time step, with shape `[batch_size, beam_width, max_time]` if
|
||
|
coverage_penalty_weight is not 0.0.
|
||
|
|
||
|
Returns:
|
||
|
The scores normalized by the length_penalty and coverage_penalty.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: accumulated_attention_probs is None when coverage penalty is
|
||
|
enabled.
|
||
|
"""
|
||
|
length_penalty_ = _length_penalty(
|
||
|
sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
|
||
|
|
||
|
coverage_penalty_weight = tf.convert_to_tensor(
|
||
|
coverage_penalty_weight, name="coverage_penalty_weight")
|
||
|
if coverage_penalty_weight.shape.ndims != 0:
|
||
|
raise ValueError("coverage_penalty_weight should be a scalar, "
|
||
|
"but saw shape: %s" % coverage_penalty_weight.shape)
|
||
|
|
||
|
if accumulated_attention_probs is None:
|
||
|
raise ValueError(
|
||
|
"accumulated_attention_probs can be None only if coverage penalty is "
|
||
|
"disabled.")
|
||
|
|
||
|
# Add source sequence length mask before computing coverage penalty.
|
||
|
accumulated_attention_probs = tf.where(
|
||
|
tf.equal(accumulated_attention_probs, 0.0),
|
||
|
tf.ones_like(accumulated_attention_probs),
|
||
|
accumulated_attention_probs)
|
||
|
|
||
|
# coverage penalty =
|
||
|
# sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
|
||
|
coverage_penalty = tf.reduce_sum(
|
||
|
tf.log(tf.minimum(accumulated_attention_probs, 1.0)), 2)
|
||
|
# Apply coverage penalty to finished predictions.
|
||
|
weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
|
||
|
# Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
|
||
|
weighted_coverage_penalty = tf.expand_dims(
|
||
|
weighted_coverage_penalty, 2)
|
||
|
|
||
|
# Normalize the scores of finished predictions.
|
||
|
return tf.where(
|
||
|
finished, log_probs / length_penalty_ + weighted_coverage_penalty,
|
||
|
log_probs)
|
||
|
|
||
|
|
||
|
def attention_probs_from_attn_state(attention_state):
|
||
|
"""Calculates the average attention probabilities.
|
||
|
|
||
|
Args:
|
||
|
attention_state: An instance of `AttentionWrapperState`.
|
||
|
|
||
|
Returns:
|
||
|
The attention probabilities in the given AttentionWrapperState.
|
||
|
If there're multiple attention mechanisms, return the average value from
|
||
|
all attention mechanisms.
|
||
|
"""
|
||
|
# Attention probabilities over time steps, with shape
|
||
|
# `[batch_size, beam_width, max_time]`.
|
||
|
attention_probs = attention_state.alignments
|
||
|
if isinstance(attention_probs, tuple):
|
||
|
attention_probs = [
|
||
|
tf.expand_dims(prob, -1) for prob in attention_probs]
|
||
|
attention_probs = tf.concat(attention_probs, -1)
|
||
|
attention_probs = tf.reduce_mean(attention_probs, -1)
|
||
|
return attention_probs
|
||
|
|
||
|
|
||
|
def _length_penalty(sequence_lengths, penalty_factor):
|
||
|
"""Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
|
||
|
|
||
|
Returns the length penalty tensor:
|
||
|
```
|
||
|
[(5+sequence_lengths)/6]**penalty_factor
|
||
|
```
|
||
|
where all operations are performed element-wise.
|
||
|
|
||
|
Args:
|
||
|
sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
|
||
|
penalty_factor: A scalar that weights the length penalty.
|
||
|
|
||
|
Returns:
|
||
|
If the penalty is `0`, returns the scalar `1.0`. Otherwise returns
|
||
|
the length penalty factor, a tensor with the same shape as
|
||
|
`sequence_lengths`.
|
||
|
"""
|
||
|
penalty_factor = tf.convert_to_tensor(penalty_factor, name="penalty_factor")
|
||
|
penalty_factor.set_shape(()) # penalty should be a scalar.
|
||
|
static_penalty = tf.contrib.util.constant_value(penalty_factor)
|
||
|
if static_penalty is not None and static_penalty == 0:
|
||
|
return 1.0
|
||
|
return tf.div((5. + tf.to_float(sequence_lengths))
|
||
|
**penalty_factor, (5. + 1.)**penalty_factor)
|
||
|
|
||
|
|
||
|
def _mask_probs(probs, eos_token, finished):
|
||
|
"""Masks log probabilities.
|
||
|
|
||
|
The result is that finished beams allocate all probability mass to eos and
|
||
|
unfinished beams remain unchanged.
|
||
|
|
||
|
Args:
|
||
|
probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]`
|
||
|
eos_token: An int32 id corresponding to the EOS token to allocate
|
||
|
probability to.
|
||
|
finished: A boolean tensor of shape `[batch_size, beam_width]` that
|
||
|
specifies which elements in the beam are finished already.
|
||
|
|
||
|
Returns:
|
||
|
A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished
|
||
|
beams stay unchanged and finished beams are replaced with a tensor with all
|
||
|
probability on the EOS token.
|
||
|
"""
|
||
|
vocab_size = tf.shape(probs)[2]
|
||
|
# All finished examples are replaced with a vector that has all
|
||
|
# probability on EOS
|
||
|
finished_row = tf.one_hot(
|
||
|
eos_token,
|
||
|
vocab_size,
|
||
|
dtype=probs.dtype,
|
||
|
on_value=tf.convert_to_tensor(0., dtype=probs.dtype),
|
||
|
off_value=probs.dtype.min)
|
||
|
finished_probs = tf.tile(
|
||
|
tf.reshape(finished_row, [1, 1, -1]),
|
||
|
tf.concat([tf.shape(finished), [1]], 0))
|
||
|
finished_mask = tf.tile(
|
||
|
tf.expand_dims(finished, 2), [1, 1, vocab_size])
|
||
|
|
||
|
return tf.where(finished_mask, finished_probs, probs)
|
||
|
|
||
|
|
||
|
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
|
||
|
range_size, gather_shape):
|
||
|
"""Maybe applies _tensor_gather_helper.
|
||
|
|
||
|
This applies _tensor_gather_helper when the gather_from dims is at least as
|
||
|
big as the length of gather_shape. This is used in conjunction with nest so
|
||
|
that we don't apply _tensor_gather_helper to inapplicable values like scalars.
|
||
|
|
||
|
Args:
|
||
|
gather_indices: The tensor indices that we use to gather.
|
||
|
gather_from: The tensor that we are gathering from.
|
||
|
batch_size: The batch size.
|
||
|
range_size: The number of values in each range. Likely equal to beam_width.
|
||
|
gather_shape: What we should reshape gather_from to in order to preserve the
|
||
|
correct values. An example is when gather_from is the attention from an
|
||
|
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
|
||
|
There, we want to preserve the attention_size elements, so gather_shape is
|
||
|
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
|
||
|
attention_size as desired.
|
||
|
|
||
|
Returns:
|
||
|
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
|
||
|
or the original tensor if its dimensions are too small.
|
||
|
"""
|
||
|
if isinstance(gather_from, tf.TensorArray):
|
||
|
return gather_from
|
||
|
_check_maybe(gather_from)
|
||
|
if gather_from.shape.ndims >= len(gather_shape):
|
||
|
return _tensor_gather_helper(
|
||
|
gather_indices=gather_indices,
|
||
|
gather_from=gather_from,
|
||
|
batch_size=batch_size,
|
||
|
range_size=range_size,
|
||
|
gather_shape=gather_shape)
|
||
|
else:
|
||
|
return gather_from
|
||
|
|
||
|
|
||
|
def _tensor_gather_helper(gather_indices,
|
||
|
gather_from,
|
||
|
batch_size,
|
||
|
range_size,
|
||
|
gather_shape,
|
||
|
name=None):
|
||
|
"""Helper for gathering the right indices from the tensor.
|
||
|
|
||
|
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
|
||
|
gathering from that according to the gather_indices, which are offset by
|
||
|
the right amounts in order to preserve the batch order.
|
||
|
|
||
|
Args:
|
||
|
gather_indices: The tensor indices that we use to gather.
|
||
|
gather_from: The tensor that we are gathering from.
|
||
|
batch_size: The input batch size.
|
||
|
range_size: The number of values in each range. Likely equal to beam_width.
|
||
|
gather_shape: What we should reshape gather_from to in order to preserve the
|
||
|
correct values. An example is when gather_from is the attention from an
|
||
|
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
|
||
|
There, we want to preserve the attention_size elements, so gather_shape is
|
||
|
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
|
||
|
attention_size as desired.
|
||
|
name: The tensor name for set of operations. By default this is
|
||
|
'tensor_gather_helper'. The final output is named 'output'.
|
||
|
|
||
|
Returns:
|
||
|
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
|
||
|
"""
|
||
|
with tf.name_scope(name, "tensor_gather_helper"):
|
||
|
range_ = tf.expand_dims(tf.range(batch_size) * range_size, 1)
|
||
|
gather_indices = tf.reshape(gather_indices + range_, [-1])
|
||
|
output = tf.gather(
|
||
|
tf.reshape(gather_from, gather_shape), gather_indices)
|
||
|
final_shape = tf.shape(gather_from)[:1 + len(gather_shape)]
|
||
|
static_batch_size = tf.contrib.util.constant_value(batch_size)
|
||
|
final_static_shape = (
|
||
|
tf.TensorShape([static_batch_size]).concatenate(
|
||
|
gather_from.shape[1:1 + len(gather_shape)]))
|
||
|
output = tf.reshape(output, final_shape, name="output")
|
||
|
output.set_shape(final_static_shape)
|
||
|
return output
|