658 lines
25 KiB
Python
658 lines
25 KiB
Python
|
# Copyright 2017 Google Inc. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
#
|
||
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Basic sequence-to-sequence model with dynamic RNN support."""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import abc
|
||
|
import collections
|
||
|
import os
|
||
|
|
||
|
import tensorflow as tf
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
from tensorflow.python.framework import function
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
|
||
|
import attention_wrapper
|
||
|
import model_helper
|
||
|
import beam_search_decoder
|
||
|
from utils import iterator_utils
|
||
|
from utils import math_utils
|
||
|
from utils import misc_utils as utils
|
||
|
from utils import vocab_utils
|
||
|
|
||
|
utils.check_tensorflow_version()
|
||
|
|
||
|
__all__ = ["BaseModel"]
|
||
|
|
||
|
|
||
|
def create_attention_mechanism(
|
||
|
num_units, memory, source_sequence_length, dtype=None):
|
||
|
"""Create attention mechanism based on the attention_option."""
|
||
|
# Mechanism
|
||
|
attention_mechanism = attention_wrapper.BahdanauAttention(
|
||
|
num_units,
|
||
|
memory,
|
||
|
memory_sequence_length=tf.to_int64(source_sequence_length),
|
||
|
normalize=True, dtype=dtype)
|
||
|
return attention_mechanism
|
||
|
|
||
|
|
||
|
class BaseModel(object):
|
||
|
"""Sequence-to-sequence base class.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, hparams, mode, features, scope=None, extra_args=None):
|
||
|
"""Create the model.
|
||
|
|
||
|
Args:
|
||
|
hparams: Hyperparameter configurations.
|
||
|
mode: TRAIN | EVAL | INFER
|
||
|
features: a dict of input features.
|
||
|
scope: scope of the model.
|
||
|
extra_args: model_helper.ExtraArgs, for passing customizable functions.
|
||
|
|
||
|
"""
|
||
|
self.hparams = hparams
|
||
|
# Set params
|
||
|
self._set_params_initializer(hparams, mode, features, scope, extra_args)
|
||
|
|
||
|
# Train graph
|
||
|
res = self.build_graph(hparams, scope=scope)
|
||
|
self._set_train_or_infer(res, hparams)
|
||
|
|
||
|
def _set_params_initializer(self,
|
||
|
hparams,
|
||
|
mode,
|
||
|
features,
|
||
|
scope,
|
||
|
extra_args=None):
|
||
|
"""Set various params for self and initialize."""
|
||
|
self.mode = mode
|
||
|
self.src_vocab_size = hparams.src_vocab_size
|
||
|
self.tgt_vocab_size = hparams.tgt_vocab_size
|
||
|
self.features = features
|
||
|
self.time_major = hparams.time_major
|
||
|
|
||
|
if hparams.use_char_encode:
|
||
|
assert (not self.time_major), ("Can't use time major for"
|
||
|
" char-level inputs.")
|
||
|
|
||
|
self.dtype = tf.float16 if hparams.use_fp16 else tf.float32
|
||
|
|
||
|
# extra_args: to make it flexible for adding external customizable code
|
||
|
self.single_cell_fn = None
|
||
|
if extra_args:
|
||
|
self.single_cell_fn = extra_args.single_cell_fn
|
||
|
|
||
|
# Set num units
|
||
|
self.num_units = hparams.num_units
|
||
|
# Set num layers
|
||
|
self.num_encoder_layers = hparams.num_encoder_layers
|
||
|
self.num_decoder_layers = hparams.num_decoder_layers
|
||
|
assert self.num_encoder_layers
|
||
|
assert self.num_decoder_layers
|
||
|
|
||
|
# Set num residual layers
|
||
|
if hasattr(hparams, "num_residual_layers"): # compatible common_test_utils
|
||
|
self.num_encoder_residual_layers = hparams.num_residual_layers
|
||
|
self.num_decoder_residual_layers = hparams.num_residual_layers
|
||
|
else:
|
||
|
self.num_encoder_residual_layers = hparams.num_encoder_residual_layers
|
||
|
self.num_decoder_residual_layers = hparams.num_decoder_residual_layers
|
||
|
|
||
|
# Batch size
|
||
|
self.batch_size = tf.size(self.features["source_sequence_length"])
|
||
|
|
||
|
# Global step
|
||
|
global_step = tf.train.get_global_step()
|
||
|
if global_step is not None:
|
||
|
utils.print_out("global_step already created!")
|
||
|
|
||
|
self.global_step = tf.train.get_or_create_global_step()
|
||
|
utils.print_out("model.global_step.name: %s" % self.global_step.name)
|
||
|
|
||
|
# Initializer
|
||
|
self.random_seed = hparams.random_seed
|
||
|
initializer = model_helper.get_initializer(
|
||
|
hparams.init_op, self.random_seed, hparams.init_weight)
|
||
|
tf.get_variable_scope().set_initializer(initializer)
|
||
|
|
||
|
# Embeddings
|
||
|
self.encoder_emb_lookup_fn = tf.nn.embedding_lookup
|
||
|
self.init_embeddings(hparams, scope)
|
||
|
|
||
|
def _set_train_or_infer(self, res, hparams):
|
||
|
"""Set up training."""
|
||
|
loss = res[1]
|
||
|
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
|
||
|
self.train_loss = loss
|
||
|
self.word_count = tf.reduce_sum(
|
||
|
self.features["source_sequence_length"]) + tf.reduce_sum(
|
||
|
self.features["target_sequence_length"])
|
||
|
elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
|
||
|
self.eval_loss = loss
|
||
|
elif self.mode == tf.contrib.learn.ModeKeys.INFER:
|
||
|
self.infer_logits = res[0]
|
||
|
self.infer_loss = loss
|
||
|
self.sample_id = res[2]
|
||
|
|
||
|
if self.mode != tf.contrib.learn.ModeKeys.INFER:
|
||
|
## Count the number of predicted words for compute ppl.
|
||
|
self.predict_count = tf.reduce_sum(
|
||
|
self.features["target_sequence_length"])
|
||
|
|
||
|
# Gradients and SGD update operation for training the model.
|
||
|
# Arrange for the embedding vars to appear at the beginning.
|
||
|
# Only build bprop if running on GPU and using dist_strategy, in which
|
||
|
# case learning rate, grads and train_op are created in estimator model
|
||
|
# function.
|
||
|
with tf.name_scope("learning_rate"):
|
||
|
self.learning_rate = tf.constant(hparams.learning_rate)
|
||
|
# warm-up
|
||
|
self.learning_rate = self._get_learning_rate_warmup(hparams)
|
||
|
# decay
|
||
|
self.learning_rate = self._get_learning_rate_decay(hparams)
|
||
|
|
||
|
if (hparams.use_dist_strategy and
|
||
|
self.mode == tf.contrib.learn.ModeKeys.TRAIN):
|
||
|
# Gradients
|
||
|
params = tf.trainable_variables()
|
||
|
# Print trainable variables
|
||
|
utils.print_out("# Trainable variables")
|
||
|
utils.print_out(
|
||
|
"Format: <name>, <shape>, <dtype>, <(soft) device placement>")
|
||
|
for param in params:
|
||
|
utils.print_out(
|
||
|
" %s, %s, %s, %s" % (param.name, str(param.get_shape()),
|
||
|
param.dtype.name, param.op.device))
|
||
|
utils.print_out("Total params size: %.2f GB" % (4. * np.sum([
|
||
|
p.get_shape().num_elements()
|
||
|
for p in params
|
||
|
if p.shape.is_fully_defined()
|
||
|
]) / 2**30))
|
||
|
|
||
|
# Optimizer
|
||
|
if hparams.optimizer == "sgd":
|
||
|
opt = tf.train.GradientDescentOptimizer(self.learning_rate)
|
||
|
elif hparams.optimizer == "adam":
|
||
|
opt = tf.train.AdamOptimizer(self.learning_rate)
|
||
|
else:
|
||
|
raise ValueError("Unknown optimizer type %s" % hparams.optimizer)
|
||
|
assert opt is not None
|
||
|
|
||
|
grads_and_vars = opt.compute_gradients(
|
||
|
self.train_loss,
|
||
|
params,
|
||
|
colocate_gradients_with_ops=hparams.colocate_gradients_with_ops)
|
||
|
gradients = [x for (x, _) in grads_and_vars]
|
||
|
|
||
|
clipped_grads, grad_norm = model_helper.gradient_clip(
|
||
|
gradients, max_gradient_norm=hparams.max_gradient_norm)
|
||
|
self.grad_norm = grad_norm
|
||
|
self.params = params
|
||
|
self.grads = clipped_grads
|
||
|
|
||
|
self.update = opt.apply_gradients(
|
||
|
list(zip(clipped_grads, params)), global_step=self.global_step)
|
||
|
else:
|
||
|
self.grad_norm = None
|
||
|
self.update = None
|
||
|
self.params = None
|
||
|
self.grads = None
|
||
|
|
||
|
def _get_learning_rate_warmup(self, hparams):
|
||
|
"""Get learning rate warmup."""
|
||
|
warmup_steps = hparams.warmup_steps
|
||
|
warmup_scheme = hparams.warmup_scheme
|
||
|
utils.print_out(" learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" %
|
||
|
(hparams.learning_rate, warmup_steps, warmup_scheme))
|
||
|
if not warmup_scheme:
|
||
|
return self.learning_rate
|
||
|
|
||
|
# Apply inverse decay if global steps less than warmup steps.
|
||
|
# Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3)
|
||
|
# When step < warmup_steps,
|
||
|
# learing_rate *= warmup_factor ** (warmup_steps - step)
|
||
|
if warmup_scheme == "t2t":
|
||
|
# 0.01^(1/warmup_steps): we start with a lr, 100 times smaller
|
||
|
warmup_factor = tf.exp(tf.log(0.01) / warmup_steps)
|
||
|
inv_decay = warmup_factor**(tf.to_float(warmup_steps - self.global_step))
|
||
|
else:
|
||
|
raise ValueError("Unknown warmup scheme %s" % warmup_scheme)
|
||
|
|
||
|
return tf.cond(
|
||
|
self.global_step < hparams.warmup_steps,
|
||
|
lambda: inv_decay * self.learning_rate,
|
||
|
lambda: self.learning_rate,
|
||
|
name="learning_rate_warump_cond")
|
||
|
|
||
|
def _get_decay_info(self, hparams):
|
||
|
"""Return decay info based on decay_scheme."""
|
||
|
if hparams.decay_scheme in [
|
||
|
"luong5", "luong10", "luong234", "jamesqin1616"
|
||
|
]:
|
||
|
epoch_size, _, _ = iterator_utils.get_effective_epoch_size(hparams)
|
||
|
num_train_steps = int(hparams.max_train_epochs * epoch_size / hparams.batch_size)
|
||
|
decay_factor = 0.5
|
||
|
if hparams.decay_scheme == "luong5":
|
||
|
start_decay_step = int(num_train_steps / 2)
|
||
|
decay_times = 5
|
||
|
remain_steps = num_train_steps - start_decay_step
|
||
|
elif hparams.decay_scheme == "luong10":
|
||
|
start_decay_step = int(num_train_steps / 2)
|
||
|
decay_times = 10
|
||
|
remain_steps = num_train_steps - start_decay_step
|
||
|
elif hparams.decay_scheme == "luong234":
|
||
|
start_decay_step = int(num_train_steps * 2 / 3)
|
||
|
decay_times = 4
|
||
|
remain_steps = num_train_steps - start_decay_step
|
||
|
elif hparams.decay_scheme == "jamesqin1616":
|
||
|
# dehao@ reported TPU setting max_epoch = 2 and use luong234.
|
||
|
# They start decay after 2 * 2/3 epochs for 4 times.
|
||
|
# If keep max_epochs = 8 then decay should start at 8 * 2/(3 * 4) epochs
|
||
|
# and for (4 *4 = 16) times.
|
||
|
decay_times = 16
|
||
|
start_decay_step = int(num_train_steps / 16.)
|
||
|
remain_steps = num_train_steps - start_decay_step
|
||
|
decay_steps = int(remain_steps / decay_times)
|
||
|
elif not hparams.decay_scheme: # no decay
|
||
|
start_decay_step = num_train_steps
|
||
|
decay_steps = 0
|
||
|
decay_factor = 1.0
|
||
|
elif hparams.decay_scheme:
|
||
|
raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme)
|
||
|
return start_decay_step, decay_steps, decay_factor
|
||
|
|
||
|
def _get_learning_rate_decay(self, hparams):
|
||
|
"""Get learning rate decay."""
|
||
|
start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams)
|
||
|
utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, "
|
||
|
"decay_factor %g" % (hparams.decay_scheme, start_decay_step,
|
||
|
decay_steps, decay_factor))
|
||
|
|
||
|
return tf.cond(
|
||
|
self.global_step < start_decay_step,
|
||
|
lambda: self.learning_rate,
|
||
|
lambda: tf.train.exponential_decay( # pylint: disable=g-long-lambda
|
||
|
self.learning_rate,
|
||
|
(self.global_step - start_decay_step),
|
||
|
decay_steps, decay_factor, staircase=True),
|
||
|
name="learning_rate_decay_cond")
|
||
|
|
||
|
def init_embeddings(self, hparams, scope):
|
||
|
"""Init embeddings."""
|
||
|
self.embedding_encoder, self.embedding_decoder = (
|
||
|
model_helper.create_emb_for_encoder_and_decoder(
|
||
|
share_vocab=hparams.share_vocab,
|
||
|
src_vocab_size=self.src_vocab_size,
|
||
|
tgt_vocab_size=self.tgt_vocab_size,
|
||
|
src_embed_size=self.num_units,
|
||
|
tgt_embed_size=self.num_units,
|
||
|
dtype=self.dtype,
|
||
|
num_enc_partitions=hparams.num_enc_emb_partitions,
|
||
|
num_dec_partitions=hparams.num_dec_emb_partitions,
|
||
|
src_vocab_file=hparams.src_vocab_file,
|
||
|
tgt_vocab_file=hparams.tgt_vocab_file,
|
||
|
src_embed_file=hparams.src_embed_file,
|
||
|
tgt_embed_file=hparams.tgt_embed_file,
|
||
|
use_char_encode=hparams.use_char_encode,
|
||
|
scope=scope,
|
||
|
))
|
||
|
|
||
|
def build_graph(self, hparams, scope=None):
|
||
|
"""Subclass must implement this method.
|
||
|
|
||
|
Creates a sequence-to-sequence model with dynamic RNN decoder API.
|
||
|
Args:
|
||
|
hparams: Hyperparameter configurations.
|
||
|
scope: VariableScope for the created subgraph; default "dynamic_seq2seq".
|
||
|
|
||
|
Returns:
|
||
|
A tuple of the form (logits, loss_tuple, final_context_state, sample_id),
|
||
|
where:
|
||
|
logits: float32 Tensor [batch_size x num_decoder_symbols].
|
||
|
loss: loss = the total loss / batch_size.
|
||
|
final_context_state: the final state of decoder RNN.
|
||
|
sample_id: sampling indices.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if encoder_type differs from mono and bi, or
|
||
|
attention_option is not (luong | scaled_luong |
|
||
|
bahdanau | normed_bahdanau).
|
||
|
"""
|
||
|
utils.print_out("# Creating %s graph ..." % self.mode)
|
||
|
|
||
|
# Projection
|
||
|
with tf.variable_scope(scope or "build_network"):
|
||
|
with tf.variable_scope("decoder/output_projection"):
|
||
|
self.output_layer = tf.layers.Dense(
|
||
|
self.tgt_vocab_size, use_bias=False, name="output_projection",
|
||
|
dtype=self.dtype)
|
||
|
|
||
|
with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
|
||
|
# Encoder
|
||
|
if hparams.language_model: # no encoder for language modeling
|
||
|
utils.print_out(" language modeling: no encoder")
|
||
|
self.encoder_outputs = None
|
||
|
encoder_state = None
|
||
|
else:
|
||
|
self.encoder_outputs, encoder_state = self._build_encoder(hparams)
|
||
|
|
||
|
## Decoder
|
||
|
logits, sample_id = (
|
||
|
self._build_decoder(self.encoder_outputs, encoder_state, hparams))
|
||
|
|
||
|
## Loss
|
||
|
if self.mode != tf.contrib.learn.ModeKeys.INFER:
|
||
|
loss = self._compute_loss(logits, hparams.label_smoothing)
|
||
|
else:
|
||
|
loss = tf.constant(0.0)
|
||
|
|
||
|
return logits, loss, sample_id
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def _build_encoder(self, hparams):
|
||
|
"""Subclass must implement this.
|
||
|
|
||
|
Build and run an RNN encoder.
|
||
|
|
||
|
Args:
|
||
|
hparams: Hyperparameters configurations.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of encoder_outputs and encoder_state.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def _get_infer_maximum_iterations(self, hparams, source_sequence_length):
|
||
|
"""Maximum decoding steps at inference time."""
|
||
|
if hparams.tgt_max_len_infer:
|
||
|
maximum_iterations = hparams.tgt_max_len_infer
|
||
|
utils.print_out(" decoding maximum_iterations %d" % maximum_iterations)
|
||
|
else:
|
||
|
# TODO(thangluong): add decoding_length_factor flag
|
||
|
decoding_length_factor = 2.0
|
||
|
max_encoder_length = tf.reduce_max(source_sequence_length)
|
||
|
maximum_iterations = tf.to_int32(
|
||
|
tf.round(tf.to_float(max_encoder_length) * decoding_length_factor))
|
||
|
return maximum_iterations
|
||
|
|
||
|
def _build_decoder(self, encoder_outputs, encoder_state, hparams):
|
||
|
"""Build and run a RNN decoder with a final projection layer.
|
||
|
|
||
|
Args:
|
||
|
encoder_outputs: The outputs of encoder for every time step.
|
||
|
encoder_state: The final state of the encoder.
|
||
|
hparams: The Hyperparameters configurations.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of final logits and final decoder state:
|
||
|
logits: size [time, batch_size, vocab_size] when time_major=True.
|
||
|
"""
|
||
|
|
||
|
## Decoder.
|
||
|
with tf.variable_scope("decoder") as decoder_scope:
|
||
|
|
||
|
## Train or eval
|
||
|
if self.mode != tf.contrib.learn.ModeKeys.INFER:
|
||
|
# [batch, time]
|
||
|
target_input = self.features["target_input"]
|
||
|
if self.time_major:
|
||
|
# If using time_major mode, then target_input should be [time, batch]
|
||
|
# then the decoder_emb_inp would be [time, batch, dim]
|
||
|
target_input = tf.transpose(target_input)
|
||
|
decoder_emb_inp = tf.cast(
|
||
|
tf.nn.embedding_lookup(self.embedding_decoder, target_input),
|
||
|
self.dtype)
|
||
|
|
||
|
if not hparams.use_fused_lstm_dec:
|
||
|
cell, decoder_initial_state = self._build_decoder_cell(
|
||
|
hparams, encoder_outputs, encoder_state,
|
||
|
self.features["source_sequence_length"])
|
||
|
|
||
|
if hparams.use_dynamic_rnn:
|
||
|
final_rnn_outputs, _ = tf.nn.dynamic_rnn(
|
||
|
cell,
|
||
|
decoder_emb_inp,
|
||
|
sequence_length=self.features["target_sequence_length"],
|
||
|
initial_state=decoder_initial_state,
|
||
|
dtype=self.dtype,
|
||
|
scope=decoder_scope,
|
||
|
parallel_iterations=hparams.parallel_iterations,
|
||
|
time_major=self.time_major)
|
||
|
else:
|
||
|
final_rnn_outputs, _ = tf.contrib.recurrent.functional_rnn(
|
||
|
cell,
|
||
|
decoder_emb_inp,
|
||
|
sequence_length=tf.to_int32(
|
||
|
self.features["target_sequence_length"]),
|
||
|
initial_state=decoder_initial_state,
|
||
|
dtype=self.dtype,
|
||
|
scope=decoder_scope,
|
||
|
time_major=self.time_major,
|
||
|
use_tpu=False)
|
||
|
else:
|
||
|
if hparams.pass_hidden_state:
|
||
|
decoder_initial_state = encoder_state
|
||
|
else:
|
||
|
decoder_initial_state = tuple((tf.nn.rnn_cell.LSTMStateTuple(
|
||
|
tf.zeros_like(s[0]), tf.zeros_like(s[1])) for s in encoder_state))
|
||
|
final_rnn_outputs = self._build_decoder_fused_for_training(
|
||
|
encoder_outputs, decoder_initial_state, decoder_emb_inp, self.hparams)
|
||
|
|
||
|
# We chose to apply the output_layer to all timesteps for speed:
|
||
|
# 10% improvements for small models & 20% for larger ones.
|
||
|
# If memory is a concern, we should apply output_layer per timestep.
|
||
|
logits = self.output_layer(final_rnn_outputs)
|
||
|
sample_id = None
|
||
|
## Inference
|
||
|
else:
|
||
|
cell, decoder_initial_state = self._build_decoder_cell(
|
||
|
hparams, encoder_outputs, encoder_state,
|
||
|
self.features["source_sequence_length"])
|
||
|
|
||
|
assert hparams.infer_mode == "beam_search"
|
||
|
_, tgt_vocab_table = vocab_utils.create_vocab_tables(
|
||
|
hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab)
|
||
|
tgt_sos_id = tf.cast(
|
||
|
tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32)
|
||
|
tgt_eos_id = tf.cast(
|
||
|
tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32)
|
||
|
start_tokens = tf.fill([self.batch_size], tgt_sos_id)
|
||
|
end_token = tgt_eos_id
|
||
|
beam_width = hparams.beam_width
|
||
|
length_penalty_weight = hparams.length_penalty_weight
|
||
|
coverage_penalty_weight = hparams.coverage_penalty_weight
|
||
|
|
||
|
my_decoder = beam_search_decoder.BeamSearchDecoder(
|
||
|
cell=cell,
|
||
|
embedding=self.embedding_decoder,
|
||
|
start_tokens=start_tokens,
|
||
|
end_token=end_token,
|
||
|
initial_state=decoder_initial_state,
|
||
|
beam_width=beam_width,
|
||
|
output_layer=self.output_layer,
|
||
|
length_penalty_weight=length_penalty_weight,
|
||
|
coverage_penalty_weight=coverage_penalty_weight)
|
||
|
|
||
|
# maximum_iteration: The maximum decoding steps.
|
||
|
maximum_iterations = self._get_infer_maximum_iterations(
|
||
|
hparams, self.features["source_sequence_length"])
|
||
|
|
||
|
# Dynamic decoding
|
||
|
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
|
||
|
my_decoder,
|
||
|
maximum_iterations=maximum_iterations,
|
||
|
output_time_major=self.time_major,
|
||
|
swap_memory=True,
|
||
|
scope=decoder_scope)
|
||
|
|
||
|
logits = tf.no_op()
|
||
|
sample_id = outputs.predicted_ids
|
||
|
|
||
|
return logits, sample_id
|
||
|
|
||
|
def get_max_time(self, tensor):
|
||
|
time_axis = 0 if self.time_major else 1
|
||
|
return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis]
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
|
||
|
source_sequence_length):
|
||
|
"""Subclass must implement this.
|
||
|
|
||
|
Args:
|
||
|
hparams: Hyperparameters configurations.
|
||
|
encoder_outputs: The outputs of encoder for every time step.
|
||
|
encoder_state: The final state of the encoder.
|
||
|
source_sequence_length: sequence length of encoder_outputs.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of a multi-layer RNN cell used by decoder and the initial state of
|
||
|
the decoder RNN.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def _softmax_cross_entropy_loss(self, logits, labels, label_smoothing):
|
||
|
"""Compute softmax loss or sampled softmax loss."""
|
||
|
use_defun = os.environ["use_defun"] == "true"
|
||
|
use_xla = os.environ["use_xla"] == "true"
|
||
|
|
||
|
# @function.Defun(noinline=True, compiled=use_xla)
|
||
|
def ComputePositiveCrossent(labels, logits):
|
||
|
crossent = math_utils.sparse_softmax_crossent_with_logits(
|
||
|
labels=labels, logits=logits)
|
||
|
return crossent
|
||
|
crossent = ComputePositiveCrossent(labels, logits)
|
||
|
assert crossent.dtype == tf.float32
|
||
|
|
||
|
def _safe_shape_div(x, y):
|
||
|
"""Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
|
||
|
return x // tf.maximum(y, 1)
|
||
|
|
||
|
@function.Defun(tf.float32, tf.float32, compiled=use_xla)
|
||
|
def ReduceSumGrad(x, grad):
|
||
|
"""docstring."""
|
||
|
input_shape = tf.shape(x)
|
||
|
# TODO(apassos) remove this once device placement for eager ops makes more
|
||
|
# sense.
|
||
|
with tf.colocate_with(input_shape):
|
||
|
output_shape_kept_dims = math_ops.reduced_shape(input_shape, -1)
|
||
|
tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
|
||
|
grad = tf.reshape(grad, output_shape_kept_dims)
|
||
|
return tf.tile(grad, tile_scaling)
|
||
|
|
||
|
def ReduceSum(x):
|
||
|
"""docstring."""
|
||
|
return tf.reduce_sum(x, axis=-1)
|
||
|
if use_defun:
|
||
|
ReduceSum = function.Defun(
|
||
|
tf.float32,
|
||
|
compiled=use_xla,
|
||
|
noinline=True,
|
||
|
grad_func=ReduceSumGrad)(ReduceSum)
|
||
|
|
||
|
if abs(label_smoothing) > 1e-3:
|
||
|
# pylint:disable=invalid-name
|
||
|
def ComputeNegativeCrossentFwd(logits):
|
||
|
"""docstring."""
|
||
|
# [time, batch, dim]
|
||
|
# [time, batch]
|
||
|
max_logits = tf.reduce_max(logits, axis=-1)
|
||
|
# [time, batch, dim]
|
||
|
shifted_logits = logits - tf.expand_dims(max_logits, axis=-1)
|
||
|
# Always compute loss in fp32
|
||
|
shifted_logits = tf.to_float(shifted_logits)
|
||
|
# [time, batch]
|
||
|
log_sum_exp = tf.log(ReduceSum(tf.exp(shifted_logits)))
|
||
|
# [time, batch, dim] - [time, batch, 1] --> reduce_sum(-1) -->
|
||
|
# [time, batch]
|
||
|
neg_crossent = ReduceSum(
|
||
|
shifted_logits - tf.expand_dims(log_sum_exp, axis=-1))
|
||
|
return neg_crossent
|
||
|
|
||
|
def ComputeNegativeCrossent(logits):
|
||
|
return ComputeNegativeCrossentFwd(logits)
|
||
|
|
||
|
if use_defun:
|
||
|
ComputeNegativeCrossent = function.Defun(
|
||
|
compiled=use_xla)(ComputeNegativeCrossent)
|
||
|
|
||
|
neg_crossent = ComputeNegativeCrossent(logits)
|
||
|
neg_crossent = tf.to_float(neg_crossent)
|
||
|
num_labels = logits.shape[-1].value
|
||
|
crossent = (1.0 - label_smoothing) * crossent - (
|
||
|
label_smoothing / tf.to_float(num_labels) * neg_crossent)
|
||
|
# pylint:enable=invalid-name
|
||
|
|
||
|
return crossent
|
||
|
|
||
|
def _compute_loss(self, logits, label_smoothing):
|
||
|
"""Compute optimization loss."""
|
||
|
target_output = self.features["target_output"]
|
||
|
if self.time_major:
|
||
|
target_output = tf.transpose(target_output)
|
||
|
max_time = self.get_max_time(target_output)
|
||
|
self.batch_seq_len = max_time
|
||
|
|
||
|
crossent = self._softmax_cross_entropy_loss(
|
||
|
logits, target_output, label_smoothing)
|
||
|
assert crossent.dtype == tf.float32
|
||
|
|
||
|
target_weights = tf.sequence_mask(
|
||
|
self.features["target_sequence_length"], max_time, dtype=crossent.dtype)
|
||
|
if self.time_major:
|
||
|
# [time, batch] if time_major, since the crossent is [time, batch] in this
|
||
|
# case.
|
||
|
target_weights = tf.transpose(target_weights)
|
||
|
|
||
|
loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(
|
||
|
self.batch_size)
|
||
|
|
||
|
return loss
|
||
|
|
||
|
def build_encoder_states(self, include_embeddings=False):
|
||
|
"""Stack encoder states and return tensor [batch, length, layer, size]."""
|
||
|
assert self.mode == tf.contrib.learn.ModeKeys.INFER
|
||
|
if include_embeddings:
|
||
|
stack_state_list = tf.stack(
|
||
|
[self.encoder_emb_inp] + self.encoder_state_list, 2)
|
||
|
else:
|
||
|
stack_state_list = tf.stack(self.encoder_state_list, 2)
|
||
|
|
||
|
# transform from [length, batch, ...] -> [batch, length, ...]
|
||
|
if self.time_major:
|
||
|
stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3])
|
||
|
|
||
|
return stack_state_list
|