[FasterTransformer] Add translation_sample, fix some bugs.

This commit is contained in:
bhsueh 2020-03-05 11:21:22 +08:00
parent b162523c38
commit bd89bca344
26 changed files with 38797 additions and 202 deletions

View file

@ -37,6 +37,16 @@ FasterTransformer V1 will be deprecated on July 2020.
### Changelog
March 2020
- Add feature in FasterTransformer 2.0
- Fix the bug of maximum sequence length of decoder cannot be larger than 128.
- Add `translate_sample.py` to demonstrate how to translate a sentence by restoring the pretrained model of OpenNMT-tf.
- Fix the bug that decoding does not check finish or not after each step.
- Fix the bug of decoder about max_seq_len.
- Modify the decoding model structure to fit the OpenNMT-tf decoding model.
- Add a layer normalization layer after decoder.
- Add a normalization for inputs of decoder
February 2020
* Release the FasterTransformer 2.0
* Provide a highly optimized OpenNMT-tf based decoder and decoding, including C++ API and TensorFlow OP.

View file

@ -23,6 +23,7 @@ This repository provides a script and recipe to run the highly optimized transfo
* [Inference process](#inference-process)
* [Encoder process](#encoder-process)
* [Decoder and Decoding process](#decoder-and-decoding-process)
* [Translation process](#translation-process)
- [Performance](#performance)
* [Encoder performance](#encoder-performance)
* [Decoder performance on T4](#decoder-performance-on-t4)
@ -51,7 +52,7 @@ FasterTransformer is built on top of CUDA and cuBLAS, providing the C++ API and
The following configurations are supported in the FasterTransformer encoder.
- Batch size (B<sub>1</sub>): smaller or equal to 512
- Sequence length (S): smaller or equal to 128
- Sequence length (S): larger than 3 and smaller or equal to 1024
- Head number (H) and size per head (N):
- 12 heads * 64 per heads
- 4 heads * 32 per heads
@ -60,7 +61,7 @@ The following configurations are supported in the FasterTransformer encoder.
The following configurations are supported in the FasterTransformer decoder and decoding.
- Batch size (B<sub>1</sub>) * beam width (B<sub>2</sub>): smaller than 1024
- Sequence length (S): smaller or equal to 128
- Sequence length (S): smaller than 1024
- Head number (H): 8 and 12
- Size per head (N): 64
- Vocabulary size (V): from 64 to 30000
@ -154,10 +155,9 @@ nvidia-docker run -ti nvcr.io/nvidia/tensorflow:19.07-py2 bash
```bash
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples
cd DeepLearningExamples/FasterTransformer/v2
git submodule init
git submodule update
cd FasterTransformer/v2
```
3. Build the project.
@ -356,6 +356,7 @@ The `sample/` folder contains useful sample codes for FasterTransformer:
* `sample/tensorflow/decoding_sample.py` - TensorFlow decoding sample codes
* `sample/tensorflow/encoder_decoder_sample.py` - TensorFlow `encoder_decoder` sample codes
* `sample/tensorflow/encoder_decoding_sample.py` - TensorFlow `encoder_decoding` sample codes
* `sample/tensorflow/translate_sample.py` - TensorFlow translation sample codes
### Command-line options
@ -367,6 +368,7 @@ python decoder_sample.py --help
python decoding_sample.py --help
python encoder_decoder_sample.py --help
python encoder_decoding_sample.py --help
python translate_sample.py --help
```
### Inference process
@ -540,14 +542,16 @@ python decoder_sample.py \
The outputs should be similar to the following:
```bash
[[INFO][PYTHON] step:][1][max diff: ][9.77516174e-06][True]
[[INFO][PYTHON] step:][2][max diff: ][1.04904175e-05][True]
[[INFO][PYTHON] step:][0][max diff: ][5.00679e-06][ op val: ][2.3735888][ tf val: ][2.37359381][True]
[[INFO][PYTHON] step:][1][max diff: ][4.64916229e-06][ op val: ][-0.588810563][ tf val: ][-0.588815212][True]
[[INFO][PYTHON] step:][2][max diff: ][5.36441803e-06][ op val: ][-1.46514082][ tf val: ][-1.46514618][True]
...
[[INFO][PYTHON] step:][31][max diff: ][1.21593475e-05][True]
[[INFO][PYTHON] step:][32][max diff: ][1.04382634e-05][True]
[[INFO][PYTHON] step:][29][max diff: ][4.529953e-06][ op val: ][2.88768935][ tf val: ][2.88769388][True]
[[INFO][PYTHON] step:][30][max diff: ][4.17232513e-06][ op val: ][-1.28717053][ tf val: ][-1.2871747][True]
[[INFO][PYTHON] step:][31][max diff: ][4.05311584e-06][ op val: ][-1.01830876][ tf val: ][-1.01831281][True]
```
The results show that the differences between the decoder of TensorFlow and decoder are smaller than threshold.
The results show that the differences between the decoder of TensorFlow and decoder are smaller than threshold. Note that the differences are absolute differences, so the differences may be large when the op val is large. In this case, the differences are larger than the threshold and the checking will return "False", but it may be not affect the final results.
The option `decoder_type` decides to use the decoder of TensorFlow or decoder of FasterTransformer. `decoder_type 2` uses both decoders and compares their results.
@ -606,15 +610,13 @@ python decoding_sample.py \
The outputs should be similar to the following:
```bash
[INFO] Before finalize:
result before finalize cross-check: True
Output ids cross-check: True
Parent ids cross-check: True
sequence lengths cross-check: True
Sequence lengths cross-check: True
[INFO] After finalize:
result after cross-check: True
Finalized output ids cross-check: True
```
Note that the results of OP and the results of TensorFlow are often different in the random inputs and weights.
@ -635,6 +637,34 @@ python encoder_decoding_sample.py \
--data_type fp32
```
#### Translation progress
For translation, we need to use some tools and library of OpenNMT-tf to prepocess the source sentence and build the encoder.
Because the encoder of FasterTransformer is based on BERT, it cannot be restore the pretrained model. So, it requires to use the encoder of OpenNMT-tf.
1. Prepare the pretrained model and the data for translation.
```bash
bash utils/translation/download_model_data.sh
```
`download_model_data.sh` will prepare the `opennmt` folder, which contains the input embedding and the encoder, download the pretrained model, and download the test data into the `translation` folder. This is because the encoder of FasterTransformer is based on BERT, but not OpenNMT-tf, so we cannot restore the pretrained model of OpenNMT-tf for encoder. Therefore, translation requires the encoder of OpenNMT-tf.
Another problem is that the implementation of our tf_decoding and OpenNMT-tf decoding is a little different. For example, OpenNMT-tf uses one gemm to compute query, key and values in one time; but tf_decoding splits them into three gemms. So, the tool `utils/dump_model.py` will convert the pretrained model to fit the model structure of decoder of FasterTransformer.
```bash
./bin/decoding_gemm 1 4 8 64 32001 100 512 0
python translate_sample.py
```
The outputs should be similar to the following:
```bash
[INFO] opennmt: ▁28 - jährige r ▁Chef koch ▁to t ▁in ▁San ▁Francisco </s>
[INFO] tf : ▁28 - jährige r ▁Chef koch ▁to t ▁in ▁San ▁Francisco </s>
[INFO] op : ▁28 - jährige r ▁Chef koch ▁to t ▁in ▁San ▁Francisco </s>
```
## Performance
Hardware settings:
@ -752,6 +782,16 @@ bash scripts/profile_decoding_op_performance.sh
### Changelog
March 2020
- Add feature in FasterTransformer 2.0
- Fix the bug of maximum sequence length of decoder cannot be larger than 128.
- Add `translate_sample.py` to demonstrate how to translate a sentence by restoring the pretrained model of OpenNMT-tf.
- Fix the bug that decoding does not check finish or not after each step.
- Fix the bug of decoder about max_seq_len.
- Modify the decoding model structure to fit the OpenNMT-tf decoding model.
- Add a layer normalization layer after decoder.
- Add a normalization for inputs of decoder
Febuary 2020
- Release the FasterTransformer 2.0
- Provide a highly optimized OpenNMT-tf based decoder and decoding, including C++ API and TensorFlow op.
@ -764,10 +804,8 @@ July 2019
### Known issues
- sequence length of Decoder and Decoding should be smaller than 128.
- batch_size should be smaller than 1024 in Decoder.
- batch_size x beam_width should be smaller than 1024 in Decoding.
- Results of TensorFlow and OP would be different in decoding. This problem is caused by the accumulated log probability, and we do not avoid this problem.
- batch_size should be smaller or equal to 1024 in Decoder.
- batch_size x beam_width should be smaller or equal to 1024 in Decoding.
- Results of TensorFlow and OP would be different in decoding. This problem is caused by the accumulated log probability, and we do not avoid this problem.
- Cmake 15 or Cmake 16 fail to build this project. Cmake 14 is no problem.
- Max sequence length of encoder and decoder should be the same.
- Max sequence length of encoder and decoder should be the same.

View file

@ -104,7 +104,7 @@ public:
auto flat = buf.flat<uint8>();
void *ptr = (void *)flat.data();
cudaMemset(ptr, 0, size);
cudaMemset(ptr, 0, buf_size);
return ptr;
}

View file

@ -40,7 +40,9 @@ void BeamSearch_OpenNMT(
int *output_ids,
const int batch_size, const int beam_width,
const int vocab_size, const int hidden_dim, const int step,
const int cache_size, const int decoder_layers, cudaStream_t stream)
const int cache_size, const int decoder_layers, cudaStream_t stream,
const int end_id,
int *finished_count)
{
#ifdef NDEBUG
/* adding cum_log_probs to log_probs */
@ -75,11 +77,15 @@ void BeamSearch_OpenNMT(
#endif
#ifdef NDEBUG
update(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream);
update(log_probs, cum_log_probs, ids, finished,
parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream,
end_id, finished_count);
#else
update(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream);
update(log_probs, cum_log_probs, ids, finished,
parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream,
end_id, finished_count);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
@ -89,13 +95,17 @@ void BeamSearch_OpenNMT(
Note that update_kernel_check contains update and uses do not need to call it again.
*/
// update_kernel_check(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids,
// batch_size, beam_width, vocab_size, stream);
// batch_size, beam_width, vocab_size, stream, end_id, finished_count);
#endif
#ifdef NDEBUG
update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream);
update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size,
beam_width, hidden_dim, step, cache_size,
decoder_layers, stream);
#else
update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream);
update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size,
beam_width, hidden_dim, step, cache_size,
decoder_layers, stream);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());

View file

@ -69,13 +69,12 @@ T blockReduceSum(T val)
if(lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)0.0f;
val = warpReduceSum(val);
return val;
}
template <typename T>
__inline__ __device__
T warpReduceMax(T val)
@ -386,12 +385,15 @@ void topK(const float* log_probs, int* ids, const int batch_size, const int beam
template <typename T>
__global__
void update_kernel(T* log_probs, T* cum_log_probs, int* ids, bool* finished, int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width, const int vocab_size)
void update_kernel(T* log_probs, T* cum_log_probs,
int* ids, bool* finished,
int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width,
const int vocab_size, const int end_id,
int* finished_count)
{
int tid = threadIdx.x;
sequence_length[tid] = finished[tid] ? sequence_length[tid] : sequence_length[tid] + 1;
int beam_id = ids[tid];
@ -401,10 +403,14 @@ void update_kernel(T* log_probs, T* cum_log_probs, int* ids, bool* finished, int
cum_log_probs[tid] = log_probs[ids[tid]];
sequence_length[tid] = sequence_length[beam_id];
finished[tid] = finished[beam_id];
finished[tid] = word_id == end_id ? 1 : 0;
parent_ids[tid] = beam_id;
word_ids[tid] = word_id;
output_ids[tid] = word_id;
// TODO use reduce sum to compute how many sentence are finished
// int fi = finished[tid]
// int total_finish = reduceSum(fi);
}
template <typename T>
@ -415,9 +421,13 @@ __global__ void embedding_lookup_kernel(const T* embedding_table, const int* wor
from_tensor[write_pos] = embedding_table[word_ids[blockIdx.x] * hidden_units + threadIdx.x];
}
void update(float* log_probs, float* cum_log_probs, int* ids, bool* finished, int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width, const int vocab_size, cudaStream_t stream)
void update(float* log_probs, float* cum_log_probs,
int* ids, bool* finished,
int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream,
const int end_id, int* finished_count)
{
dim3 grid(1);
@ -425,9 +435,11 @@ void update(float* log_probs, float* cum_log_probs, int* ids, bool* finished, in
assert(block.x <= 1024);
update_kernel<float><<<grid, block, 0, stream>>>(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length,
word_ids, output_ids,
batch_size, beam_width, vocab_size);
update_kernel<float><<<grid, block, 0, stream>>>(log_probs, cum_log_probs, ids,
finished, parent_ids, sequence_length,
word_ids, output_ids, batch_size,
beam_width, vocab_size, end_id,
finished_count);
}
template <typename T>
@ -565,14 +577,17 @@ __global__
void sine_position_encoder_kernel(T* output, int step, int n){
int tid = threadIdx.x;
int bid = blockIdx.x;
int half_n = n / 2;
float half_n = (float)n / 2.;
float log_timescale_increment = __logf(10000) / (( half_n - 1) * 1.f);
float inv_timescales = __expf( (tid % half_n) * -1 * log_timescale_increment );
// input = input * hidden_dim**0.5
output[bid * n + tid] = output[bid * n + tid] * (T)sqrtf(float(n));
float log_timescale_increment = __logf(10000) / (half_n - 1.f);
float inv_timescales = __expf( (tid % (int)half_n) * -1 * log_timescale_increment );
float scaled_time = inv_timescales * step;
T encoding_val = (tid < half_n) ? (T) __sinf(scaled_time) : (T) __cosf(scaled_time);
output[bid * n + tid] = output[bid * n + tid] + encoding_val;
output[bid * n + tid] = output[bid * n + tid] + encoding_val;
}
template<typename T>
@ -584,7 +599,7 @@ void sine_position_encoder(
dim3 grid(m);
dim3 block(n);
assert(n <= 1024);
sine_position_encoder_kernel<T><<<grid, block, 0, stream>>>(output, step + 1, n);
sine_position_encoder_kernel<T><<<grid, block, 0, stream>>>(output, step, n);
}
template void add_bias_act_kernelLauncher<float>(

View file

@ -16,38 +16,53 @@
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace fastertransformer{
template <typename T>
void add_bias_act_kernelLauncher(T* out, const T* bias, int m, int n, cudaStream_t stream);
template <typename T>
void add_bias_input_layernorm_kernelLauncher(T* out, const T* input_tensor, const T* bias,
const T* gamma, const T* beta, int m, int n, cudaStream_t stream);
void add_bias_input_layernorm_kernelLauncher(T* out, const T* input_tensor,
const T* bias, const T* gamma,
const T* beta, int m, int n,
cudaStream_t stream);
void broadcast_kernelLauncher(float* log_probs, float* cum_log_probs, const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream);
void broadcast_kernelLauncher(float* log_probs, float* cum_log_probs,
const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream);
void topK(const float* log_probs, int* ids, const int batch_size, const int beam_width, const int vocab_size,
cudaStream_t stream);
void topK(const float* log_probs, int* ids, const int batch_size,
const int beam_width, const int vocab_size, cudaStream_t stream);
void update(float* log_probs, float* cum_log_probs, int* ids, bool* finished, int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream);
void update(float* log_probs, float* cum_log_probs, int* ids,
bool* finished, int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream,
const int end_id,
int* finished_count);
template <typename T>
void embedding_lookup(const T* embedding_table, const int* word_ids, T* from_tensor, const int batch_size, const int beam_width,
const int hidden_units, cudaStream_t stream);
void embedding_lookup(const T* embedding_table, const int* word_ids,
T* from_tensor, const int batch_size,
const int beam_width, const int hidden_units,
cudaStream_t stream);
void update_logits(float* logits, const float* bias, const int end_ids, const bool* finished, const int m, const int n, cudaStream_t stream);
void update_logits(float* logits, const float* bias, const int end_ids,
const bool* finished, const int m, const int n,
cudaStream_t stream);
void init(bool* finished, int* sequence_length, int* word_ids, float* cum_log_probs, const int sentence_id, const int batch_size,
const int beam_width, cudaStream_t stream);
void init(bool* finished, int* sequence_length, int* word_ids,
float* cum_log_probs, const int sentence_id,
const int batch_size, const int beam_width, cudaStream_t stream);
template <typename T>
void update_KV_cache(T** key_cache, T** value_cache, const int* beam_ids, const int batch_size, const int beam_width,
const int hidden_dim, const int step, const int cache_size, const int decoder_layers, cudaStream_t stream);
void update_KV_cache(T** key_cache, T** value_cache, const int* beam_ids,
const int batch_size, const int beam_width,
const int hidden_dim, const int step,
const int cache_size, const int decoder_layers,
cudaStream_t stream);
template <typename T>
void sine_position_encoder(T* output, int step, int m, int n, cudaStream_t stream);

View file

@ -310,7 +310,8 @@ void topK_kernel_check(const float *log_probs, int *ids, const int batch_size, c
void update_kernel_check(float *log_probs, float *cum_log_probs, int *ids, bool *finished, int *parent_ids, int *sequence_length,
int *word_ids, int *output_ids,
const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream)
const int vocab_size, cudaStream_t stream,
const int end_id, int* finished_count)
{
printf("[INFO] decoding update check. \n");
@ -346,7 +347,7 @@ void update_kernel_check(float *log_probs, float *cum_log_probs, int *ids, bool
// compute on GPU and copy to GPU output
update(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream);
batch_size, beam_width, vocab_size, stream, end_id, finished_count);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());

View file

@ -79,7 +79,8 @@ void topK_kernel_check(const float* log_probs, int* ids, const int batch_size, c
void update_kernel_check(float* log_probs, float* cum_log_probs, int* ids, bool* finished, int* parent_ids, int* sequence_length,
int* word_ids, int* output_ids,
const int batch_size, const int beam_width,
const int vocab_size, cudaStream_t stream);
const int vocab_size, cudaStream_t stream,
const int end_id, int* finished_count);
template <typename T>
void update_KV_cache_kernel_check(T** key_cache, T** value_cache, const int* beam_ids, const int batch_size, const int beam_width, const int hidden_dim,

View file

@ -41,6 +41,8 @@ public:
const T *memory_tensor;
const int *memory_sequence_length;
LayerNormWeight<T> layernorm;
int *output_ids;
int *parent_ids;
int *sequence_length;
@ -77,6 +79,7 @@ private:
DataType_ **V_mem_cache_;
DataType_ *from_tensor_[2];
DataType_ *decoder_buf_;
DataType_ *decoder_normed_result_buf_;
float *logits_buf_;
float *cum_log_buf_;
int *word_ids_buf_;
@ -85,18 +88,18 @@ private:
void *buf_;
int start_id_;
int end_id_;
int memory_hidden_units_;
int *finished_count_buf_;
bool *h_finished_buf_;
public:
DecodingOpenNMT(const IAllocator &allocator, const int batch_size,
const int beam_width, const int seq_len,
const int head_num, const int size_per_head,
const int vocab_size, const int decoder_layers,
const int memory_hidden_units,
const int memory_hidden_units, const int memory_max_seq_len,
const int start_id, const int end_id) : allocator_(allocator), batch_size_(batch_size), beam_width_(beam_width),
seq_len_(seq_len), head_num_(head_num), size_per_head_(size_per_head),
vocab_size_(vocab_size), decoder_layers_(decoder_layers),
memory_hidden_units_(memory_hidden_units),
start_id_(start_id), end_id_(end_id)
{
#ifndef NDEBUG
@ -109,16 +112,19 @@ public:
V_mem_cache_ = new DataType_ *[decoder_layers_];
hidden_units_ = head_num_ * size_per_head_;
decoder_ = new OpenDecoder<OpType_>(allocator, batch_size * beam_width, seq_len, head_num, size_per_head, memory_hidden_units_);
decoder_ = new OpenDecoder<OpType_>(allocator, batch_size * beam_width, memory_max_seq_len,
head_num, size_per_head, memory_hidden_units);
int from_tensor_size = batch_size_ * beam_width_ * hidden_units_; // type T
int decoder_workspace_size = decoder_->getWorkspaceSize(); // type T
int decoder_normed_result_buffer_size = batch_size_ * beam_width_ * hidden_units_; // type T
int cache_size = batch_size_ * beam_width_ * seq_len_ * hidden_units_; // type T
int logits_buf_size = batch_size_ * beam_width_ * vocab_size_; // type T
int logits_buf_size = batch_size_ * beam_width_ * vocab_size_; // type float
int cum_log_buf_size = batch_size_ * beam_width_; // type float
int word_ids_buf_size = batch_size_ * beam_width_; //type int
int finished_buf_size = batch_size_ * beam_width_; //type bool
int finished_count_size = (int)(ceil(1 / 4.)) * 4; // type int
//type int
int topk_ids_buf_size = batch_size_ * beam_width_ * (ceil)((beam_width_ * vocab_size_ * 1.0) / 1024.0);
@ -127,16 +133,18 @@ public:
word_ids_buf_size = (int)(ceil(word_ids_buf_size / 4.)) * 4;
finished_buf_size = (int)(ceil(finished_buf_size / 32.)) * 32;
topk_ids_buf_size = (int)(ceil(topk_ids_buf_size / 4.)) * 4;
int datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size +
cache_size * 6 * decoder_layers_;
cache_size * 6 * decoder_layers_ + decoder_normed_result_buffer_size;
buf_ = reinterpret_cast<void *>(allocator_.malloc(
sizeof(DataType_) * datatype_buf_size +
sizeof(float) * (logits_buf_size + cum_log_buf_size) +
sizeof(int) * word_ids_buf_size +
sizeof(bool) * finished_buf_size +
sizeof(int) * topk_ids_buf_size));
sizeof(int) * topk_ids_buf_size +
sizeof(int) * finished_count_size ));
from_tensor_[0] = (DataType_ *)buf_;
from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size);
@ -154,11 +162,15 @@ public:
V_cache_[1] = V_mem_cache_[decoder_layers - 1] + cache_size + 3 * cache_size * decoder_layers_;
decoder_buf_ = V_cache_[1] + cache_size * decoder_layers_;
logits_buf_ = (float *)(decoder_buf_ + decoder_workspace_size);
decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size);
logits_buf_ = (float *)(decoder_normed_result_buf_ + decoder_normed_result_buffer_size);
cum_log_buf_ = (float *)(logits_buf_ + logits_buf_size);
word_ids_buf_ = (int *)(cum_log_buf_ + cum_log_buf_size);
finished_buf_ = (bool *)(word_ids_buf_ + word_ids_buf_size);
topk_ids_buf_ = (int *)(finished_buf_ + finished_buf_size);
finished_count_buf_ = (int *)(topk_ids_buf_ + topk_ids_buf_size);
h_finished_buf_ = new bool[finished_buf_size];
FILE *fd = fopen("decoding_gemm_config.in", "r");
int err = 0;
@ -304,6 +316,8 @@ public:
check_cuda_error(cudaGetLastError());
#endif
}
decoder_->decoder_norm1(from_tensor_[out_id], decoding_params.layernorm.gamma,
decoding_params.layernorm.beta, decoder_normed_result_buf_, m, k);
float alpha = (float)1.0f;
float beta = (float)0.0f;
@ -313,7 +327,7 @@ public:
n, m, k,
&alpha,
decoding_params.embedding_kernel, AType_, n,
from_tensor_[out_id], BType_, k,
decoder_normed_result_buf_, BType_, k,
&beta,
logits_buf_, CUDA_R_32F, n,
CUDA_R_32F,
@ -342,17 +356,28 @@ public:
logits_buf_, cum_log_buf_, finished_buf_,
K_cache_,
V_cache_,
decoding_params.parent_ids + step * batch_size_ * beam_width_,
decoding_params.parent_ids + (step - 1) * batch_size_ * beam_width_,
decoding_params.sequence_length,
word_ids_buf_,
topk_ids_buf_,
decoding_params.output_ids + step * batch_size_ * beam_width_,
batch_size_, beam_width_, vocab_size_, hidden_units_, step, cache_size, decoder_layers_, decoding_params.stream);
decoding_params.output_ids + (step - 1) * batch_size_ * beam_width_,
batch_size_, beam_width_, vocab_size_, hidden_units_, step, cache_size, decoder_layers_, decoding_params.stream,
end_id_,
finished_count_buf_);
#ifndef NDEBUG
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
#endif
// TODO
// Find a better method to check the is_finished
cudaMemcpy(h_finished_buf_, finished_buf_, sizeof(bool) * batch_size_ * beam_width_, cudaMemcpyDeviceToHost);
int sum = 0;
for(int i = 0; i < batch_size_ * beam_width_; i++){
sum += (int)h_finished_buf_[i];
}
if(sum == batch_size_ * beam_width_) break;
}
}

