DeepLearningExamples/FasterTransformer/v3.0/sample/tensorflow/utils/decoding.py
byshiue b2e89e6e80
[FT] FasterTransformer 3.0 Release (#696)
[FT] feat: Add FasterTransformer v3.0

1. Add supporting of INT8 quantization of cpp and TensorFlow op.
2. Provide the tools to quantize the model.
3. Fix the bugs that cmake 3.15 and 3.16 cannot build this project. 
4. Deprecate the FasterTransformer v1
2020-09-23 10:03:37 +08:00

835 lines
51 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf
import os
import pickle
import sys
from utils.decoder import tf_decoder
from utils.decoder import op_decoder
from utils.decoder import init_op_cache
from utils.decoder import init_tf_cache
from utils.common import create_initializer
from utils.common import _get_shape_invariants
from utils.position import SinusoidalPositionEncoder
from utils.beam_search import search_word
from utils.sampling import Sampling
def initialize_decoding_variables(decoding_args, batchxbeam):
start_ids = tf.fill([batchxbeam], decoding_args.start_id) # [batch_size * beam_width]
step = tf.constant(0, dtype=tf.int32)
# save the output ids for each step
outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
cache = init_tf_cache(batchxbeam,
decoding_args.decoder_args.head_num, decoding_args.decoder_args.size_per_head,
decoding_args.decoder_args.num_layer, dtype=decoding_args.decoder_args.dtype, num_sources=1)
finished = tf.zeros([batchxbeam], dtype=tf.bool) # [batch_size * beam_width], record that a sentence is finished or not
initial_log_probs = tf.cast(tf.tile([0.] + [-float("inf")] * (decoding_args.decoder_args.beam_width - 1),
[batchxbeam / decoding_args.decoder_args.beam_width]), dtype=tf.float32) # [batch_size * beam_width]
# [batch_size * beam_width], record the lengths of all sentences
sequence_lengths = tf.zeros([batchxbeam], dtype=tf.int32)
# record the beam search indices, used for rebuild the whole sentence in the final
parent_ids = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
extra_vars = tuple([parent_ids, sequence_lengths])
return start_ids, step, outputs, cache, finished, initial_log_probs, sequence_lengths, extra_vars
def generate_encoder_result(batch_size,
max_seq_len,
memory_hidden_dim,
dtype):
memory_sequence_length = np.random.randint(
1, max_seq_len + 1, size=batch_size).astype(np.int32)
memory_sequence_length[np.random.randint(0, batch_size)] = max_seq_len
outter_embbeding = np.random.randn(memory_hidden_dim) * 0.01
memory = []
mem_max_seq_len = np.max(memory_sequence_length)
for i in range(batch_size):
data = np.random.randn(mem_max_seq_len, memory_hidden_dim) * 0.01
for j in range(memory_sequence_length[i], mem_max_seq_len):
data[j] = outter_embbeding
memory.append(data)
memory = np.asarray(memory)
memory = tf.convert_to_tensor(memory, dtype=dtype)
return memory, memory_sequence_length
def finalize(beam_width, parent_ids, sequence_lengths, outputs, end_id, max_seq_len=None):
maximum_lengths = tf.reduce_max(tf.reshape(
sequence_lengths, [-1, beam_width]), axis=-1)
if max_seq_len != None:
array_shape = [max_seq_len, -1, beam_width]
else:
array_shape = [tf.reduce_max(maximum_lengths), -1, beam_width]
step_ids = tf.reshape(outputs, array_shape)
parent_ids = tf.reshape(parent_ids, array_shape)
ids = tf.contrib.seq2seq.gather_tree(
step_ids, parent_ids, maximum_lengths, end_id)
ids = tf.transpose(ids, perm=[1, 2, 0])
lengths = tf.not_equal(ids, end_id)
lengths = tf.cast(lengths, tf.int32)
lengths = tf.reduce_sum(lengths, axis=-1)
return ids, lengths
def decoding_body(word_ids,
step,
memory,
memory_sequence_length,
my_cache,
op_self_cache,
op_mem_cache,
embedding_table,
decoding_args,
decoder_type):
decoder_args = decoding_args.decoder_args
hidden_dim = decoder_args.hidden_dim
k_init_range = decoder_args.kernel_init_range
data_type = decoder_args.dtype
batchxbeam = tf.shape(word_ids)[0]
# [batch_size * beam_width, hidden_dim]
inputs = tf.nn.embedding_lookup(embedding_table, word_ids)
# [batch_size * beam_width, 1, hidden_dim]
inputs = tf.expand_dims(inputs, 1)
inputs *= hidden_dim**0.5
position_encoder = SinusoidalPositionEncoder()
if position_encoder is not None:
position_encoding_table = position_encoder._create_position_encoding_table(decoding_args.max_seq_len, hidden_dim, data_type)
position_encoding_val = position_encoding_table[step]
position_encoding_val = tf.reshape(position_encoding_val, [1, 1, -1])
position_encoding_val = tf.tile(position_encoding_val, [batchxbeam, 1, 1])
inputs = inputs + position_encoding_val
with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
tf_result = tf_decoder(decoder_args=decoder_args,
inputs=inputs,
memory=memory,
memory_sequence_length=memory_sequence_length,
step=step,
cache=my_cache)
if decoder_type != 0:
decoder_vars = tf.global_variables()
decoder_vars_start_id = 0
while decoder_vars_start_id < len(decoder_vars):
if decoder_vars[decoder_vars_start_id].name.find("transformer/decoder/layer") != -1:
break
decoder_vars_start_id += 1
decoder_vars = decoder_vars[decoder_vars_start_id:]
decoder_var_dict = {}
for v in decoder_vars:
decoder_var_dict[v.name] = v
psuedo_input = []
if decoder_type == 2:
psuedo_input = tf_result
op_result, op_self_cache, op_mem_cache = op_decoder(inputs,
memory,
memory_sequence_length,
op_self_cache,
op_mem_cache,
psuedo_input,
decoder_var_dict,
decoder_args)
result = None
if decoder_type == 0:
result = tf_result
elif decoder_type == 1:
result = op_result
elif decoder_type == 2:
result = tf_result
result_2 = op_result
flatten_result = tf.reshape(result, [-1])
flatten_result_2 = tf.reshape(result_2, [-1])
abs_diff = tf.math.abs(flatten_result - flatten_result_2)
abs_argmax = tf.math.argmax(abs_diff)
result = tf.Print(result, ["[INFO][PYTHON] step:", step,
tf.cond(abs_diff[abs_argmax] / (tf.math.abs(flatten_result[abs_argmax]) + 1e-6) < decoder_args.check_threshold,
lambda: "True", lambda: "False"),
"max abs diff: ", abs_diff[abs_argmax],
" op val: ", flatten_result_2[abs_argmax],
" tf val: ", flatten_result[abs_argmax] ])
else:
print("[TF][ERROR] decoder type is only 0 or 1 or 2.")
exit(-1)
result = tf.contrib.layers.layer_norm(result, begin_norm_axis=-1)
# [batch_size * beam_width, hidden_dim]
result = tf.squeeze(result, axis=1)
logits = tf.layers.dense(result,
decoding_args.vocab_size,
use_bias=True,
bias_initializer=create_initializer(0.0, data_type),
kernel_initializer=create_initializer(k_init_range, data_type),
activation=None)
return logits, my_cache, op_self_cache, op_mem_cache
def tf_beamsearch_decoding(memory_tensor,
memory_sequence_length,
embedding_table,
decoding_args,
decoder_type):
'''
Run the decoding with beam search by TensorFlow.
Args:
memory_tensor: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension].
The results of encoder transformer layer. The rank must be 3.
Note that it must be extended by beam_width times.
memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int.
The lenght of each sentence of results of encoder.
Note that it must be extended by beam_width times.
embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension].
The embedding table of embedding lookup for each step.
decoder_args: The arguments for decoding. The details are in the class "DecodingBeamsearchArgument" of common.py
decoder_type: A int value. Choose to using TensorFlow decoder, FasterTransformer decoder, or both.
If it is 0, then using the TensorFlow decoder only.
If it is 1, then using the FasterTransformer decoder only.
If it is 2, then using both decoder and compare their result.
Outputs:
finalized_tf_output_ids: A tf.Tensor with shape [batch_size, beam_width, max(tf_sequence_lengths)], with tf.int type.
Finalized tf_output_ids by beam search algorithm and tf_parent_ids.
finalized_tf_sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type.
Finalized tf_sequence_lengths by beam search algorithm and tf_parent_ids.
tf_output_ids: A tf.Tensor with shape [batch_size, beam_width, max(tf_sequence_lengths)], with tf.int type.
The results of decoding. It contains the id of token of vocabulary.
tf_parent_ids: A tf.Tensor with shape [batch_size, beam_width, max(tf_sequence_lengths)], with tf.int type.
The beam index of output ids for each step.
tf_sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type.
'''
decoder_args = decoding_args.decoder_args
beam_width = decoder_args.beam_width
search_method = decoding_args.search_method
with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE):
# copy memory and memory_sequence_length by beam_width times
# if memory is [a, b, c], beam_width = 3, then the result is: [a a a b b b c c c ]
extended_memory = tf.contrib.seq2seq.tile_batch(memory_tensor, multiplier=beam_width)
extended_memory_sequence_length = tf.contrib.seq2seq.tile_batch(
memory_sequence_length, multiplier=beam_width)
def _cond(word_ids, cum_log_probs, finished, step, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache):
return tf.reduce_any(tf.logical_not(finished))
def _body(word_ids, cum_log_probs, finished, step, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache):
logits, my_cache, op_self_cache, op_mem_cache = decoding_body(word_ids,
step,
extended_memory,
extended_memory_sequence_length,
my_cache,
op_self_cache,
op_mem_cache,
embedding_table,
decoding_args,
decoder_type)
end_ids = tf.fill([tf.shape(logits)[0]], decoding_args.end_id) # [batch_size * beam_width]
eos_max_prob = tf.one_hot(end_ids, decoding_args.vocab_size,
on_value=decoder_args.dtype.max,
off_value=decoder_args.dtype.min) # [batch_size * beam_width, vocab_size]
# [batch_size * beam_width, vocab_size]
logits = tf.where(finished, x=eos_max_prob, y=logits)
logits = tf.cast(logits, tf.float32)
output_id, next_cum_log_probs, finished, my_cache, \
extra_vars, op_self_cache = search_word(beam_width,
decoding_args.vocab_size,
step,
logits,
cum_log_probs,
finished,
my_cache,
extra_vars,
op_self_cache,
search_method=search_method)
cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs)
outputs = outputs.write(step, output_id)
finished = tf.logical_or(finished, tf.equal(output_id, decoding_args.end_id))
return output_id, cum_log_probs, finished, step + 1, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache
# initialization
batchxbeam = tf.shape(extended_memory)[0]
start_ids, step, outputs, tf_decoder_cache, finished, initial_log_probs, \
tf_sequence_lengths, extra_vars = initialize_decoding_variables(decoding_args, batchxbeam)
word_ids = tf.identity(start_ids, name="word_ids")
cum_log_probs = tf.identity(initial_log_probs, name="cum_log_probs")
# if use_op == False, these two caches are useless
op_self_cache, op_mem_cache = init_op_cache(decoder_args, batchxbeam, tf.reduce_max(memory_sequence_length))
_, _, _, _, outputs, _, extra_vars, _, _ = tf.while_loop(
_cond,
_body,
loop_vars=(
word_ids,
cum_log_probs,
finished,
step,
outputs,
tf_decoder_cache,
extra_vars,
op_self_cache,
op_mem_cache
),
back_prop=False,
maximum_iterations=decoding_args.max_seq_len,
shape_invariants=(
start_ids.shape,
initial_log_probs.shape,
finished.shape,
step.shape,
tf.TensorShape(None),
tf.contrib.framework.nest.map_structure(_get_shape_invariants, tf_decoder_cache),
tf.contrib.framework.nest.map_structure(_get_shape_invariants, extra_vars),
tf.contrib.framework.nest.map_structure(_get_shape_invariants, op_self_cache),
tf.contrib.framework.nest.map_structure(_get_shape_invariants, op_mem_cache))
)
tf_parent_ids = extra_vars[0].stack()
tf_sequence_lengths = extra_vars[1]
tf_output_ids = outputs.stack()
finalized_tf_output_ids, finalized_tf_sequence_lengths = finalize(beam_width,
tf_parent_ids,
tf_sequence_lengths,
tf_output_ids,
decoding_args.end_id)
finalized_tf_output_ids = tf.cast(finalized_tf_output_ids, start_ids.dtype)
finalized_tf_sequence_lengths = tf.minimum(
finalized_tf_sequence_lengths + 1, tf.shape(finalized_tf_output_ids)[2])
return finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, tf_parent_ids, tf_sequence_lengths
def tf_sampling_decoding(memory_tensor,
memory_sequence_length,
embedding_table,
decoding_args,
decoder_type):
'''
Run the decoding with sampling by TensorFlow.
Args:
memory_tensor: A tf.tensor with shape [batch_size, max(memory_sequence_length), encoder_hidden_dimension].
The results of encoder transformer layer. The rank must be 3.
memory_sequence_length: A tf.Tensor with shape [batch_size], type tf.int.
The lenght of each sentence of results of encoder.
embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension].
The embedding table of embedding lookup for each step.
decoder_args: The arguments for decoding. The details are in the class "DecodingSamplingArgument" of common.py
decoder_type: A int value. Choose to using TensorFlow decoder, FasterTransformer decoder, or both.
If it is 0, then using the TensorFlow decoder only.
If it is 1, then using the FasterTransformer decoder only.
If it is 2, then using both decoder and compare their result.
Outputs:
tf_output_ids: A tf.Tensor with shape [batch_size, max(sequence_lengths)], with int type.
The results of decoding. It contains the id of token of vocabulary.
sequence_lengths: A tf.Tensor with shape [batch_size], with int type.
'''
decoder_args = decoding_args.decoder_args
with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE):
batch_size = tf.shape(memory_tensor)[0]
def _cond(word_ids, finished, step, outputs, my_cache, sequence_lengths, op_self_cache, op_mem_cache):
return tf.reduce_any(tf.logical_not(finished))
def _body(word_ids, finished, step, outputs, my_cache, sequence_lengths, op_self_cache, op_mem_cache):
logits, my_cache, op_self_cache, op_mem_cache = decoding_body(word_ids,
step,
memory_tensor,
memory_sequence_length,
my_cache,
op_self_cache,
op_mem_cache,
embedding_table,
decoding_args,
decoder_type)
end_ids = tf.fill([batch_size],decoding_args.end_id) # [batch_size * beam_width]
eos_max_prob = tf.one_hot(end_ids, decoding_args.vocab_size,
on_value=decoder_args.dtype.max,
off_value=decoder_args.dtype.min) # [batch_size * beam_width, vocab_size]
# [batch_size, vocab_size]
logits = tf.where(finished, x=eos_max_prob, y=logits)
logits = tf.cast(logits, tf.float32)
# sampling
if decoding_args.top_k != 0:
sampling_method = Sampling("top_k")
output_id = sampling_method.sample(logits, threshold=decoding_args.top_k)
elif decoding_args.top_p != 0.0:
sampling_method = Sampling("top_p")
output_id = sampling_method.sample(logits, threshold=decoding_args.top_p)
sequence_lengths = tf.where(finished, x=sequence_lengths, y=sequence_lengths + 1)
outputs = outputs.write(step, output_id)
finished = tf.logical_or(finished, tf.equal(output_id, decoding_args.end_id))
# return output_id, cum_log_probs, finished, step + 1, outputs, my_cache, extra_vars, op_self_cache, op_mem_cache
return output_id, finished, step + 1, outputs, my_cache, sequence_lengths, op_self_cache, op_mem_cache
# initialization
start_ids, step, outputs, tf_decoder_cache, finished, _, \
_, extra_vars = initialize_decoding_variables(decoding_args, batch_size)
sequence_lengths = extra_vars[1]
word_ids = tf.identity(start_ids, name="word_ids")
# if use_op == False, these two caches are useless
op_self_cache, op_mem_cache = init_op_cache(decoder_args, batch_size, tf.reduce_max(memory_sequence_length))
_, _, _, outputs, _, sequence_lengths, _, _ = tf.while_loop(
_cond,
_body,
loop_vars=(
word_ids,
finished,
step,
outputs,
tf_decoder_cache,
sequence_lengths,
op_self_cache,
op_mem_cache
),
back_prop=False,
maximum_iterations=decoding_args.max_seq_len,
shape_invariants=(
start_ids.shape,
finished.shape,
step.shape,
tf.TensorShape(None),
tf.contrib.framework.nest.map_structure(
_get_shape_invariants, tf_decoder_cache),
tf.contrib.framework.nest.map_structure(
_get_shape_invariants, sequence_lengths),
tf.contrib.framework.nest.map_structure(
_get_shape_invariants, op_self_cache),
tf.contrib.framework.nest.map_structure(_get_shape_invariants, op_mem_cache))
)
tf_output_ids = outputs.stack()
tf_sequence_lengths = sequence_lengths
tf_output_ids = tf.reshape(tf_output_ids, [-1, batch_size])
tf_output_ids = tf.transpose(tf_output_ids, [1, 0])
tf_output_ids = tf.cast(tf_output_ids, start_ids.dtype)
return tf_output_ids, sequence_lengths
def preprocess_decoder_var(decoding_vars,
num_layer,
using_model_var,
checkpoint_filename,
data_type,
fuse_qkv=True):
'''
Args:
decoding_vars: A list of tf.Tensor. The variables of decoding.
num_layer: A int value. The number of transformer layer of decoder in decoding
using_model_var: A bool value. Using the model variables of TensorFlow or not.
If True, then putting the model variables of TensorFlow decoding model into decoding op directly.
The data type is tensor of TensorFlow in this case.
If False, then restoring the values of variables from the checkpoint_filename, and putting
the values into decoding op.
The data type is numpy is this case.
checkpoint_file: A string. The checkpoint file name of storing the values of model. The checkpoint should be stored in
pickle, and the name of checkpoint should be xxx.pkl.
The model is saved by dict.
The key of the dict is the name of variables
The value of the dict is the values of variables
For example, decoding_vars[0]=<tf.Variable 'transformer/decoder/layer_0/masked_multi_head/LayerNorm/beta:0' shape=(512,) dtype=float32_ref>,
then the key is 'transformer/decoder/layer_0/masked_multi_head/LayerNorm/beta:0'; the value is sess.run(decoding_vars[0])
data_type: tf.float32 or tf.float16.
Only used when using_model_var is False. Convert the numpy data to the data type of model.
Outputs:
vars_in_diff_layers_dict: A dict to store the variables by their name.
For decoder variables, the key is like 'transformer/decoder/layer/masked_multi_head/LayerNorm/beta:0',
which is similar to the name of variables, except we use 'layer' but not 'layer_x'. The value is a list,
which contains 'transformer/decoder/layer_%d/masked_multi_head/LayerNorm/beta:0' % i for i in range(num_layer)
For other variables, the key is the name of variable, and the value is the correspoding weight.
Note that we return the concated weights. The concat operation would bring other overhead, and this should be optimized in
the real application. The recommended method is pre-processing the weights as numpy format. Because TensorFlow do the operations
for each inference if using the TensorFlow to pre-process the weights.
'''
var_dict = {}
if using_model_var == False:
# restore the model from the checkpoint file
if(checkpoint_filename == None):
print("[ERROR] checkpoint_filename cannot be None when using_model_var is False.")
exit(-1)
with open(checkpoint_filename, 'rb') as f:
ckpt = pickle.load(f)
for var in decoding_vars:
var_dict[var.name] = ckpt[var.name]
else:
for var in decoding_vars:
var_dict[var.name] = var
vars_in_diff_layers_dict = {}
vars_in_diff_layers_dict["transformer/decoder/LayerNorm/beta:0"] = var_dict["transformer/decoder/LayerNorm/beta:0"]
vars_in_diff_layers_dict["transformer/decoder/LayerNorm/gamma:0"] = var_dict["transformer/decoder/LayerNorm/gamma:0"]
vars_in_diff_layers_dict["transformer/decoder/dense/kernel:0"] = var_dict["transformer/decoder/dense/kernel:0"]
vars_in_diff_layers_dict["transformer/decoder/dense/bias:0"] = tf.cast(var_dict["transformer/decoder/dense/bias:0"], dtype=tf.float32)
for i in range(num_layer):
'''
Handling the names of q, k, v kernel and bias because their names
are different for fusing the qkv or not.
'''
layer_prefix_name = "transformer/decoder/layer_%d/" % i
if fuse_qkv == True:
var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'], \
var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'], \
var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] = \
tf.split(var_dict[layer_prefix_name + 'masked_multi_head/conv1d/kernel:0'], 3, axis=-1)
var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'], \
var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'], \
var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] = \
tf.split(var_dict[layer_prefix_name + 'masked_multi_head/conv1d/bias:0'], 3, axis=-1)
var_dict[layer_prefix_name + 'multi_head/query/kernel:0'] = \
var_dict[layer_prefix_name + 'multi_head/conv1d/kernel:0']
var_dict[layer_prefix_name + 'multi_head/query/bias:0'] = \
var_dict[layer_prefix_name + 'multi_head/conv1d/bias:0']
var_dict[layer_prefix_name + 'multi_head/key/kernel:0'], \
var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] = \
tf.split(var_dict[layer_prefix_name + 'multi_head/conv1d_1/kernel:0'], 2, axis=-1)
var_dict[layer_prefix_name + 'multi_head/key/bias:0'], \
var_dict[layer_prefix_name + 'multi_head/value/bias:0'] = \
tf.split(var_dict[layer_prefix_name + 'multi_head/conv1d_1/bias:0'], 2, axis=-1)
else:
var_dict[layer_prefix_name + 'masked_multi_head/query/kernel:0'] = \
var_dict[layer_prefix_name + 'masked_multi_head/conv1d/kernel:0']
var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0'] = \
var_dict[layer_prefix_name + 'masked_multi_head/key/kernel:0']
var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0'] = \
var_dict[layer_prefix_name + 'masked_multi_head/value/kernel:0']
var_dict[layer_prefix_name + 'masked_multi_head/query/bias:0'] = \
var_dict[layer_prefix_name + 'masked_multi_head/conv1d/bias:0']
var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0'] = \
var_dict[layer_prefix_name + 'masked_multi_head/key/bias:0']
var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0'] = \
var_dict[layer_prefix_name + 'masked_multi_head/value/bias:0']
var_dict[layer_prefix_name + 'multi_head/query/kernel:0'] = \
var_dict[layer_prefix_name + 'multi_head/conv1d/kernel:0']
var_dict[layer_prefix_name + 'multi_head/query/bias:0'] = \
var_dict[layer_prefix_name + 'multi_head/conv1d/bias:0']
var_dict[layer_prefix_name + 'multi_head/key/kernel:0'] = \
var_dict[layer_prefix_name + 'multi_head/conv1d_1/kernel:0']
var_dict[layer_prefix_name + 'multi_head/key/bias:0'] = \
var_dict[layer_prefix_name + 'multi_head/conv1d_1/bias:0']
var_dict[layer_prefix_name + 'multi_head/value/kernel:0'] = \
var_dict[layer_prefix_name + 'multi_head/value/kernel:0']
var_dict[layer_prefix_name + 'multi_head/value/bias:0'] = \
var_dict[layer_prefix_name + 'multi_head/value/bias:0']
layer_prefix_name = 'transformer/decoder/layer'
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/LayerNorm/beta:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/LayerNorm/beta:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/LayerNorm/gamma:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/LayerNorm/gamma:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/query/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/query/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/query/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/query/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/key/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/key/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/key/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/key/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/value/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/value/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/value/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/value/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/conv1d_1/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/conv1d_1/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/masked_multi_head/conv1d_1/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/masked_multi_head/conv1d_1/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/LayerNorm/beta:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/LayerNorm/beta:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/LayerNorm/gamma:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/LayerNorm/gamma:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/query/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/query/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/query/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/query/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/key/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/key/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/key/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/key/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/value/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/value/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/value/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/value/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/conv1d_2/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/conv1d_2/kernel:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/multi_head/conv1d_2/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/multi_head/conv1d_2/bias:0' % i] for i in range(num_layer) ], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/ffn/LayerNorm/beta:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/LayerNorm/beta:0' % i] for i in range(num_layer)], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/ffn/LayerNorm/gamma:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/LayerNorm/gamma:0' % i] for i in range(num_layer)], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d/kernel:0' % i] for i in range(num_layer)], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d/bias:0' % i] for i in range(num_layer)], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d_1/kernel:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d_1/kernel:0' % i] for i in range(num_layer)], axis=0), dtype=data_type)
vars_in_diff_layers_dict[layer_prefix_name + '/ffn/conv1d_1/bias:0'] = \
tf.cast(tf.concat([ var_dict[layer_prefix_name + '_%d/ffn/conv1d_1/bias:0' % i] for i in range(num_layer)], axis=0), dtype=data_type)
return vars_in_diff_layers_dict
def op_beamsearch_decoding(memory_tensor,
memory_sequence_length,
embedding_table,
decoding_vars,
decoding_args,
using_model_var=True,
checkpoint_filename=None):
'''
Run the decoding with beam search by TensorFlow.
Args:
memory_tensor: A tf.tensor with shape [batch_size * beam_width, max(memory_sequence_length), encoder_hidden_dimension].
The results of encoder transformer layer. The rank must be 3.
Note that it must be extended by beam_width times.
memory_sequence_length: A tf.Tensor with shape [batch_size * beam_width], type tf.int.
The lenght of each sentence of results of encoder.
Note that it must be extended by beam_width times.
embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension].
The embedding table of embedding lookup for each step.
decoder_vars: A list of tf.Tensor. The variables for decoding. A list of model variables of TensorFlow model.
decoder_args: The arguments for decoding. The details are in the class "DecodingBeamsearchArgument" of common.py
using_model_var: A bool value. Using the model variables of TensorFlow or not.
The details are described in 'preprocess_decoder_var' function in the following.
checkpoint_filename: A string. The checkpoint file name of storing the values of model.
The details are described in 'preprocess_decoder_var' function in the following.
Outputs:
finalized_output_ids: A tf.Tensor with shape [batch_size, beam_width, max(sequence_lengths)], with tf.int type.
Finalized output_ids by beam search algorithm and parent_ids.
finalized_sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type.
Finalized sequence_lengths by beam search algorithm and parent_ids.
output_ids: A tf.Tensor with shape [batch_size, beam_width, max(sequence_lengths)], with tf.int type.
The results of decoding. It contains the id of token of vocabulary.
parent_ids: A tf.Tensor with shape [batch_size, beam_width, max(sequence_lengths)], with tf.int type.
The beam index of output ids for each step.
sequence_lengths: A tf.Tensor with shape [batch_size * beam_width], with int type.
'''
decoder_args = decoding_args.decoder_args
decoding_op_module = tf.load_op_library(os.path.join('./lib/libtf_decoding_beamsearch.so'))
vars_dict_in_differ_layers = preprocess_decoder_var(decoding_vars,
decoder_args.num_layer,
using_model_var,
checkpoint_filename,
decoder_args.dtype,
decoder_args.fuse_qkv)
extended_memory = tf.contrib.seq2seq.tile_batch(
memory_tensor, multiplier=decoder_args.beam_width)
extended_memory_sequence_length = tf.contrib.seq2seq.tile_batch(
memory_sequence_length, multiplier=decoder_args.beam_width)
position_encoder = SinusoidalPositionEncoder()
position_encoding_table = position_encoder._create_position_encoding_table(
decoding_args.max_seq_len, decoder_args.head_num * decoder_args.size_per_head, decoder_args.dtype)
# shape of position_encoding_table: [max_seq_len, hidden_dim]
output_ids, parent_ids, sequence_lengths = decoding_op_module.decoding(
extended_memory, # 0
extended_memory_sequence_length, # 1
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/beta:0'], # 2
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/gamma:0'], # 3
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/kernel:0'], # 4
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/bias:0'], # 5
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/kernel:0'], # 6
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/bias:0'], # 7
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/kernel:0'], # 8
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/bias:0'], # 9
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/kernel:0'], # 10
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/bias:0'], # 11
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/beta:0'], # 12
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/gamma:0'], # 13
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/kernel:0'], # 14
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/bias:0'], # 15
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/kernel:0'], # 16
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/bias:0'], # 17
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/kernel:0'], # 18
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/bias:0'], # 19
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/kernel:0'], # 20
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/bias:0'], # 21
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/beta:0'], # 22
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/gamma:0'], # 23
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/kernel:0'], # 24
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/bias:0'], # 25
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/kernel:0'], # 26
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/bias:0'], # 27
vars_dict_in_differ_layers['transformer/decoder/LayerNorm/beta:0'], # 28
vars_dict_in_differ_layers['transformer/decoder/LayerNorm/gamma:0'], # 29
embedding_table, # 30
vars_dict_in_differ_layers['transformer/decoder/dense/kernel:0'], # 31
vars_dict_in_differ_layers['transformer/decoder/dense/bias:0'], # 32
position_encoding_table, # 33
beam_width=decoder_args.beam_width,
max_seq_len=decoding_args.max_seq_len,
head_num=decoder_args.head_num,
size_per_head=decoder_args.size_per_head,
num_layer=decoder_args.num_layer,
start_id=decoding_args.start_id,
end_id=decoding_args.end_id,
beam_search_diversity_rate=decoding_args.beam_search_diversity_rate
)
parent_ids = parent_ids % decoder_args.beam_width
finalized_output_ids, finalized_sequence_lengths = finalize(decoder_args.beam_width,
parent_ids,
sequence_lengths,
output_ids,
decoding_args.end_id,
decoding_args.max_seq_len)
finalized_sequence_lengths = tf.minimum(
finalized_sequence_lengths + 1, tf.shape(finalized_output_ids)[2])
return finalized_output_ids, finalized_sequence_lengths, output_ids, parent_ids, sequence_lengths
def op_sampling_decoding(memory_tensor,
memory_sequence_length,
embedding_table,
decoding_vars,
decoding_args,
using_model_var=True,
checkpoint_filename=None):
'''
Run the decoding with sampling by FasterTransformer.
Args:
memory_tensor: A tf.tensor with shape [batch_size, max(memory_sequence_length), encoder_hidden_dimension].
The results of encoder transformer layer. The rank must be 3.
memory_sequence_length: A tf.Tensor with shape [batch_size], type tf.int.
The lenght of each sentence of results of encoder.
embedding_table: A tf.Tensor with shape [vocab_size, hidden_dimension].
The embedding table of embedding lookup for each step.
decoder_vars: A list of tf.Tensor. The variables for decoding. A list of model variables of TensorFlow model.
decoder_args: The arguments for decoding. The details are in the class "DecodingSamplingArgument" of common.py
using_model_var: A bool value. Using the model variables of TensorFlow or not.
The details are described in 'preprocess_decoder_var' function in the following.
checkpoint_filename: A string. The checkpoint file name of storing the values of model.
The details are described in 'preprocess_decoder_var' function in the following.
Outputs:
output_ids: A tf.Tensor with shape [batch_size, max(sequence_lengths)], with int type.
The results of decoding. It contains the id of token of vocabulary.
sequence_lengths: A tf.Tensor with shape [batch_size], with int type.
'''
decoder_args = decoding_args.decoder_args
decoding_op_module = tf.load_op_library(os.path.join('./lib/libtf_decoding_sampling.so'))
vars_dict_in_differ_layers = preprocess_decoder_var(decoding_vars,
decoding_args.decoder_args.num_layer,
using_model_var,
checkpoint_filename,
decoder_args.dtype,
decoder_args.fuse_qkv)
position_encoder = SinusoidalPositionEncoder()
position_encoding_table = position_encoder._create_position_encoding_table(
decoding_args.max_seq_len, decoder_args.head_num * decoder_args.size_per_head, decoder_args.dtype)
# shape of position_encoding_table: [max_seq_len, hidden_dim]
output_ids, sequence_lengths = decoding_op_module.decoding_sampling(
memory_tensor, # 0
memory_sequence_length, # 1
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/beta:0'], # 2
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/LayerNorm/gamma:0'], # 3
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/kernel:0'], # 4
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/query/bias:0'], # 5
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/kernel:0'], # 6
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/key/bias:0'], # 7
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/kernel:0'], # 8
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/value/bias:0'], # 9
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/kernel:0'], # 10
vars_dict_in_differ_layers['transformer/decoder/layer/masked_multi_head/conv1d_1/bias:0'], # 11
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/beta:0'], # 12
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/LayerNorm/gamma:0'], # 13
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/kernel:0'], # 14
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/query/bias:0'], # 15
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/kernel:0'], # 16
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/key/bias:0'], # 17
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/kernel:0'], # 18
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/value/bias:0'], # 19
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/kernel:0'], # 20
vars_dict_in_differ_layers['transformer/decoder/layer/multi_head/conv1d_2/bias:0'], # 21
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/beta:0'], # 22
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/LayerNorm/gamma:0'], # 23
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/kernel:0'], # 24
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d/bias:0'], # 25
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/kernel:0'], # 26
vars_dict_in_differ_layers['transformer/decoder/layer/ffn/conv1d_1/bias:0'], # 27
vars_dict_in_differ_layers['transformer/decoder/LayerNorm/beta:0'], # 28
vars_dict_in_differ_layers['transformer/decoder/LayerNorm/gamma:0'], # 29
embedding_table, # 30
vars_dict_in_differ_layers['transformer/decoder/dense/kernel:0'], # 31
vars_dict_in_differ_layers['transformer/decoder/dense/bias:0'], # 32
position_encoding_table, # 33
max_seq_len=decoding_args.max_seq_len,
candidate_num=decoding_args.top_k,
probability_threshold=decoding_args.top_p,
head_num=decoder_args.head_num,
size_per_head=decoder_args.size_per_head,
num_layer=decoder_args.num_layer,
start_id=decoding_args.start_id,
end_id=decoding_args.end_id
)
batch_size = tf.shape(memory_tensor)[0]
output_ids = tf.reshape(output_ids, [-1, batch_size])
output_ids = tf.transpose(output_ids, [1, 0])
return output_ids, sequence_lengths