# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Basic sequence-to-sequence model with dynamic RNN support.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import collections import os import tensorflow as tf import numpy as np from tensorflow.python.framework import function from tensorflow.python.ops import math_ops import attention_wrapper import model_helper import beam_search_decoder from utils import iterator_utils from utils import math_utils from utils import misc_utils as utils from utils import vocab_utils utils.check_tensorflow_version() __all__ = ["BaseModel"] def create_attention_mechanism( num_units, memory, source_sequence_length, dtype=None): """Create attention mechanism based on the attention_option.""" # Mechanism attention_mechanism = attention_wrapper.BahdanauAttention( num_units, memory, memory_sequence_length=tf.to_int64(source_sequence_length), normalize=True, dtype=dtype) return attention_mechanism class BaseModel(object): """Sequence-to-sequence base class. """ def __init__(self, hparams, mode, features, scope=None, extra_args=None): """Create the model. Args: hparams: Hyperparameter configurations. mode: TRAIN | EVAL | INFER features: a dict of input features. scope: scope of the model. extra_args: model_helper.ExtraArgs, for passing customizable functions. """ self.hparams = hparams # Set params self._set_params_initializer(hparams, mode, features, scope, extra_args) # Train graph res = self.build_graph(hparams, scope=scope) self._set_train_or_infer(res, hparams) def _set_params_initializer(self, hparams, mode, features, scope, extra_args=None): """Set various params for self and initialize.""" self.mode = mode self.src_vocab_size = hparams.src_vocab_size self.tgt_vocab_size = hparams.tgt_vocab_size self.features = features self.time_major = hparams.time_major if hparams.use_char_encode: assert (not self.time_major), ("Can't use time major for" " char-level inputs.") self.dtype = tf.float16 if hparams.use_fp16 else tf.float32 # extra_args: to make it flexible for adding external customizable code self.single_cell_fn = None if extra_args: self.single_cell_fn = extra_args.single_cell_fn # Set num units self.num_units = hparams.num_units # Set num layers self.num_encoder_layers = hparams.num_encoder_layers self.num_decoder_layers = hparams.num_decoder_layers assert self.num_encoder_layers assert self.num_decoder_layers # Set num residual layers if hasattr(hparams, "num_residual_layers"): # compatible common_test_utils self.num_encoder_residual_layers = hparams.num_residual_layers self.num_decoder_residual_layers = hparams.num_residual_layers else: self.num_encoder_residual_layers = hparams.num_encoder_residual_layers self.num_decoder_residual_layers = hparams.num_decoder_residual_layers # Batch size self.batch_size = tf.size(self.features["source_sequence_length"]) # Global step global_step = tf.train.get_global_step() if global_step is not None: utils.print_out("global_step already created!") self.global_step = tf.train.get_or_create_global_step() utils.print_out("model.global_step.name: %s" % self.global_step.name) # Initializer self.random_seed = hparams.random_seed initializer = model_helper.get_initializer( hparams.init_op, self.random_seed, hparams.init_weight) tf.get_variable_scope().set_initializer(initializer) # Embeddings self.encoder_emb_lookup_fn = tf.nn.embedding_lookup self.init_embeddings(hparams, scope) def _set_train_or_infer(self, res, hparams): """Set up training.""" loss = res[1] if self.mode == tf.contrib.learn.ModeKeys.TRAIN: self.train_loss = loss self.word_count = tf.reduce_sum( self.features["source_sequence_length"]) + tf.reduce_sum( self.features["target_sequence_length"]) elif self.mode == tf.contrib.learn.ModeKeys.EVAL: self.eval_loss = loss elif self.mode == tf.contrib.learn.ModeKeys.INFER: self.infer_logits = res[0] self.infer_loss = loss self.sample_id = res[2] if self.mode != tf.contrib.learn.ModeKeys.INFER: ## Count the number of predicted words for compute ppl. self.predict_count = tf.reduce_sum( self.features["target_sequence_length"]) # Gradients and SGD update operation for training the model. # Arrange for the embedding vars to appear at the beginning. # Only build bprop if running on GPU and using dist_strategy, in which # case learning rate, grads and train_op are created in estimator model # function. with tf.name_scope("learning_rate"): self.learning_rate = tf.constant(hparams.learning_rate) # warm-up self.learning_rate = self._get_learning_rate_warmup(hparams) # decay self.learning_rate = self._get_learning_rate_decay(hparams) if (hparams.use_dist_strategy and self.mode == tf.contrib.learn.ModeKeys.TRAIN): # Gradients params = tf.trainable_variables() # Print trainable variables utils.print_out("# Trainable variables") utils.print_out( "Format: , , , <(soft) device placement>") for param in params: utils.print_out( " %s, %s, %s, %s" % (param.name, str(param.get_shape()), param.dtype.name, param.op.device)) utils.print_out("Total params size: %.2f GB" % (4. * np.sum([ p.get_shape().num_elements() for p in params if p.shape.is_fully_defined() ]) / 2**30)) # Optimizer if hparams.optimizer == "sgd": opt = tf.train.GradientDescentOptimizer(self.learning_rate) elif hparams.optimizer == "adam": opt = tf.train.AdamOptimizer(self.learning_rate) else: raise ValueError("Unknown optimizer type %s" % hparams.optimizer) assert opt is not None grads_and_vars = opt.compute_gradients( self.train_loss, params, colocate_gradients_with_ops=hparams.colocate_gradients_with_ops) gradients = [x for (x, _) in grads_and_vars] clipped_grads, grad_norm = model_helper.gradient_clip( gradients, max_gradient_norm=hparams.max_gradient_norm) self.grad_norm = grad_norm self.params = params self.grads = clipped_grads self.update = opt.apply_gradients( list(zip(clipped_grads, params)), global_step=self.global_step) else: self.grad_norm = None self.update = None self.params = None self.grads = None def _get_learning_rate_warmup(self, hparams): """Get learning rate warmup.""" warmup_steps = hparams.warmup_steps warmup_scheme = hparams.warmup_scheme utils.print_out(" learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" % (hparams.learning_rate, warmup_steps, warmup_scheme)) if not warmup_scheme: return self.learning_rate # Apply inverse decay if global steps less than warmup steps. # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3) # When step < warmup_steps, # learing_rate *= warmup_factor ** (warmup_steps - step) if warmup_scheme == "t2t": # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller warmup_factor = tf.exp(tf.log(0.01) / warmup_steps) inv_decay = warmup_factor**(tf.to_float(warmup_steps - self.global_step)) else: raise ValueError("Unknown warmup scheme %s" % warmup_scheme) return tf.cond( self.global_step < hparams.warmup_steps, lambda: inv_decay * self.learning_rate, lambda: self.learning_rate, name="learning_rate_warump_cond") def _get_decay_info(self, hparams): """Return decay info based on decay_scheme.""" if hparams.decay_scheme in [ "luong5", "luong10", "luong234", "jamesqin1616" ]: epoch_size, _, _ = iterator_utils.get_effective_epoch_size(hparams) num_train_steps = int(hparams.max_train_epochs * epoch_size / hparams.batch_size) decay_factor = 0.5 if hparams.decay_scheme == "luong5": start_decay_step = int(num_train_steps / 2) decay_times = 5 remain_steps = num_train_steps - start_decay_step elif hparams.decay_scheme == "luong10": start_decay_step = int(num_train_steps / 2) decay_times = 10 remain_steps = num_train_steps - start_decay_step elif hparams.decay_scheme == "luong234": start_decay_step = int(num_train_steps * 2 / 3) decay_times = 4 remain_steps = num_train_steps - start_decay_step elif hparams.decay_scheme == "jamesqin1616": # dehao@ reported TPU setting max_epoch = 2 and use luong234. # They start decay after 2 * 2/3 epochs for 4 times. # If keep max_epochs = 8 then decay should start at 8 * 2/(3 * 4) epochs # and for (4 *4 = 16) times. decay_times = 16 start_decay_step = int(num_train_steps / 16.) remain_steps = num_train_steps - start_decay_step decay_steps = int(remain_steps / decay_times) elif not hparams.decay_scheme: # no decay start_decay_step = num_train_steps decay_steps = 0 decay_factor = 1.0 elif hparams.decay_scheme: raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme) return start_decay_step, decay_steps, decay_factor def _get_learning_rate_decay(self, hparams): """Get learning rate decay.""" start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams) utils.print_out(" decay_scheme=%s, start_decay_step=%d, decay_steps %d, " "decay_factor %g" % (hparams.decay_scheme, start_decay_step, decay_steps, decay_factor)) return tf.cond( self.global_step < start_decay_step, lambda: self.learning_rate, lambda: tf.train.exponential_decay( # pylint: disable=g-long-lambda self.learning_rate, (self.global_step - start_decay_step), decay_steps, decay_factor, staircase=True), name="learning_rate_decay_cond") def init_embeddings(self, hparams, scope): """Init embeddings.""" self.embedding_encoder, self.embedding_decoder = ( model_helper.create_emb_for_encoder_and_decoder( share_vocab=hparams.share_vocab, src_vocab_size=self.src_vocab_size, tgt_vocab_size=self.tgt_vocab_size, src_embed_size=self.num_units, tgt_embed_size=self.num_units, dtype=self.dtype, num_enc_partitions=hparams.num_enc_emb_partitions, num_dec_partitions=hparams.num_dec_emb_partitions, src_vocab_file=hparams.src_vocab_file, tgt_vocab_file=hparams.tgt_vocab_file, src_embed_file=hparams.src_embed_file, tgt_embed_file=hparams.tgt_embed_file, use_char_encode=hparams.use_char_encode, scope=scope, )) def build_graph(self, hparams, scope=None): """Subclass must implement this method. Creates a sequence-to-sequence model with dynamic RNN decoder API. Args: hparams: Hyperparameter configurations. scope: VariableScope for the created subgraph; default "dynamic_seq2seq". Returns: A tuple of the form (logits, loss_tuple, final_context_state, sample_id), where: logits: float32 Tensor [batch_size x num_decoder_symbols]. loss: loss = the total loss / batch_size. final_context_state: the final state of decoder RNN. sample_id: sampling indices. Raises: ValueError: if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau). """ utils.print_out("# Creating %s graph ..." % self.mode) # Projection with tf.variable_scope(scope or "build_network"): with tf.variable_scope("decoder/output_projection"): self.output_layer = tf.layers.Dense( self.tgt_vocab_size, use_bias=False, name="output_projection", dtype=self.dtype) with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype): # Encoder if hparams.language_model: # no encoder for language modeling utils.print_out(" language modeling: no encoder") self.encoder_outputs = None encoder_state = None else: self.encoder_outputs, encoder_state = self._build_encoder(hparams) ## Decoder logits, sample_id = ( self._build_decoder(self.encoder_outputs, encoder_state, hparams)) ## Loss if self.mode != tf.contrib.learn.ModeKeys.INFER: loss = self._compute_loss(logits, hparams.label_smoothing) else: loss = tf.constant(0.0) return logits, loss, sample_id @abc.abstractmethod def _build_encoder(self, hparams): """Subclass must implement this. Build and run an RNN encoder. Args: hparams: Hyperparameters configurations. Returns: A tuple of encoder_outputs and encoder_state. """ pass def _get_infer_maximum_iterations(self, hparams, source_sequence_length): """Maximum decoding steps at inference time.""" if hparams.tgt_max_len_infer: maximum_iterations = hparams.tgt_max_len_infer utils.print_out(" decoding maximum_iterations %d" % maximum_iterations) else: # TODO(thangluong): add decoding_length_factor flag decoding_length_factor = 2.0 max_encoder_length = tf.reduce_max(source_sequence_length) maximum_iterations = tf.to_int32( tf.round(tf.to_float(max_encoder_length) * decoding_length_factor)) return maximum_iterations def _build_decoder(self, encoder_outputs, encoder_state, hparams): """Build and run a RNN decoder with a final projection layer. Args: encoder_outputs: The outputs of encoder for every time step. encoder_state: The final state of the encoder. hparams: The Hyperparameters configurations. Returns: A tuple of final logits and final decoder state: logits: size [time, batch_size, vocab_size] when time_major=True. """ ## Decoder. with tf.variable_scope("decoder") as decoder_scope: ## Train or eval if self.mode != tf.contrib.learn.ModeKeys.INFER: # [batch, time] target_input = self.features["target_input"] if self.time_major: # If using time_major mode, then target_input should be [time, batch] # then the decoder_emb_inp would be [time, batch, dim] target_input = tf.transpose(target_input) decoder_emb_inp = tf.cast( tf.nn.embedding_lookup(self.embedding_decoder, target_input), self.dtype) if not hparams.use_fused_lstm_dec: cell, decoder_initial_state = self._build_decoder_cell( hparams, encoder_outputs, encoder_state, self.features["source_sequence_length"]) if hparams.use_dynamic_rnn: final_rnn_outputs, _ = tf.nn.dynamic_rnn( cell, decoder_emb_inp, sequence_length=self.features["target_sequence_length"], initial_state=decoder_initial_state, dtype=self.dtype, scope=decoder_scope, parallel_iterations=hparams.parallel_iterations, time_major=self.time_major) else: final_rnn_outputs, _ = tf.contrib.recurrent.functional_rnn( cell, decoder_emb_inp, sequence_length=tf.to_int32( self.features["target_sequence_length"]), initial_state=decoder_initial_state, dtype=self.dtype, scope=decoder_scope, time_major=self.time_major, use_tpu=False) else: if hparams.pass_hidden_state: decoder_initial_state = encoder_state else: decoder_initial_state = tuple((tf.nn.rnn_cell.LSTMStateTuple( tf.zeros_like(s[0]), tf.zeros_like(s[1])) for s in encoder_state)) final_rnn_outputs = self._build_decoder_fused_for_training( encoder_outputs, decoder_initial_state, decoder_emb_inp, self.hparams) # We chose to apply the output_layer to all timesteps for speed: # 10% improvements for small models & 20% for larger ones. # If memory is a concern, we should apply output_layer per timestep. logits = self.output_layer(final_rnn_outputs) sample_id = None ## Inference else: cell, decoder_initial_state = self._build_decoder_cell( hparams, encoder_outputs, encoder_state, self.features["source_sequence_length"]) assert hparams.infer_mode == "beam_search" _, tgt_vocab_table = vocab_utils.create_vocab_tables( hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab) tgt_sos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id beam_width = hparams.beam_width length_penalty_weight = hparams.length_penalty_weight coverage_penalty_weight = hparams.coverage_penalty_weight my_decoder = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=self.embedding_decoder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=self.output_layer, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight) # maximum_iteration: The maximum decoding steps. maximum_iterations = self._get_infer_maximum_iterations( hparams, self.features["source_sequence_length"]) # Dynamic decoding outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( my_decoder, maximum_iterations=maximum_iterations, output_time_major=self.time_major, swap_memory=True, scope=decoder_scope) logits = tf.no_op() sample_id = outputs.predicted_ids return logits, sample_id def get_max_time(self, tensor): time_axis = 0 if self.time_major else 1 return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis] @abc.abstractmethod def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Subclass must implement this. Args: hparams: Hyperparameters configurations. encoder_outputs: The outputs of encoder for every time step. encoder_state: The final state of the encoder. source_sequence_length: sequence length of encoder_outputs. Returns: A tuple of a multi-layer RNN cell used by decoder and the initial state of the decoder RNN. """ pass def _softmax_cross_entropy_loss(self, logits, labels, label_smoothing): """Compute softmax loss or sampled softmax loss.""" use_defun = os.environ["use_defun"] == "true" use_xla = os.environ["use_xla"] == "true" # @function.Defun(noinline=True, compiled=use_xla) def ComputePositiveCrossent(labels, logits): crossent = math_utils.sparse_softmax_crossent_with_logits( labels=labels, logits=logits) return crossent crossent = ComputePositiveCrossent(labels, logits) assert crossent.dtype == tf.float32 def _safe_shape_div(x, y): """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" return x // tf.maximum(y, 1) @function.Defun(tf.float32, tf.float32, compiled=use_xla) def ReduceSumGrad(x, grad): """docstring.""" input_shape = tf.shape(x) # TODO(apassos) remove this once device placement for eager ops makes more # sense. with tf.colocate_with(input_shape): output_shape_kept_dims = math_ops.reduced_shape(input_shape, -1) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = tf.reshape(grad, output_shape_kept_dims) return tf.tile(grad, tile_scaling) def ReduceSum(x): """docstring.""" return tf.reduce_sum(x, axis=-1) if use_defun: ReduceSum = function.Defun( tf.float32, compiled=use_xla, noinline=True, grad_func=ReduceSumGrad)(ReduceSum) if abs(label_smoothing) > 1e-3: # pylint:disable=invalid-name def ComputeNegativeCrossentFwd(logits): """docstring.""" # [time, batch, dim] # [time, batch] max_logits = tf.reduce_max(logits, axis=-1) # [time, batch, dim] shifted_logits = logits - tf.expand_dims(max_logits, axis=-1) # Always compute loss in fp32 shifted_logits = tf.to_float(shifted_logits) # [time, batch] log_sum_exp = tf.log(ReduceSum(tf.exp(shifted_logits))) # [time, batch, dim] - [time, batch, 1] --> reduce_sum(-1) --> # [time, batch] neg_crossent = ReduceSum( shifted_logits - tf.expand_dims(log_sum_exp, axis=-1)) return neg_crossent def ComputeNegativeCrossent(logits): return ComputeNegativeCrossentFwd(logits) if use_defun: ComputeNegativeCrossent = function.Defun( compiled=use_xla)(ComputeNegativeCrossent) neg_crossent = ComputeNegativeCrossent(logits) neg_crossent = tf.to_float(neg_crossent) num_labels = logits.shape[-1].value crossent = (1.0 - label_smoothing) * crossent - ( label_smoothing / tf.to_float(num_labels) * neg_crossent) # pylint:enable=invalid-name return crossent def _compute_loss(self, logits, label_smoothing): """Compute optimization loss.""" target_output = self.features["target_output"] if self.time_major: target_output = tf.transpose(target_output) max_time = self.get_max_time(target_output) self.batch_seq_len = max_time crossent = self._softmax_cross_entropy_loss( logits, target_output, label_smoothing) assert crossent.dtype == tf.float32 target_weights = tf.sequence_mask( self.features["target_sequence_length"], max_time, dtype=crossent.dtype) if self.time_major: # [time, batch] if time_major, since the crossent is [time, batch] in this # case. target_weights = tf.transpose(target_weights) loss = tf.reduce_sum(crossent * target_weights) / tf.to_float( self.batch_size) return loss def build_encoder_states(self, include_embeddings=False): """Stack encoder states and return tensor [batch, length, layer, size].""" assert self.mode == tf.contrib.learn.ModeKeys.INFER if include_embeddings: stack_state_list = tf.stack( [self.encoder_emb_inp] + self.encoder_state_list, 2) else: stack_state_list = tf.stack(self.encoder_state_list, 2) # transform from [length, batch, ...] -> [batch, length, ...] if self.time_major: stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3]) return stack_state_list