View file

@ -97,9 +97,9 @@ public:
PRINT_FUNC_NAME_();
#endif
hidden_units_ = head_num_ * size_per_head_;
if (max_seq_len_ > 128 || batch_size_ > 1024 || hidden_units_ > 1024)
if (batch_size_ > 1024 || hidden_units_ > 1024)
{
printf("[ERROR] Decoder does not support max_seq_len_ > 128 or batch_size_ > 1024 or hidden_units_ > 1024. \n");
printf("[ERROR] Decoder does not support batch_size_ > 1024 or hidden_units_ > 1024. \n");
exit(-1);
}

View file

@ -66,14 +66,12 @@ REGISTER_OP("Decoder")
.Output("new_self_cache: T")
.Output("new_mem_cache: T")
.Attr("T: {float, half}")
.Attr("max_seq_len: int >= 1")
.Attr("head_num: int >= 1")
.Attr("size_per_head: int >= 1")
.Attr("memory_hidden_dim: int >= 1")
.SetShapeFn([](shape_inference::InferenceContext *c) {
c->set_output(0, c->input(0));
c->set_output(1, c->input(1));
c->set_output(2, c->input(2));
c->set_output(1, c->input(29));
c->set_output(2, c->input(30));
return Status::OK();
});
template <typename Device, typename T>
@ -82,15 +80,18 @@ class DecoderOp : public CommonOp<T>
public:
explicit DecoderOp(OpKernelConstruction *context) : CommonOp<T>(context)
{
OP_REQUIRES_OK(context, context->GetAttr("max_seq_len", &max_seq_len_));
OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_));
OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_));
OP_REQUIRES_OK(context, context->GetAttr("memory_hidden_dim", &memory_hidden_dim_));
}
void Compute(OpKernelContext *context) override
{
batch_size_ = (int)context->input(0).dim_size(0);
// input(1): memory_tensor: [batch_size, memory_max_seq_len, memory_hidden_dim]
assert((int)(context->input(1).dims()) == 3);
const int batch_size_ = (int)context->input(1).dim_size(0);
const int max_seq_len_ = (int)context->input(1).dim_size(1);
const int memory_hidden_dim_ = (int)context->input(1).dim_size(2);
typedef DecoderTransformerTraits<traits_::OpType> DecoderTraits_;
OpenDecoder<DecoderTraits_::OpType> *decoder_;
fastertransformer::Allocator<AllocatorType::TF> allocator_(context);
@ -168,7 +169,7 @@ public:
DataType_ *K_mem_cache = memory_cache;
DataType_ *V_mem_cache = memory_cache + batch_size_ * max_seq_len_ * head_num_ * size_per_head_;
const int decoder_buffer_size = decoder_->getWorkspaceSize() * sizeof(DataType_);
DataType_ *decoder_buffer = (DataType_ *)allocator_.malloc(sizeof(DataType_) * decoder_buffer_size);
DataType_ *decoder_buffer = (DataType_ *)allocator_.malloc(decoder_buffer_size);
OP_REQUIRES_OK(
context,
@ -187,7 +188,7 @@ public:
}
private:
int batch_size_, max_seq_len_, head_num_, size_per_head_, memory_hidden_dim_;
int head_num_, size_per_head_;
typedef TFTraits<T> traits_;
typedef typename traits_::DataType DataType_;
};

