DeepLearningExamples/TensorFlow/Translation/GNMT/beam_search_decoder.py

1120 lines
44 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A decoder that performs beam search."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import tensorflow as tf
import attention_wrapper
__all__ = [
"BeamSearchDecoderOutput",
"BeamSearchDecoderState",
"BeamSearchDecoder",
"FinalBeamSearchDecoderOutput",
"tile_batch",
]
class BeamSearchDecoderState(
collections.namedtuple("BeamSearchDecoderState",
("cell_state", "log_probs", "finished", "lengths",
"accumulated_attention_probs"))):
pass
class BeamSearchDecoderOutput(
collections.namedtuple("BeamSearchDecoderOutput",
("scores", "predicted_ids", "parent_ids"))):
pass
class FinalBeamSearchDecoderOutput(
collections.namedtuple("FinalBeamDecoderOutput",
["predicted_ids", "beam_search_decoder_output"])):
"""Final outputs returned by the beam search after all decoding is finished.
Args:
predicted_ids: The final prediction. A tensor of shape
`[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if
`output_time_major` is True). Beams are ordered from best to worst.
beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that
describes the state of the beam search.
"""
pass
def _tile_batch(t, multiplier):
"""Core single-tensor implementation of tile_batch."""
t = tf.convert_to_tensor(t, name="t")
shape_t = tf.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1:
raise ValueError("t must have statically known rank")
tiling = [1] * (t.shape.ndims + 1)
tiling[1] = multiplier
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = tf.tile(tf.expand_dims(t, 1), tiling)
tiled = tf.reshape(
tiled, tf.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
tf.TensorShape([tiled_static_batch_size]).concatenate(
t.shape[1:]))
return tiled
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
For each tensor t in a (possibly nested structure) of tensors,
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
`multiplier` times.
Args:
t: `Tensor` shaped `[batch_size, ...]`.
multiplier: Python int.
name: Name scope for any created operations.
Returns:
A (possibly nested structure of) `Tensor` shaped
`[batch_size * multiplier, ...]`.
Raises:
ValueError: if tensor(s) `t` do not have a statically known rank or
the rank is < 1.
"""
flat_t = tf.contrib.framework.nest.flatten(t)
with tf.name_scope(name, "tile_batch", flat_t + [multiplier]):
return tf.contrib.framework.nest.map_structure(
lambda t_: _tile_batch(t_, multiplier), t)
def gather_tree_from_array(t, parent_ids, sequence_length):
"""Calculates the full beams for `TensorArray`s.
Args:
t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
where `s` is the depth shape.
parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
sequence_length: The sequence length of shape `[batch_size, beam_width]`.
Returns:
A `Tensor` which is a stacked `TensorArray` of the same size and type as
`t` and where beams are sorted in each `Tensor` according to `parent_ids`.
"""
max_time = parent_ids.shape[0].value or tf.shape(parent_ids)[0]
batch_size = parent_ids.shape[1].value or tf.shape(parent_ids)[1]
beam_width = parent_ids.shape[2].value or tf.shape(parent_ids)[2]
# Generate beam ids that will be reordered by gather_tree.
beam_ids = tf.expand_dims(
tf.expand_dims(tf.range(beam_width), 0), 0)
beam_ids = tf.tile(beam_ids, [max_time, batch_size, 1])
max_sequence_lengths = tf.to_int32(tf.reduce_max(sequence_length, axis=1))
sorted_beam_ids = tf.contrib.seq2seq.gather_tree(
step_ids=beam_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=beam_width + 1)
# For out of range steps, simply copy the same beam.
in_bound_steps = tf.transpose(
tf.sequence_mask(sequence_length, maxlen=max_time),
perm=[2, 0, 1])
sorted_beam_ids = tf.where(
in_bound_steps, x=sorted_beam_ids, y=beam_ids)
# Generate indices for gather_nd.
time_ind = tf.tile(tf.reshape(
tf.range(max_time), [-1, 1, 1]), [1, batch_size, beam_width])
batch_ind = tf.tile(tf.reshape(
tf.range(batch_size), [-1, 1, 1]), [1, max_time, beam_width])
batch_ind = tf.transpose(batch_ind, perm=[1, 0, 2])
indices = tf.stack([time_ind, batch_ind, sorted_beam_ids], -1)
# Gather from a tensor with collapsed additional dimensions.
gather_from = t
final_shape = tf.shape(gather_from)
gather_from = tf.reshape(
gather_from, [max_time, batch_size, beam_width, -1])
ordered = tf.gather_nd(gather_from, indices)
ordered = tf.reshape(ordered, final_shape)
return ordered
def _check_maybe(t):
if t.shape.ndims is None:
raise ValueError(
"Expected tensor (%s) to have known rank, but ndims == None." % t)
def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
"""Raises an exception if dimensions are known statically and can not be
reshaped to [batch_size, beam_size, -1].
"""
reshaped_shape = tf.TensorShape([batch_size, beam_width, None])
if (batch_size is not None and shape[0].value is not None
and (shape[0] != batch_size * beam_width
or (shape.ndims >= 2 and shape[1].value is not None
and (shape[0] != batch_size or shape[1] != beam_width)))):
tf.logging.warn("TensorArray reordering expects elements to be "
"reshapable to %s which is incompatible with the "
"current shape %s. Consider setting "
"reorder_tensor_arrays to False to disable TensorArray "
"reordering during the beam search."
% (reshaped_shape, shape))
return False
return True
def _check_batch_beam(t, batch_size, beam_width):
"""Returns an Assert operation checking that the elements of the stacked
TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
the TensorArray elements have a known rank of at least 1.
"""
error_message = ("TensorArray reordering expects elements to be "
"reshapable to [batch_size, beam_size, -1] which is "
"incompatible with the dynamic shape of %s elements. "
"Consider setting reorder_tensor_arrays to False to disable "
"TensorArray reordering during the beam search."
% (t.name))
rank = t.shape.ndims
shape = tf.shape(t)
if rank == 2:
condition = tf.equal(shape[1], batch_size * beam_width)
else:
condition = tf.logical_or(
tf.equal(shape[1], batch_size * beam_width),
tf.logical_and(
tf.equal(shape[1], batch_size),
tf.equal(shape[2], beam_width)))
return tf.Assert(condition, [error_message])
class BeamSearchDecoder(tf.contrib.seq2seq.Decoder):
"""BeamSearch sampling decoder.
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
`AttentionWrapper`, then you must ensure that:
- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `zero_state` method of this
wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `zero_state` above contains a
`cell_state` value containing properly tiled final state from the
encoder.
An example:
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
the translation to cover all inputs.
"""
def __init__(self,
cell,
embedding,
start_tokens,
end_token,
initial_state,
beam_width,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
Args:
cell: An `RNNCell` instance.
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
beam_width: Python integer, the number of beams.
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
Otherwise, the `TensorArray` will be returned as is. Set this flag to
`False` if the cell state contains `TensorArray`s that are not amenable
to reordering.
Raises:
TypeError: if `cell` is not an instance of `RNNCell`,
or `output_layer` is not an instance of `tf.layers.Layer`.
ValueError: If `start_tokens` is not a vector or
`end_token` is not a scalar.
"""
if (output_layer is not None and
not isinstance(output_layer, tf.layers.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
self._output_layer = output_layer
self._reorder_tensor_arrays = reorder_tensor_arrays
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: tf.nn.embedding_lookup(embedding, ids))
self._start_tokens = tf.convert_to_tensor(
start_tokens, dtype=tf.int32, name="start_tokens")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._end_token = tf.convert_to_tensor(
end_token, dtype=tf.int32, name="end_token")
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._batch_size = tf.size(start_tokens)
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = tf.contrib.framework.nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = tf.tile(
tf.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
self._finished = tf.one_hot(
tf.zeros([self._batch_size], dtype=tf.int32),
depth=self._beam_width,
on_value=False,
off_value=True,
dtype=tf.bool)
@property
def batch_size(self):
return self._batch_size
def _rnn_output_size(self):
size = self._cell.output_size
if self._output_layer is None:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = tf.contrib.framework.nest.map_structure(
lambda s: tf.TensorShape([None]).concatenate(s), size)
layer_output_shape = self._output_layer.compute_output_shape(
output_shape_with_unknown_batch)
return tf.contrib.framework.nest.map_structure(
lambda s: s[1:], layer_output_shape)
@property
def tracks_own_finished(self):
"""The BeamSearchDecoder shuffles its beams and their finished state.
For this reason, it conflicts with the `dynamic_decode` function's
tracking of finished states. Setting this property to true avoids
early stopping of decoding due to mismanagement of the finished state
in `dynamic_decode`.
Returns:
`True`.
"""
return True
@property
def output_size(self):
# Return the cell output and the id
return BeamSearchDecoderOutput(
scores=tf.TensorShape([self._beam_width]),
predicted_ids=tf.TensorShape([self._beam_width]),
parent_ids=tf.TensorShape([self._beam_width]))
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
dtype = tf.contrib.framework.nest.flatten(self._initial_cell_state)[0].dtype
return BeamSearchDecoderOutput(
scores=tf.contrib.framework.nest.map_structure(
lambda _: dtype, self._rnn_output_size()),
predicted_ids=tf.int32,
parent_ids=tf.int32)
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, start_inputs, initial_state)`.
"""
finished, start_inputs = self._finished, self._start_inputs
dtype = tf.contrib.framework.nest.flatten(self._initial_cell_state)[0].dtype
log_probs = tf.one_hot( # shape(batch_sz, beam_sz)
tf.zeros([self._batch_size], dtype=tf.int32),
depth=self._beam_width,
on_value=tf.convert_to_tensor(0.0, dtype=dtype),
off_value=tf.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
init_attention_probs = get_attention_probs(
self._initial_cell_state, self._coverage_penalty_weight)
if init_attention_probs is None:
init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=tf.zeros(
[self._batch_size, self._beam_width], dtype=tf.int64),
accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
def finalize(self, outputs, final_state, sequence_lengths):
"""Finalize and return the predicted_ids.
Args:
outputs: An instance of BeamSearchDecoderOutput.
final_state: An instance of BeamSearchDecoderState. Passed through to the
output.
sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
The sequence lengths determined for each beam during decode.
**NOTE** These are ignored; the updated sequence lengths are stored in
`final_state.lengths`.
Returns:
outputs: An instance of `FinalBeamSearchDecoderOutput` where the
predicted_ids are the result of calling _gather_tree.
final_state: The same input instance of `BeamSearchDecoderState`.
"""
del sequence_lengths
# Get max_sequence_length across all beams for each batch.
max_sequence_lengths = tf.to_int32(
tf.reduce_max(final_state.lengths, axis=1))
predicted_ids = tf.contrib.seq2seq.gather_tree(
outputs.predicted_ids,
outputs.parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=self._end_token)
if self._reorder_tensor_arrays:
final_state = final_state._replace(
cell_state=tf.contrib.framework.nest.map_structure(
lambda t: self._maybe_sort_array_beams(
t, outputs.parent_ids, final_state.lengths),
final_state.cell_state))
outputs = FinalBeamSearchDecoderOutput(
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
return outputs, final_state
def _merge_batch_beams(self, t, s=None):
"""Merges the tensor from a batch of beams into a batch by beams.
More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We
reshape this into [batch_size*beam_width, s]
Args:
t: Tensor of dimension [batch_size, beam_width, s]
s: (Possibly known) depth shape.
Returns:
A reshaped version of t with dimension [batch_size * beam_width, s].
"""
if isinstance(s, tf.Tensor):
s = tf.contrib.util.constant_value(s)
if isinstance(s, tf.TensorShape):
return s
else:
s = tf.TensorShape(s)
else:
s = tf.TensorShape(s)
t_shape = tf.shape(t)
static_batch_size = tf.contrib.util.constant_value(self._batch_size)
batch_size_beam_width = (
None
if static_batch_size is None else static_batch_size * self._beam_width)
reshaped_t = tf.reshape(
t,
tf.concat(([self._batch_size * self._beam_width], t_shape[2:]), 0))
reshaped_t.set_shape(
(tf.TensorShape([batch_size_beam_width]).concatenate(s)))
return reshaped_t
def _split_batch_beams(self, t, s=None):
"""Splits the tensor from a batch by beams into a batch of beams.
More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
reshape this into [batch_size, beam_width, s]
Args:
t: Tensor of dimension [batch_size*beam_width, s].
s: (Possibly known) depth shape.
Returns:
A reshaped version of t with dimension [batch_size, beam_width, s].
Raises:
ValueError: If, after reshaping, the new tensor is not shaped
`[batch_size, beam_width, s]` (assuming batch_size and beam_width
are known statically).
"""
if isinstance(s, tf.Tensor):
s = tf.TensorShape(tf.contrib.util.constant_value(s))
else:
s = tf.TensorShape(s)
t_shape = tf.shape(t)
reshaped_t = tf.reshape(
t,
tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), 0))
static_batch_size = tf.contrib.util.constant_value(self._batch_size)
expected_reshaped_shape = tf.TensorShape(
[static_batch_size, self._beam_width]).concatenate(s)
if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape):
raise ValueError("Unexpected behavior when reshaping between beam width "
"and batch size. The reshaped tensor has shape: %s. "
"We expected it to have shape "
"(batch_size, beam_width, depth) == %s. Perhaps you "
"forgot to create a zero_state with "
"batch_size=encoder_batch_size * beam_width?" %
(reshaped_t.shape, expected_reshaped_shape))
reshaped_t.set_shape(expected_reshaped_shape)
return reshaped_t
def _maybe_split_batch_beams(self, t, s):
"""Maybe splits the tensor from a batch by beams into a batch of beams.
We do this so that we can use nest and not run into problems with shapes.
Args:
t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
s: `Tensor`, Python int, or `TensorShape`.
Returns:
If `t` is a matrix or higher order tensor, then the return value is
`t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is
returned unchanged.
Raises:
ValueError: If the rank of `t` is not statically known.
"""
if isinstance(t, tf.TensorArray):
return t
_check_maybe(t)
if t.shape.ndims >= 1:
return self._split_batch_beams(t, s)
else:
return t
def _maybe_merge_batch_beams(self, t, s):
"""Splits the tensor from a batch by beams into a batch of beams.
More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`,
then we reshape it to `[batch_size, beam_width] + s`.
Args:
t: `Tensor` of dimension `[batch_size * beam_width] + s`.
s: `Tensor`, Python int, or `TensorShape`.
Returns:
A reshaped version of t with shape `[batch_size, beam_width] + s`.
Raises:
ValueError: If the rank of `t` is not statically known.
"""
if isinstance(t, tf.TensorArray):
return t
_check_maybe(t)
if t.shape.ndims >= 2:
return self._merge_batch_beams(t, s)
else:
return t
def _maybe_sort_array_beams(self, t, parent_ids, sequence_length):
"""Maybe sorts beams within a `TensorArray`.
Args:
t: A `TensorArray` of size `max_time` that contains `Tensor`s of shape
`[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` where
`s` is the depth shape.
parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
sequence_length: The sequence length of shape `[batch_size, beam_width]`.
Returns:
A `TensorArray` where beams are sorted in each `Tensor` or `t` itself if
it is not a `TensorArray` or does not meet shape requirements.
"""
if not isinstance(t, tf.TensorArray):
return t
# pylint: disable=protected-access
if (not t._infer_shape or not t._element_shape
or t._element_shape[0].ndims is None
or t._element_shape[0].ndims < 1):
shape = (
t._element_shape[0] if t._infer_shape and t._element_shape
else tf.TensorShape(None))
tf.logging.warn("The TensorArray %s in the cell state is not amenable to "
"sorting based on the beam search result. For a "
"TensorArray to be sorted, its elements shape must be "
"defined and have at least a rank of 1, but saw shape: %s"
% (t.handle.name, shape))
return t
shape = t._element_shape[0]
# pylint: enable=protected-access
if not _check_static_batch_beam_maybe(
shape, tf.contrib.util.constant_value(self._batch_size),
self._beam_width):
return t
t = t.stack()
with tf.control_dependencies(
[_check_batch_beam(t, self._batch_size, self._beam_width)]):
return gather_tree_from_array(t, parent_ids, sequence_length)
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
batch_size = self._batch_size
beam_width = self._beam_width
end_token = self._end_token
length_penalty_weight = self._length_penalty_weight
coverage_penalty_weight = self._coverage_penalty_weight
with tf.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
inputs = tf.contrib.framework.nest.map_structure(
lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
cell_state = tf.contrib.framework.nest.map_structure(
self._maybe_merge_batch_beams, cell_state, self._cell.state_size)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
cell_outputs = tf.contrib.framework.nest.map_structure(
lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
next_cell_state = tf.contrib.framework.nest.map_structure(
self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
beam_search_output, beam_search_state = _beam_search_step(
time=time,
logits=cell_outputs,
next_cell_state=next_cell_state,
beam_state=state,
batch_size=batch_size,
beam_width=beam_width,
end_token=end_token,
length_penalty_weight=length_penalty_weight,
coverage_penalty_weight=coverage_penalty_weight)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
next_inputs = tf.cond(
tf.reduce_all(finished), lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (beam_search_output, beam_search_state, next_inputs, finished)
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width, end_token, length_penalty_weight,
coverage_penalty_weight):
"""Performs a single step of Beam Search Decoding.
Args:
time: Beam search time step, should start at 0. At time 0 we assume
that all beams are equal and consider only the first beam for
continuations.
logits: Logits at the current time step. A tensor of shape
`[batch_size, beam_width, vocab_size]`
next_cell_state: The next state from the cell, e.g. an instance of
AttentionWrapperState if the cell is attentional.
beam_state: Current state of the beam search.
An instance of `BeamSearchDecoderState`.
batch_size: The batch size for this input.
beam_width: Python int. The size of the beams.
end_token: The int32 end token.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
Returns:
A new beam state.
"""
static_batch_size = tf.contrib.util.constant_value(batch_size)
# Calculate the current lengths of the predictions
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
not_finished = tf.logical_not(previously_finished)
# Calculate the total log probs for the new hypotheses
# Final Shape: [batch_size, beam_width, vocab_size]
step_log_probs = tf.nn.log_softmax(logits)
step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
total_probs = tf.expand_dims(beam_state.log_probs, 2) + step_log_probs
# Calculate the continuation lengths by adding to all continuing beams.
vocab_size = logits.shape[-1].value or tf.shape(logits)[-1]
lengths_to_add = tf.one_hot(
indices=tf.fill([batch_size, beam_width], end_token),
depth=vocab_size,
on_value=np.int64(0),
off_value=np.int64(1),
dtype=tf.int64)
add_mask = tf.to_int64(not_finished)
lengths_to_add *= tf.expand_dims(add_mask, 2)
new_prediction_lengths = (
lengths_to_add + tf.expand_dims(prediction_lengths, 2))
# Calculate the accumulated attention probabilities if coverage penalty is
# enabled.
accumulated_attention_probs = None
attention_probs = get_attention_probs(
next_cell_state, coverage_penalty_weight)
if attention_probs is not None:
attention_probs *= tf.expand_dims(tf.to_float(not_finished), 2)
accumulated_attention_probs = (
beam_state.accumulated_attention_probs + attention_probs)
batch_finished = tf.reduce_all(
previously_finished, axis=1, keepdims=True)
any_batch_finished = tf.reduce_any(batch_finished)
batch_finished = tf.tile(tf.expand_dims(batch_finished, 2),
[1, beam_width, vocab_size])
def _normalized_scores():
return _get_scores(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
length_penalty_weight=length_penalty_weight,
coverage_penalty_weight=coverage_penalty_weight,
finished=batch_finished,
accumulated_attention_probs=accumulated_attention_probs)
# Normalize the scores of finished batches.
scores = tf.cond(any_batch_finished, _normalized_scores, lambda: total_probs)
time = tf.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
scores_flat = tf.reshape(scores, [batch_size, -1])
# Pick the next beams according to the specified successors function
next_beam_size = tf.convert_to_tensor(
beam_width, dtype=tf.int32, name="beam_width")
next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=next_beam_size)
next_beam_scores.set_shape([static_batch_size, beam_width])
word_indices.set_shape([static_batch_size, beam_width])
# Pick out the probs, beam_ids, and states according to the chosen predictions
next_beam_probs = _tensor_gather_helper(
gather_indices=word_indices,
gather_from=total_probs,
batch_size=batch_size,
range_size=beam_width * vocab_size,
gather_shape=[-1],
name="next_beam_probs")
# Note: just doing the following
# tf.to_int32(word_indices % vocab_size,
# name="next_beam_word_ids")
# would be a lot cleaner but for reasons unclear, that hides the results of
# the op which prevents capturing it with tfdbg debug ops.
raw_next_word_ids = tf.mod(
word_indices, vocab_size, name="next_beam_word_ids")
next_word_ids = tf.to_int32(raw_next_word_ids)
next_beam_ids = tf.to_int32(
word_indices / vocab_size, name="next_beam_parent_ids")
# Append new ids to current predictions
previously_finished = _tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=previously_finished,
batch_size=batch_size,
range_size=beam_width,
gather_shape=[-1])
next_finished = tf.logical_or(
previously_finished,
tf.equal(next_word_ids, end_token),
name="next_beam_finished")
# Calculate the length of the next predictions.
# 1. Finished beams remain unchanged.
# 2. Beams that are now finished (EOS predicted) have their length
# increased by 1.
# 3. Beams that are not yet finished have their length increased by 1.
lengths_to_add = tf.to_int64(tf.logical_not(previously_finished))
next_prediction_len = _tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=beam_state.lengths,
batch_size=batch_size,
range_size=beam_width,
gather_shape=[-1])
next_prediction_len += lengths_to_add
next_accumulated_attention_probs = ()
if accumulated_attention_probs is not None:
next_accumulated_attention_probs = _tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=accumulated_attention_probs,
batch_size=batch_size,
range_size=beam_width,
gather_shape=[batch_size * beam_width, -1],
name="next_accumulated_attention_probs")
# Pick out the cell_states according to the next_beam_ids. We use a
# different gather_shape here because the cell_state tensors, i.e.
# the tensors that would be gathered from, all have dimension
# greater than two and we need to preserve those dimensions.
# pylint: disable=g-long-lambda
next_cell_state = tf.contrib.framework.nest.map_structure(
lambda gather_from: _maybe_tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=gather_from,
batch_size=batch_size,
range_size=beam_width,
gather_shape=[batch_size * beam_width, -1]),
next_cell_state)
# pylint: enable=g-long-lambda
next_state = BeamSearchDecoderState(
cell_state=next_cell_state,
log_probs=next_beam_probs,
lengths=next_prediction_len,
finished=next_finished,
accumulated_attention_probs=next_accumulated_attention_probs)
output = BeamSearchDecoderOutput(
scores=next_beam_scores,
predicted_ids=next_word_ids,
parent_ids=next_beam_ids)
return output, next_state
def get_attention_probs(next_cell_state, coverage_penalty_weight):
"""Get attention probabilities from the cell state.
Args:
next_cell_state: The next state from the cell, e.g. an instance of
AttentionWrapperState if the cell is attentional.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
Returns:
The attention probabilities with shape `[batch_size, beam_width, max_time]`
if coverage penalty is enabled. Otherwise, returns None.
Raises:
ValueError: If no cell is attentional but coverage penalty is enabled.
"""
if coverage_penalty_weight == 0.0:
return None
# Attention probabilities of each attention layer. Each with shape
# `[batch_size, beam_width, max_time]`.
probs_per_attn_layer = []
if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
elif isinstance(next_cell_state, tuple):
for state in next_cell_state:
if isinstance(state, attention_wrapper.AttentionWrapperState):
probs_per_attn_layer.append(attention_probs_from_attn_state(state))
if not probs_per_attn_layer:
raise ValueError(
"coverage_penalty_weight must be 0.0 if no cell is attentional.")
if len(probs_per_attn_layer) == 1:
attention_probs = probs_per_attn_layer[0]
else:
# Calculate the average attention probabilities from all attention layers.
attention_probs = [
tf.expand_dims(prob, -1) for prob in probs_per_attn_layer]
attention_probs = tf.concat(attention_probs, -1)
attention_probs = tf.reduce_mean(attention_probs, -1)
return attention_probs
def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
coverage_penalty_weight, finished, accumulated_attention_probs):
"""Calculates scores for beam search hypotheses.
Args:
log_probs: The log probabilities with shape
`[batch_size, beam_width, vocab_size]`.
sequence_lengths: The array of sequence lengths.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight: Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
finished: A boolean tensor of shape `[batch_size, beam_width, vocab_size]`
that specifies which elements in the beam are finished already.
accumulated_attention_probs: Accumulated attention probabilities up to the
current time step, with shape `[batch_size, beam_width, max_time]` if
coverage_penalty_weight is not 0.0.
Returns:
The scores normalized by the length_penalty and coverage_penalty.
Raises:
ValueError: accumulated_attention_probs is None when coverage penalty is
enabled.
"""
length_penalty_ = _length_penalty(
sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
coverage_penalty_weight = tf.convert_to_tensor(
coverage_penalty_weight, name="coverage_penalty_weight")
if coverage_penalty_weight.shape.ndims != 0:
raise ValueError("coverage_penalty_weight should be a scalar, "
"but saw shape: %s" % coverage_penalty_weight.shape)
if accumulated_attention_probs is None:
raise ValueError(
"accumulated_attention_probs can be None only if coverage penalty is "
"disabled.")
# Add source sequence length mask before computing coverage penalty.
accumulated_attention_probs = tf.where(
tf.equal(accumulated_attention_probs, 0.0),
tf.ones_like(accumulated_attention_probs),
accumulated_attention_probs)
# coverage penalty =
# sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
coverage_penalty = tf.reduce_sum(
tf.log(tf.minimum(accumulated_attention_probs, 1.0)), 2)
# Apply coverage penalty to finished predictions.
weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
# Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
weighted_coverage_penalty = tf.expand_dims(
weighted_coverage_penalty, 2)
# Normalize the scores of finished predictions.
return tf.where(
finished, log_probs / length_penalty_ + weighted_coverage_penalty,
log_probs)
def attention_probs_from_attn_state(attention_state):
"""Calculates the average attention probabilities.
Args:
attention_state: An instance of `AttentionWrapperState`.
Returns:
The attention probabilities in the given AttentionWrapperState.
If there're multiple attention mechanisms, return the average value from
all attention mechanisms.
"""
# Attention probabilities over time steps, with shape
# `[batch_size, beam_width, max_time]`.
attention_probs = attention_state.alignments
if isinstance(attention_probs, tuple):
attention_probs = [
tf.expand_dims(prob, -1) for prob in attention_probs]
attention_probs = tf.concat(attention_probs, -1)
attention_probs = tf.reduce_mean(attention_probs, -1)
return attention_probs
def _length_penalty(sequence_lengths, penalty_factor):
"""Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
Returns the length penalty tensor:
```
[(5+sequence_lengths)/6]**penalty_factor
```
where all operations are performed element-wise.
Args:
sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
penalty_factor: A scalar that weights the length penalty.
Returns:
If the penalty is `0`, returns the scalar `1.0`. Otherwise returns
the length penalty factor, a tensor with the same shape as
`sequence_lengths`.
"""
penalty_factor = tf.convert_to_tensor(penalty_factor, name="penalty_factor")
penalty_factor.set_shape(()) # penalty should be a scalar.
static_penalty = tf.contrib.util.constant_value(penalty_factor)
if static_penalty is not None and static_penalty == 0:
return 1.0
return tf.div((5. + tf.to_float(sequence_lengths))
**penalty_factor, (5. + 1.)**penalty_factor)
def _mask_probs(probs, eos_token, finished):
"""Masks log probabilities.
The result is that finished beams allocate all probability mass to eos and
unfinished beams remain unchanged.
Args:
probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]`
eos_token: An int32 id corresponding to the EOS token to allocate
probability to.
finished: A boolean tensor of shape `[batch_size, beam_width]` that
specifies which elements in the beam are finished already.
Returns:
A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished
beams stay unchanged and finished beams are replaced with a tensor with all
probability on the EOS token.
"""
vocab_size = tf.shape(probs)[2]
# All finished examples are replaced with a vector that has all
# probability on EOS
finished_row = tf.one_hot(
eos_token,
vocab_size,
dtype=probs.dtype,
on_value=tf.convert_to_tensor(0., dtype=probs.dtype),
off_value=probs.dtype.min)
finished_probs = tf.tile(
tf.reshape(finished_row, [1, 1, -1]),
tf.concat([tf.shape(finished), [1]], 0))
finished_mask = tf.tile(
tf.expand_dims(finished, 2), [1, 1, vocab_size])
return tf.where(finished_mask, finished_probs, probs)
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
range_size, gather_shape):
"""Maybe applies _tensor_gather_helper.
This applies _tensor_gather_helper when the gather_from dims is at least as
big as the length of gather_shape. This is used in conjunction with nest so
that we don't apply _tensor_gather_helper to inapplicable values like scalars.
Args:
gather_indices: The tensor indices that we use to gather.
gather_from: The tensor that we are gathering from.
batch_size: The batch size.
range_size: The number of values in each range. Likely equal to beam_width.
gather_shape: What we should reshape gather_from to in order to preserve the
correct values. An example is when gather_from is the attention from an
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
There, we want to preserve the attention_size elements, so gather_shape is
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
attention_size as desired.
Returns:
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
or the original tensor if its dimensions are too small.
"""
if isinstance(gather_from, tf.TensorArray):
return gather_from
_check_maybe(gather_from)
if gather_from.shape.ndims >= len(gather_shape):
return _tensor_gather_helper(
gather_indices=gather_indices,
gather_from=gather_from,
batch_size=batch_size,
range_size=range_size,
gather_shape=gather_shape)
else:
return gather_from
def _tensor_gather_helper(gather_indices,
gather_from,
batch_size,
range_size,
gather_shape,
name=None):
"""Helper for gathering the right indices from the tensor.
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
gathering from that according to the gather_indices, which are offset by
the right amounts in order to preserve the batch order.
Args:
gather_indices: The tensor indices that we use to gather.
gather_from: The tensor that we are gathering from.
batch_size: The input batch size.
range_size: The number of values in each range. Likely equal to beam_width.
gather_shape: What we should reshape gather_from to in order to preserve the
correct values. An example is when gather_from is the attention from an
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
There, we want to preserve the attention_size elements, so gather_shape is
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
attention_size as desired.
name: The tensor name for set of operations. By default this is
'tensor_gather_helper'. The final output is named 'output'.
Returns:
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
"""
with tf.name_scope(name, "tensor_gather_helper"):
range_ = tf.expand_dims(tf.range(batch_size) * range_size, 1)
gather_indices = tf.reshape(gather_indices + range_, [-1])
output = tf.gather(
tf.reshape(gather_from, gather_shape), gather_indices)
final_shape = tf.shape(gather_from)[:1 + len(gather_shape)]
static_batch_size = tf.contrib.util.constant_value(batch_size)
final_static_shape = (
tf.TensorShape([static_batch_size]).concatenate(
gather_from.shape[1:1 + len(gather_shape)]))
output = tf.reshape(output, final_shape, name="output")
output.set_shape(final_static_shape)
return output