[FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. (#640)

* [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder.
This commit is contained in:
BO-YANG HSUEH 2020-08-06 20:15:49 +08:00 committed by GitHub
parent 280e75c63e
commit 1aa6813450
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View file

@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2
bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override
{
return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
return type == TransformerTrtTraits<T>::DataType && format == PluginFormat::kNCHW;
}
void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim,
int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override
{
assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
assert(dataType == TransformerTrtTraits<T>::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW);
assert(nInputDim == 2);
assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);

View file

@ -104,7 +104,7 @@ class TRT_Transformer
builder->setMaxBatchSize(batch_size_);
builder->setMaxWorkspaceSize(1 << 20);
builder->setFp16Mode(false);
builder->setFp16Mode(sizeof(T) == 2);
engine_ = builder->buildCudaEngine(*network);
assert(engine_);