View file

@ -58,6 +58,8 @@ REGISTER_OP("Decoding")
.Input("ffn_kernel2: T")
.Input("ffn_bias2: T")
.Input("embedding_table: T")
.Input("decoding_beta: T")
.Input("decoding_gamma: T")
.Input("embedding_kernel: T")
.Input("embedding_bias: float32")
.Output("output_ids: int32")
@ -79,8 +81,8 @@ REGISTER_OP("Decoding")
c->GetAttr("batch_size", &batch_size);
c->GetAttr("beam_width", &beam_width);
c->GetAttr("max_seq_len", &max_seq_len);
c->set_output(0, c->MakeShape({batch_size * beam_width * (max_seq_len + 1)}));
c->set_output(1, c->MakeShape({batch_size * beam_width * (max_seq_len + 1)}));
c->set_output(0, c->MakeShape({batch_size * beam_width * max_seq_len}));
c->set_output(1, c->MakeShape({batch_size * beam_width * max_seq_len}));
c->set_output(2, c->MakeShape({batch_size * beam_width}));
return Status::OK();
});
@ -96,7 +98,6 @@ public:
OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_));
OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_));
OP_REQUIRES_OK(context, context->GetAttr("num_layer", &num_layer_));
OP_REQUIRES_OK(context, context->GetAttr("memory_hidden_dim", &memory_hidden_dim_));
OP_REQUIRES_OK(context, context->GetAttr("vocab_size", &vocab_size_));
OP_REQUIRES_OK(context, context->GetAttr("start_id", &start_id_));
OP_REQUIRES_OK(context, context->GetAttr("end_id", &end_id_));
@ -104,17 +105,22 @@ public:
void Compute(OpKernelContext *context) override
{
// input(0): memory_tensor: [batch_size * beam_width, memory_max_seq_len, memory_hidden_dim]
assert((int)(context->input(0).dims()) == 3);
const int memory_max_seq_len = (int)context->input(0).dim_size(1);
const int memory_hidden_dim_ = (int)context->input(0).dim_size(2);
DecodingInitParam<DataType_> decoding_params;
decoding_params.cublas_handle = this->get_cublas_handler();
Tensor *output_ids = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(0, {max_seq_len_ + 1, batch_size_ * beam_width_}, &output_ids));
context->allocate_output(0, {max_seq_len_, batch_size_ * beam_width_}, &output_ids));
Tensor *parent_ids = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(1, {max_seq_len_ + 1, batch_size_ * beam_width_}, &parent_ids));
context->allocate_output(1, {max_seq_len_, batch_size_ * beam_width_}, &parent_ids));
Tensor *sequence_length = nullptr;
OP_REQUIRES_OK(
@ -125,8 +131,8 @@ public:
decoding_params.parent_ids = reinterpret_cast<int *>(parent_ids->flat<int>().data());
decoding_params.sequence_length = reinterpret_cast<int *>(sequence_length->flat<int>().data());
check_cuda_error(cudaMemset(decoding_params.output_ids, 0, sizeof(int) * (max_seq_len_ + 1) * batch_size_ * beam_width_));
check_cuda_error(cudaMemset(decoding_params.parent_ids, 0, sizeof(int) * (max_seq_len_ + 1) * batch_size_ * beam_width_));
check_cuda_error(cudaMemset(decoding_params.output_ids, 0, sizeof(int) * max_seq_len_ * batch_size_ * beam_width_));
check_cuda_error(cudaMemset(decoding_params.parent_ids, 0, sizeof(int) * max_seq_len_ * batch_size_ * beam_width_));
check_cuda_error(cudaMemset(decoding_params.sequence_length, 0, sizeof(int) * batch_size_ * beam_width_));
typedef DecoderTransformerTraits<traits_::OpType> DecodingTraits_;
@ -138,7 +144,7 @@ public:
allocator_, batch_size_, beam_width_,
max_seq_len_, head_num_, size_per_head_,
vocab_size_, num_layer_,
memory_hidden_dim_,
memory_hidden_dim_, memory_max_seq_len,
start_id_, end_id_);
}
catch (std::runtime_error &error)
@ -146,7 +152,7 @@ public:
OP_REQUIRES(context, false, errors::Internal(error.what()));
}
OP_REQUIRES(context, context->num_inputs() == 31, errors::InvalidArgument("[ERROR] Less or more input arguments"));
OP_REQUIRES(context, context->num_inputs() == 33, errors::InvalidArgument("[ERROR] Less or more input arguments"));
this->get_tensor(context, 0, &decoding_params.memory_tensor);
decoding_params.memory_sequence_length = reinterpret_cast<const int *>(context->input(1).flat<int>().data());
@ -185,10 +191,12 @@ public:
this->get_tensor(context, 27, &params[i].ffn.output_weight.bias, i * hidden_unit);
}
this->get_tensor(context, 28, &decoding_params.embedding_table);
this->get_tensor(context, 29, &decoding_params.embedding_kernel);
this->get_tensor(context, 28, &decoding_params.layernorm.beta);
this->get_tensor(context, 29, &decoding_params.layernorm.gamma);
this->get_tensor(context, 30, &decoding_params.embedding_table);
this->get_tensor(context, 31, &decoding_params.embedding_kernel);
decoding_params.embedding_bias = reinterpret_cast<const float *>(context->input(30).flat<float>().data());
decoding_params.embedding_bias = reinterpret_cast<const float *>(context->input(32).flat<float>().data());
OP_REQUIRES(context, decoding_params.embedding_bias != nullptr, errors::InvalidArgument("memory_sequence_length"));
OP_REQUIRES_OK(

View file

@ -96,6 +96,7 @@ void decoding_sample(int batch_size,
int memory_hidden_units)
{
const int max_seq_len = seq_len;
const int memory_seq_len = seq_len;
const int start_id = 1;
const int end_id = 2;
const int hidden_units = head_num * size_per_head;
@ -123,7 +124,7 @@ void decoding_sample(int batch_size,
T *d_self_gamma, *d_self_beta;
T *d_cross_gamma, *d_cross_beta;
T *d_ffn_gamma, *d_ffn_beta;
device_malloc(&d_self_Q_kernel, sizeof(T) * hidden_units * hidden_units);
device_malloc(&d_self_K_kernel, sizeof(T) * hidden_units * hidden_units);
device_malloc(&d_self_V_kernel, sizeof(T) * hidden_units * hidden_units);
@ -194,15 +195,18 @@ void decoding_sample(int batch_size,
int* d_parent_ids;
int* d_sequence_lengths;
int* d_memory_sequence_lengths;
T *d_gamma, *d_beta;
device_malloc(&d_memory_tensor, sizeof(T) * hidden_units * seq_len * batch_size * beam_width);
device_malloc(&d_embedding_table, sizeof(T) * hidden_units * vocab_size);
device_malloc(&d_embedding_kernel, sizeof(T) * vocab_size * hidden_units);
check_cuda_error(cudaMalloc((void**)&d_embedding_bias, sizeof(float) * vocab_size));
check_cuda_error(cudaMalloc((void**)&d_output_ids, sizeof(int) * (max_seq_len + 1) * batch_size * beam_width));
check_cuda_error(cudaMalloc((void**)&d_parent_ids, sizeof(int) * (max_seq_len + 1) * batch_size * beam_width));
check_cuda_error(cudaMalloc((void**)&d_output_ids, sizeof(int) * (max_seq_len) * batch_size * beam_width));
check_cuda_error(cudaMalloc((void**)&d_parent_ids, sizeof(int) * (max_seq_len) * batch_size * beam_width));
check_cuda_error(cudaMalloc((void**)&d_sequence_lengths, sizeof(int) * batch_size * beam_width));
check_cuda_error(cudaMalloc((void**)&d_memory_sequence_lengths, sizeof(int) * batch_size * beam_width));
device_malloc(&d_gamma, sizeof(T) * hidden_units);
device_malloc(&d_beta, sizeof(T) * hidden_units);
int *h_memory_sequence_lengths = new int[batch_size * beam_width];
for(int i = 0; i < batch_size * beam_width; i++) h_memory_sequence_lengths[i] = seq_len;
@ -218,6 +222,8 @@ void decoding_sample(int batch_size,
decoding_params.parent_ids = d_parent_ids;
decoding_params.sequence_length = d_sequence_lengths;
decoding_params.memory_sequence_length = d_memory_sequence_lengths;
decoding_params.layernorm.gamma = d_gamma;
decoding_params.layernorm.beta = d_beta;
const fastertransformer::OperationType type = sizeof(T) == sizeof(float) ? OperationType::FP32 : OperationType::FP16;
@ -225,7 +231,8 @@ void decoding_sample(int batch_size,
DecodingOpenNMT<type>(allocator, batch_size, beam_width,
max_seq_len, head_num, size_per_head,
vocab_size, decoder_layers,
memory_hidden_units, start_id, end_id);
memory_hidden_units, memory_seq_len,
start_id, end_id);
//warm up
int ite = 100;

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import tensorflow as tf
import argparse
@ -58,7 +58,6 @@ if __name__ == "__main__":
np.random.seed(1)
tf.set_random_seed(1)
kernel_initializer_range = 0.02
initializer_range = 0.02
bias_initializer_range = 0.02
batch_size = args.batch_size
@ -90,8 +89,8 @@ if __name__ == "__main__":
encoder_hidden_dim=memory_hidden_dim,
dtype=tf_datatype)
embedding_table = np.random.rand(vocab_size, hidden_dim).astype(
np_datatype) # a [vocab_size, hidden_dim] table
embedding_table = np.random.randn(vocab_size, hidden_dim).astype(
np_datatype) * 0.01 # a [vocab_size, hidden_dim] table
embedding_table = tf.convert_to_tensor(embedding_table)
memory, memory_sequence_length = generate_encoder_result(
batch_size, max_seq_len, memory_hidden_dim, tf_datatype)

View file

@ -121,7 +121,7 @@ if __name__ == "__main__":
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
if args.cross_check == 1:
finalized_tf_output_ids_result, tf_output_ids_result, tf_parent_ids_result, \
tf_sequence_lengths_result = sess.run(
@ -130,15 +130,15 @@ if __name__ == "__main__":
op_sequence_lengths_result = sess.run(
[finalized_op_output_ids, op_output_ids, op_parent_ids, op_sequence_lengths])
int_result_cross_check("Output ids",
tf_output_ids_result, op_output_ids_result, shape=[1 + max_seq_len, batch_size * beam_width])
int_result_cross_check("Parent ids",
tf_parent_ids_result, op_parent_ids_result, shape=[1 + max_seq_len, batch_size * beam_width])
int_result_cross_check("Sequence lengths",
tf_sequence_lengths_result, op_sequence_lengths_result, shape=[1, batch_size * beam_width])
int_result_cross_check("Finalized output ids",
finalized_tf_output_ids_result.T, finalized_op_output_ids_result.T,
shape=[1 + max_seq_len, batch_size * beam_width])
int_result_cross_check("Output ids", tf_output_ids_result, op_output_ids_result,
shape=[max_seq_len, batch_size * beam_width])
int_result_cross_check("Parent ids", tf_parent_ids_result, op_parent_ids_result,
shape=[max_seq_len, batch_size * beam_width])
int_result_cross_check("Sequence lengths", tf_sequence_lengths_result,
op_sequence_lengths_result, shape=[1, batch_size * beam_width])
int_result_cross_check("Finalized output ids", finalized_tf_output_ids_result.T,
finalized_op_output_ids_result.T,
shape=[max_seq_len, batch_size * beam_width])
if args.test_time == 1 or args.test_tf_time == 1 or args.test_op_time == 1:
if args.test_time == 1 or args.test_tf_time == 1:

View file

@ -132,7 +132,7 @@ if __name__ == "__main__":
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
decoder_var_start_id = 0
while all_vars[decoder_var_start_id].name.find("transformer/decoder") == -1:
while all_vars[decoder_var_start_id].name.find("transformer/decoding") == -1:
decoder_var_start_id += 1
encoder_variables = all_vars[:decoder_var_start_id]
decoder_variables = all_vars[decoder_var_start_id:]

View file

@ -0,0 +1,260 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import tensorflow as tf
import numpy as np
import os
import math
import six
import argparse
from utils.common import time_test, DecodingArgument, int_result_cross_check, TransformerArgument
from utils.decoding import tf_decoding, generate_encoder_result, op_decoding
from utils.encoder import tf_encoder, op_encoder
from utils.decoder import build_sequence_mask
from opennmt.utils import misc
from opennmt.utils import compat
from opennmt.encoders.self_attention_encoder import SelfAttentionEncoder
from opennmt.decoders.self_attention_decoder import SelfAttentionDecoder
from opennmt.inputters import WordEmbedder
from opennmt.inputters import ExampleInputter
def resotre_model_by_pkl(sess, variables):
import pickle as pkl
with open("model.pkl", 'rb') as model_file:
model_dict = pkl.load(model_file)
assign_op_list = []
for var in variables:
print(var.name, end=' ')
if var.name in model_dict:
print("restore", end=' ')
assign_op_list.append(tf.assign(var, np.reshape(model_dict[var.name], var.shape)))
print("mean: {} , var: {} . ".format(np.mean(model_dict[var.name]), np.std(model_dict[var.name])), end=' ')
print()
assert(len(assign_op_list) == len(variables))
sess.run(assign_op_list)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-batch', '--batch_size', type=int, default=1, metavar='NUMBER',
help='batch size (default: 1)')
parser.add_argument('-beam', '--beam_width', type=int, default=4, metavar='NUMBER',
help='beam width (default: 4)')
parser.add_argument('-s', '--max_seq_len', type=int, default=200, metavar='NUMBER',
help='max sequence length (default: 200)')
parser.add_argument('-encoder_head', '--encoder_head_number', type=int, default=8, metavar='NUMBER',
help='encoder head number (default: 12)')
parser.add_argument('-encoder_size', '--encoder_size_per_head', type=int, default=64, metavar='NUMBER',
help='encoder size per head (default: 64)')
parser.add_argument('-decoder_head', '--decoder_head_number', type=int, default=8, metavar='NUMBER',
help='decoder head number (default: 8)')
parser.add_argument('-decoder_size', '--decoder_size_per_head', type=int, default=64, metavar='NUMBER',
help='decoder size per head (default: 64)')
parser.add_argument('-encoder_layer', '--encoder_num_layer', type=int, default=6, metavar='NUMBER',
help='number of layers (default: 6)')
parser.add_argument('-decoder_layer', '--decoder_num_layer', type=int, default=6, metavar='NUMBER',
help='number of layers (default: 6)')
parser.add_argument('-d', '--data_type', type=str, default="fp32", metavar='STRING',
help='data type (default: fp32)')
args = parser.parse_args()
print("\n=============== Argument ===============")
print(args)
start_of_sentence_id = 1
end_of_sentence_id = 2
np.random.seed(1)
tf.set_random_seed(1)
kernel_initializer_range = 0.02
initializer_range = 0.02
bias_initializer_range = 0.02
batch_size = args.batch_size
beam_width = args.beam_width
max_seq_len = args.max_seq_len
encoder_head_num = args.encoder_head_number
encoder_size_per_head = args.encoder_size_per_head
decoder_head_num = args.decoder_head_number
decoder_size_per_head = args.decoder_size_per_head
encoder_num_layer = args.encoder_num_layer
decoder_num_layer = args.decoder_num_layer
encoder_hidden_dim = encoder_head_num * encoder_size_per_head
decoder_hidden_dim = decoder_head_num * decoder_size_per_head
tf_datatype = tf.float32
np_datatype = np.float32
atol_threshold = 2e-5
if args.data_type == "fp16":
tf_datatype = tf.float16
np_datatype = np.float16
atol_threshold = 2e-2
initializer_range = 0.02
source_inputter = WordEmbedder("source_vocabulary", embedding_size=512)
target_inputter = WordEmbedder("target_vocabulary", embedding_size=512)
inputter = ExampleInputter(source_inputter, target_inputter)
inputter.initialize({
"source_vocabulary": "./utils/translation/wmtende.vocab",
"target_vocabulary": "./utils/translation/wmtende.vocab"
})
vocab_size = target_inputter.vocabulary_size
source_file = "./utils/translation/test.en"
decoding_args = DecodingArgument(batch_size=batch_size,
beam_width=beam_width,
head_num=decoder_head_num,
size_per_head=decoder_size_per_head,
num_layer=decoder_num_layer,
max_seq_len=max_seq_len,
vocab_size=vocab_size,
start_id=start_of_sentence_id,
end_id=end_of_sentence_id,
encoder_hidden_dim=encoder_head_num * encoder_size_per_head,
dtype=tf_datatype)
mode = tf.estimator.ModeKeys.PREDICT
with tf.variable_scope("transformer/encoder"):
dataset = inputter.make_inference_dataset(source_file, batch_size)
iterator = dataset.make_initializable_iterator()
source = iterator.get_next()
source_embedding = source_inputter.make_inputs(source)
memory_sequence_length = source["length"]
encoder = SelfAttentionEncoder(
num_layers=encoder_num_layer,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1)
memory, _, _ = encoder.encode(source_embedding, memory_sequence_length, mode=mode)
tf_encoder_result = memory
tf_encoder_result = tf.reshape(
tf_encoder_result, [batch_size, -1, encoder_hidden_dim])
with tf.variable_scope("transformer/decoder", reuse=tf.AUTO_REUSE):
target_inputter.build()
with tf.variable_scope("transformer/decoder", reuse=tf.AUTO_REUSE):
decoder = SelfAttentionDecoder(
num_layers=6,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0)
start_tokens = tf.fill([batch_size], start_of_sentence_id)
end_token = end_of_sentence_id
target_ids, _, target_length, _ = decoder.dynamic_decode_and_search(
target_inputter.embedding,
start_tokens,
end_token,
vocab_size=vocab_size,
beam_width=beam_width,
memory=memory,
memory_sequence_length=memory_sequence_length)
target_vocab_rev = target_inputter.vocabulary_lookup_reverse()
target_tokens = target_vocab_rev.lookup(tf.cast(target_ids, tf.int64))
opennmt_target_length = target_length
opennmt_target_tokens = target_tokens
opennmt_target_ids = target_ids
opennmt_variables = tf.global_variables()
## TF Decoding ###
finalized_tf_output_ids, finalized_tf_sequence_lengths, tf_output_ids, \
tf_parent_ids, tf_sequence_lengths = tf_decoding(tf_encoder_result,
memory_sequence_length,
target_inputter.embedding,
decoding_args,
decoder_type=1,
kernel_initializer_range=kernel_initializer_range,
bias_initializer_range=bias_initializer_range)
tf_target_ids = finalized_tf_output_ids
tf_target_length = finalized_tf_sequence_lengths
tf_target_tokens = target_vocab_rev.lookup(tf.cast(tf_target_ids, tf.int64))
## end of tf decoding ##
## op decoding ##
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
decoder_var_start_id = 0
while all_vars[decoder_var_start_id].name.find("transformer/decoding") == -1:
decoder_var_start_id += 1
encoder_variables = all_vars[:decoder_var_start_id]
decoder_variables = all_vars[decoder_var_start_id:]
finalized_op_output_ids, finalized_op_sequence_lengths, op_output_ids, \
op_parent_ids, op_sequence_lengths = op_decoding(tf_encoder_result,
memory_sequence_length,
target_inputter.embedding,
decoder_variables, # first one is embedding table
decoding_args)
op_target_ids = finalized_op_output_ids
op_target_length = finalized_op_sequence_lengths
op_target_tokens = target_vocab_rev.lookup(tf.cast(op_target_ids, tf.int64))
## end of op decoding
opennmt_target_ids = tf.cast(opennmt_target_ids, tf.int32)
tf_target_ids = tf.cast(tf_target_ids, tf.int32)
op_target_ids = tf.cast(op_target_ids, tf.int32)
opennmt_target_length = tf.minimum(opennmt_target_length + 1, tf.shape(opennmt_target_ids)[2])
tf_target_length = tf.minimum(tf_target_length + 1, tf.shape(tf_target_ids)[2])
op_target_length = tf.minimum(op_target_length + 1, tf.shape(op_target_ids)[2])
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
saver = tf.train.Saver(opennmt_variables)
sess.run(tf.global_variables_initializer())
saver.restore(sess, "translation/ckpt/model.ckpt-500000")
sess.run(tf.tables_initializer())
sess.run(iterator.initializer)
resotre_model_by_pkl(sess, decoder_variables)
iteration = 0
while iteration < 3:
try:
opennmt_batch_tokens, opennmt_batch_length, \
tf_batch_tokens, tf_batch_length, \
op_batch_tokens, op_batch_length, source_result = sess.run([opennmt_target_tokens, opennmt_target_length,
tf_target_tokens, tf_target_length,
op_target_tokens, op_target_length, source])
print("[INFO] opennmt: ", end='')
for tokens, length in zip(opennmt_batch_tokens, opennmt_batch_length):
misc.print_bytes(b" ".join(tokens[0][:length[0] - 1]))
print("[INFO] tf : ", end='')
for tokens, length in zip(tf_batch_tokens, tf_batch_length):
misc.print_bytes(b" ".join(tokens[0][:length[0] - 1]))
print("[INFO] op : ", end='')
for tokens, length in zip(op_batch_tokens, op_batch_length):
misc.print_bytes(b" ".join(tokens[0][:length[0] - 1]))
iteration += 1
except tf.errors.OutOfRangeError:
break

View file

@ -15,8 +15,6 @@
import os
import tensorflow as tf
from common import create_initializer
from utils.position import SinusoidalPositionEncoder
def norm(inputs):
"""Layer normalizes :obj:`inputs`."""
@ -74,12 +72,6 @@ def tf_decoder(decoder_args,
cache=None,
kernel_initializer_range=0.02,
bias_initializer_range=0):
position_encoder = SinusoidalPositionEncoder()
if position_encoder is not None:
inputs = position_encoder(
inputs, position=step + 1 if step is not None else None)
memory_mask = None # has something
if memory is not None and not tf.contrib.framework.nest.is_sequence(memory):
@ -340,10 +332,6 @@ def op_decoder(inputs,
decoder_op_module = tf.load_op_library(
os.path.join('./lib/libtf_decoder.so'))
position_encoder = SinusoidalPositionEncoder()
inputs = position_encoder(
inputs, position=step + 1 if step is not None else None)
op_self_cache = tf.concat([op_self_cache, tf.zeros([decoder_args.num_layer, 2, 1,
decoder_args.batch_size * decoder_args.beam_width,
decoder_args.hidden_dim], dtype=decoder_args.dtype)], axis=2)
@ -366,8 +354,8 @@ def op_decoder(inputs,
decoder_vars[24 + 26 * i], decoder_vars[25 + 26 * i],
op_self_cache[i], op_mem_cache[i],
psuedo_input, # add tf_result as input to prevent the OP and TF from parallel execution and lead to error result
max_seq_len=decoder_args.max_seq_len, head_num=decoder_args.head_num,
size_per_head=decoder_args.size_per_head, memory_hidden_dim=memory_hidden_dim)
head_num=decoder_args.head_num,
size_per_head=decoder_args.size_per_head)
inputs = op_result
return op_result, op_self_cache, op_mem_cache

View file

@ -3,14 +3,14 @@ import tensorflow as tf
import os
from decoder import tf_decoder, op_decoder, init_op_cache, init_tf_cache
from common import create_initializer, _get_shape_invariants
from utils.position import SinusoidalPositionEncoder
def initialize_decoding_variables(decoding_args):
start_ids = tf.fill([decoding_args.decoder_args.batch_size * decoding_args.decoder_args.beam_width],
decoding_args.start_id) # [batch_size * beam_width]
step = tf.constant(1, dtype=tf.int32)
step = tf.constant(0, dtype=tf.int32)
# save the output ids for each step
outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
@ -39,11 +39,11 @@ def generate_encoder_result(batch_size,
memory_sequence_length = np.random.randint(
1, max_seq_len + 1, size=batch_size).astype(np.int32)
outter_embbeding = np.random.rand(memory_hidden_dim)
outter_embbeding = np.random.randn(memory_hidden_dim) * 0.01
memory = []
for i in range(batch_size):
data = np.random.rand(max_seq_len, memory_hidden_dim)
data = np.random.randn(max_seq_len, memory_hidden_dim) * 0.01
for j in range(memory_sequence_length[i], max_seq_len):
data[j] = outter_embbeding
memory.append(data)
@ -103,11 +103,13 @@ def beam_search(beam_width,
return word_ids, cum_log_probs, finished, cache, tuple(extra_vars), op_self_cache
def finalize(beam_width, parent_ids, sequence_lengths, outputs, end_id, max_seq_len):
def finalize(beam_width, parent_ids, sequence_lengths, outputs, end_id, max_seq_len=None):
maximum_lengths = tf.reduce_max(tf.reshape(
sequence_lengths, [-1, beam_width]), axis=-1)
array_shape = [max_seq_len + 1, -1, beam_width]
if max_seq_len != None:
array_shape = [max_seq_len, -1, beam_width]
else:
array_shape = [maximum_lengths[0], -1, beam_width]
step_ids = tf.reshape(outputs, array_shape)
parent_ids = tf.reshape(parent_ids, array_shape)
@ -146,24 +148,29 @@ def op_decoding(memory_tensor,
output_ids, parent_ids, sequence_lengths = decoding_op_module.decoding(
extended_memory, extended_memory_sequence_length,
decoding_vars_in_differ_layers[0], decoding_vars_in_differ_layers[1], decoding_vars_in_differ_layers[
2], decoding_vars_in_differ_layers[3], decoding_vars_in_differ_layers[4],
decoding_vars_in_differ_layers[5], decoding_vars_in_differ_layers[6], decoding_vars_in_differ_layers[
7], decoding_vars_in_differ_layers[8], decoding_vars_in_differ_layers[9],
decoding_vars_in_differ_layers[10], decoding_vars_in_differ_layers[11], decoding_vars_in_differ_layers[
12], decoding_vars_in_differ_layers[13], decoding_vars_in_differ_layers[14],
decoding_vars_in_differ_layers[15], decoding_vars_in_differ_layers[16], decoding_vars_in_differ_layers[
17], decoding_vars_in_differ_layers[18], decoding_vars_in_differ_layers[19],
decoding_vars_in_differ_layers[20], decoding_vars_in_differ_layers[21], decoding_vars_in_differ_layers[
22], decoding_vars_in_differ_layers[23], decoding_vars_in_differ_layers[24],
decoding_vars_in_differ_layers[25],
embedding_table, decoding_vars[-2], tf.cast(
decoding_vars[-1], dtype=tf.float32),
batch_size=decoding_args.decoder_args.batch_size, beam_width=decoding_args.decoder_args.beam_width,
decoding_vars_in_differ_layers[0], decoding_vars_in_differ_layers[1],
decoding_vars_in_differ_layers[2], decoding_vars_in_differ_layers[3],
decoding_vars_in_differ_layers[4], decoding_vars_in_differ_layers[5],
decoding_vars_in_differ_layers[6], decoding_vars_in_differ_layers[7],
decoding_vars_in_differ_layers[8], decoding_vars_in_differ_layers[9],
decoding_vars_in_differ_layers[10], decoding_vars_in_differ_layers[11],
decoding_vars_in_differ_layers[12], decoding_vars_in_differ_layers[13],
decoding_vars_in_differ_layers[14], decoding_vars_in_differ_layers[15],
decoding_vars_in_differ_layers[16], decoding_vars_in_differ_layers[17],
decoding_vars_in_differ_layers[18], decoding_vars_in_differ_layers[19],
decoding_vars_in_differ_layers[20], decoding_vars_in_differ_layers[21],
decoding_vars_in_differ_layers[22], decoding_vars_in_differ_layers[23],
decoding_vars_in_differ_layers[24], decoding_vars_in_differ_layers[25],
decoding_vars[-4], decoding_vars[-3], embedding_table,
decoding_vars[-2], tf.cast(decoding_vars[-1], dtype=tf.float32),
batch_size=decoding_args.decoder_args.batch_size,
beam_width=decoding_args.decoder_args.beam_width,
max_seq_len=decoding_args.decoder_args.max_seq_len,
head_num=decoding_args.decoder_args.head_num, size_per_head=decoding_args.decoder_args.size_per_head,
head_num=decoding_args.decoder_args.head_num,
size_per_head=decoding_args.decoder_args.size_per_head,
num_layer=decoding_args.decoder_args.num_layer,
memory_hidden_dim=decoding_args.encoder_hidden_dim, vocab_size=decoding_args.vocab_size,
memory_hidden_dim=decoding_args.encoder_hidden_dim,
vocab_size=decoding_args.vocab_size,
start_id=decoding_args.start_id, end_id=decoding_args.end_id
)
parent_ids = parent_ids % decoding_args.decoder_args.beam_width
@ -175,6 +182,9 @@ def op_decoding(memory_tensor,
decoding_args.end_id,
decoding_args.decoder_args.max_seq_len)
finalized_sequence_lengths = tf.minimum(
finalized_sequence_lengths + 1, tf.shape(finalized_output_ids)[2])
return finalized_output_ids, finalized_sequence_lengths, output_ids, parent_ids, sequence_lengths
@ -187,7 +197,7 @@ def tf_decoding(memory_tensor,
bias_initializer_range,
atol_threshold=1e-6):
with tf.variable_scope("transformer/decoder", reuse=tf.AUTO_REUSE):
with tf.variable_scope("transformer/decoding", reuse=tf.AUTO_REUSE):
# copy memory and memory_sequence_length by beam_width times
# if memory is [a, b, c], beam_width = 3, then the result is: [a a a b b b c c c ]
extended_memory = tf.contrib.seq2seq.tile_batch(
@ -203,37 +213,47 @@ def tf_decoding(memory_tensor,
inputs = tf.nn.embedding_lookup(embedding_table, word_ids)
# [batch_size * beam_width, 1, hidden_dim]
inputs = tf.expand_dims(inputs, 1)
inputs *= decoding_args.decoder_args.hidden_dim**0.5
position_encoder = SinusoidalPositionEncoder()
if position_encoder is not None:
inputs = position_encoder(
inputs, position=step + 1 if step is not None else None)
with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
tf_result = tf_decoder(decoder_args=decoding_args.decoder_args,
inputs=inputs,
memory=extended_memory,
memory_sequence_length=extended_memory_sequence_length,
step=step,
cache=my_cache,
kernel_initializer_range=kernel_initializer_range,
bias_initializer_range=bias_initializer_range)
tf_result = tf_decoder(decoder_args=decoding_args.decoder_args,
inputs=inputs,
memory=extended_memory,
memory_sequence_length=extended_memory_sequence_length,
step=step,
cache=my_cache,
kernel_initializer_range=kernel_initializer_range,
bias_initializer_range=bias_initializer_range)
decoder_vars = tf.global_variables()
decoder_vars_start_id = 0
while decoder_vars_start_id < len(decoder_vars):
if decoder_vars[decoder_vars_start_id].name.find("transformer/decoder") != -1:
break
decoder_vars_start_id += 1
decoder_vars = decoder_vars[decoder_vars_start_id:]
psuedo_input = []
if decoder_type == 2:
psuedo_input = tf_result
if decoder_type != 0:
decoder_vars = tf.global_variables()
decoder_vars_start_id = 0
while decoder_vars_start_id < len(decoder_vars):
if decoder_vars[decoder_vars_start_id].name.find("transformer/decoding/decoder") != -1:
break
decoder_vars_start_id += 1
decoder_vars = decoder_vars[decoder_vars_start_id:]
op_result, op_self_cache, op_mem_cache = op_decoder(inputs,
step,
extended_memory,
extended_memory_sequence_length,
op_self_cache,
op_mem_cache,
psuedo_input,
decoder_vars,
decoding_args.decoder_args,
decoding_args.encoder_hidden_dim)
psuedo_input = []
if decoder_type == 2:
psuedo_input = tf_result
op_result, op_self_cache, op_mem_cache = op_decoder(inputs,
step,
extended_memory,
extended_memory_sequence_length,
op_self_cache,
op_mem_cache,
psuedo_input,
decoder_vars,
decoding_args.decoder_args,
decoding_args.encoder_hidden_dim)
result = None
if decoder_type == 0:
@ -243,13 +263,20 @@ def tf_decoding(memory_tensor,
elif decoder_type == 2:
result = tf_result
result_2 = op_result
diff = tf.math.reduce_max(tf.math.abs(result - result_2))
result = tf.Print(result, ["[INFO][PYTHON] step:", step, "max diff: ",
diff, tf.cond(diff < atol_threshold, lambda: "True", lambda: "False")])
flatten_result = tf.reshape(result, [-1])
flatten_result_2 = tf.reshape(result_2, [-1])
abs_diff = tf.math.abs(flatten_result - flatten_result_2)
argmax = tf.math.argmax(abs_diff)
result = tf.Print(result, ["[INFO][PYTHON] step:", step, "max diff: ", abs_diff[argmax],
" op val: ", flatten_result_2[argmax],
" tf val: ", flatten_result[argmax],
tf.cond(abs_diff[argmax] < atol_threshold, lambda: "True", lambda: "False")])
else:
print("[TF][ERROR] decoder type is only 0 or 1 or 2.")
exit(-1)
result = tf.contrib.layers.layer_norm(result, begin_norm_axis=-1)
# [batch_size * beam_width, hidden_dim]
result = tf.squeeze(result, axis=1)
logits = tf.layers.dense(result,
@ -340,8 +367,7 @@ def tf_decoding(memory_tensor,
tf_parent_ids,
tf_sequence_lengths,
tf_output_ids,
decoding_args.end_id,
decoding_args.decoder_args.max_seq_len)
decoding_args.end_id)
finalized_tf_output_ids = tf.cast(
finalized_tf_output_ids, start_ids.dtype)

View file

@ -0,0 +1,161 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import tensorflow as tf
import numpy as np
import os
import math
import six
from datetime import datetime
import sys
import time
import pickle
if len(sys.argv) != 2:
print("[ERROR] dump_pruned_model.py needs a ckpt file as input. \n e.g. python dump_pruned_model.py model.ckpt")
sys.exit(0)
ckpt_name = sys.argv[1]
with tf.Session() as sess:
saver = tf.train.import_meta_graph(ckpt_name + ".meta")
saver.restore(sess, (ckpt_name))
def dumpModel_new():
all_variables = tf.trainable_variables()
ckpt = {}
for i, var in enumerate(all_variables):
print("[INFO] %d/%d" %(i, len(all_variables)), end='\r')
sys.stdout.flush()
if var in tf.trainable_variables():
val = sess.run(var)
name = None
if var.name.find("Adam") != -1:
continue
elif var.name.find("encoder") != -1:
# transformer/encoder/layer_x/multi_head/conv1d/kernel:0 -> transformer/encoder/layer_x/attention/self/query, key, value/kernel:0
# transformer/encoder/layer_x/multi_head/conv1d_1/kernel:0 -> transformer/encoder/layer_x/attention/output/kernel:0
# transformer/encoder/layer_x/multi_head/LayerNorm/gamma:0 -> transformer/encoder/layer_x/attention/output/LayerNorm/gamma:0
if var.name.find("multi_head/conv1d/") != -1:
dim = val.shape[-1] / 3
Q, K, V = np.split(val, [dim, dim * 2], axis=-1)
ckpt[var.name.replace("multi_head/conv1d/", "attention/self/query/")] = Q
ckpt[var.name.replace("multi_head/conv1d/", "attention/self/key/")] = K
ckpt[var.name.replace("multi_head/conv1d/", "attention/self/value/")] = V
elif var.name.find("multi_head/conv1d_1/") != -1:
name = var.name.replace("multi_head/conv1d_1/", "attention/output/")
ckpt[name] = val
elif var.name.find("multi_head/LayerNorm/") != -1:
name = var.name.replace("multi_head/LayerNorm/", "attention/output/LayerNorm/")
ckpt[name] = val
# transformer/encoder/layer_x/ffn/conv1d/kernel:0 -> transformer/encoder/layer_x/intermediate/dense/kernel:0
# transformer/encoder/layer_x/ffn/LayerNorm/beta:0 -> transformer/encoder/layer_x/output/LayerNorm/beta:0
# transformer/encoder/layer_x/ffn/conv1d_1/kernel:0 -> transformer/encoder/layer_x/output/dense/kernel:0
elif var.name.find("ffn/conv1d/") != -1:
name = var.name.replace("ffn/conv1d/", "intermediate/dense/")
ckpt[name] = val
elif var.name.find("ffn/LayerNorm/") != -1:
name = var.name.replace("ffn/LayerNorm/", "output/LayerNorm/")
ckpt[name] = val
elif var.name.find("ffn/conv1d_1/") != -1:
name = var.name.replace("ffn/conv1d_1/", "output/dense/")
ckpt[name] = val
elif var.name.find("transformer/encoder/w_embs") != -1:
name = var.name
ckpt[name] = val
elif var.name.find("decoder") != -1:
pre_name = var.name.replace("decoder", "decoding/decoder")
# transformer/decoder/layer_x/masked_multi_head/conv1d/kernel:0 -> transformer/decoder/layer_x/masked_multi_head/query, key, value/kernel:0
# transformer/decoder/layer_x/masked_multi_head/conv1d_1/kernel:0 -> transformer/decoder/layer_x/masked_multi_head/conv1d/kernel:0
# transformer/decoder/layer_x/masked_multi_head/LayerNorm/gamma:0 -> transformer/decoder/layer_x/masked_multi_head/LayerNorm/gamma:0
if var.name.find("masked_multi_head/conv1d/") != -1:
dim = val.shape[-1] / 3
Q, K, V = np.split(val, [dim, dim * 2], axis=-1)
ckpt[pre_name.replace("masked_multi_head/conv1d/", "masked_multi_head/query/")] = Q
ckpt[pre_name.replace("masked_multi_head/conv1d/", "masked_multi_head/key/")] = K
ckpt[pre_name.replace("masked_multi_head/conv1d/", "masked_multi_head/value/")] = V
elif var.name.find("masked_multi_head/conv1d_1/") != -1:
name = pre_name.replace("masked_multi_head/conv1d_1/", "masked_multi_head/conv1d/")
ckpt[name] = val
elif var.name.find("masked_multi_head/LayerNorm/") != -1:
name = pre_name
ckpt[name] = val
# transformer/decoder/layer_x/multi_head/conv1d/kernel:0 -> transformer/decoder/layer_x/multi_head/query/kernel:0
# transformer/decoder/layer_x/multi_head/conv1d_1/kernel:0 -> transformer/decoder/layer_x/multi_head/key, value/kernel:0
# transformer/decoder/layer_x/multi_head/conv1d_2/kernel:0 -> transformer/decoder/layer_x/multi_head/conv1d/kernel
# transformer/decoder/layer_x/multi_head/LayerNorm/gamma:0 -> transformer/decoder/layer_x/multi_head/LayerNorm/gamma:0
elif var.name.find("multi_head/conv1d/") != -1:
name = pre_name.replace("multi_head/conv1d/", "multi_head/query/")
ckpt[name] = val
elif var.name.find("multi_head/conv1d_1/") != -1:
dim = val.shape[-1] / 2
K, V = np.split(val, [dim], axis=-1)
ckpt[pre_name.replace("multi_head/conv1d_1/", "multi_head/key/")] = K
ckpt[pre_name.replace("multi_head/conv1d_1/", "multi_head/value/")] = V
elif var.name.find("multi_head/conv1d_2/") != -1:
name = pre_name.replace("multi_head/conv1d_2/", "multi_head/conv1d/")
ckpt[name] = val
elif var.name.find("multi_head/LayerNorm/") != -1:
name = pre_name
ckpt[name] = val
# transformer/decoder/layer_x/ffn/conv1d/kernel:0 -> transformer/decoder/layer_x/intermediate/dense/kernel:0
# transformer/decoder/layer_x/ffn/LayerNorm/beta:0 -> transformer/decoder/layer_x/output/LayerNorm/beta:0
# transformer/decoder/layer_x/ffn/conv1d_1/kernel:0 -> transformer/decoder/layer_x/output/dense/kernel:0
elif var.name.find("ffn/conv1d/") != -1:
# name = var.name.replace("ffn/conv1d/", "intermediate/dense/")
name = pre_name
ckpt[name] = val
elif var.name.find("ffn/LayerNorm/") != -1:
# name = var.name.replace("ffn/LayerNorm/", "output/LayerNorm/")
name = pre_name
ckpt[name] = val
elif var.name.find("ffn/conv1d_1/") != -1:
# name = var.name.replace("ffn/conv1d_1/", "output/dense/")
name = pre_name
ckpt[name] = val
elif var.name.find("transformer/decoder/w_embs") != -1:
name = var.name.replace("decoder", "decoding")
ckpt[name] = val
elif var.name.find("transformer/decoder/dense/") != -1:
name = var.name.replace("decoder", "decoding")
ckpt[name] = val
elif var.name.find("transformer/decoder/LayerNorm/") != -1:
name = var.name.replace("decoder", "decoding")
ckpt[name] = val
if name != None:
print("[INFO] {} -> {} ".format(var.name, name))
for key in ckpt:
print(key)
with open('model.pkl', 'wb') as f:
pickle.dump(ckpt, f, 0)
dumpModel_new()

View file

@ -18,7 +18,7 @@ import math
import six
import os
from common import create_initializer
def gelu(x):
cdf = 0.5 * (1.0 + tf.tanh(
@ -156,7 +156,8 @@ def tf_encoder(input_tensor,
attention_mask=None,
intermediate_act_fn=gelu,
initializer_range=0.02):
intermediate_size = encoder_args.hidden_dim * 4
if encoder_args.hidden_dim % encoder_args.head_num != 0:
raise ValueError(
@ -299,7 +300,6 @@ def op_encoder(inputs,
encoder_args,
encoder_vars,
attention_mask):
transformer_op_module = tf.load_op_library(
os.path.join('./lib/libtf_fastertransformer.so'))
for layer_idx in range(encoder_args.num_layer):

View file

@ -43,14 +43,14 @@ class PositionEncoder(tf.keras.layers.Layer):
"""
batch_size = tf.shape(inputs)[0]
timesteps = tf.shape(inputs)[1]
input_dim = inputs.shape[-1]
input_dim = inputs.get_shape().as_list()[-1] # return int
positions = tf.range(timesteps) + 1 if position is None else [position]
position_encoding = self._encode([positions], input_dim)
position_encoding = self._encode([positions], input_dim, dtype=inputs.dtype)
position_encoding = tf.tile(position_encoding, [batch_size, 1, 1])
return self.reducer([inputs, position_encoding])
@abc.abstractmethod
def _encode(self, positions, depth):
def _encode(self, positions, depth, dtype):
"""Creates position encodings.
Args:
positions: The positions to encode of shape :math:`[B, ...]`.
@ -66,7 +66,7 @@ class SinusoidalPositionEncoder(PositionEncoder):
https://arxiv.org/abs/1706.03762.
"""
def _encode(self, positions, depth):
def _encode(self, positions, depth, dtype):
if depth % 2 != 0:
raise ValueError("SinusoidalPositionEncoder expects the depth to be divisble "
"by 2 but got %d" % depth)
@ -74,13 +74,13 @@ class SinusoidalPositionEncoder(PositionEncoder):
batch_size = tf.shape(positions)[0]
positions = tf.cast(positions, tf.float32)
log_timescale_increment = math.log(10000) / ((int)(depth / 2 - 1))
log_timescale_increment = math.log(10000) / (depth / 2 - 1)
inv_timescales = tf.exp(
tf.cast(tf.range(depth / 2), tf.float32) * -log_timescale_increment)
tf.cast(tf.range(depth / 2), dtype=tf.float32) * -log_timescale_increment)
inv_timescales = tf.reshape(
tf.tile(inv_timescales, [batch_size]), [batch_size, -1])
scaled_time = tf.expand_dims(
positions, -1) * tf.expand_dims(inv_timescales, 1)
encoding = tf.concat(
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
return tf.cast(encoding, self.dtype)
return tf.cast(encoding, dtype)

View file

@ -0,0 +1,22 @@
# Clone the OpenNMT-tf repo
git clone https://github.com/OpenNMT/OpenNMT-tf -b r1
cp OpenNMT-tf/opennmt/ . -r
rm OpenNMT-tf -r
# Download the vocabulary and test data
# wget https://s3.amazonaws.com/opennmt-trainingdata/wmt_ende_sp.tar.gz
# Download the pretrained model
wget https://s3.amazonaws.com/opennmt-models/averaged-ende-ckpt500k.tar.gz
mkdir translation
mkdir translation/ckpt
# mkdir translation/data
# tar xf wmt_ende_sp.tar.gz -C translation/data
tar xf averaged-ende-ckpt500k.tar.gz -C translation/ckpt
rm wmt_ende_sp.tar.gz averaged-ende-ckpt500k.tar.gz
# head -n 5 translation/data/test.en > test.en
# head -n 5 translation/data/test.de > test.de
# convert the pretrained model to fit our model structure
python utils/dump_model.py translation/ckpt/model.ckpt-500000

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff