From 1aa68134507a4fe65317cef2f398d4af0b6d1e8c Mon Sep 17 00:00:00 2001 From: BO-YANG HSUEH Date: Thu, 6 Aug 2020 20:15:49 +0800 Subject: [PATCH] [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. (#640) * [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. --- .../fastertransformer/trt_plugin/bert_transformer_plugin.h | 4 ++-- .../v2.1/fastertransformer/trt_plugin/trt_model.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/FasterTransformer/v2.1/fastertransformer/trt_plugin/bert_transformer_plugin.h b/FasterTransformer/v2.1/fastertransformer/trt_plugin/bert_transformer_plugin.h index 40d532b8..c04b24ae 100644 --- a/FasterTransformer/v2.1/fastertransformer/trt_plugin/bert_transformer_plugin.h +++ b/FasterTransformer/v2.1/fastertransformer/trt_plugin/bert_transformer_plugin.h @@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2 bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override { - return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW; + return type == TransformerTrtTraits::DataType && format == PluginFormat::kNCHW; } void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim, int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override { - assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW); + assert(dataType == TransformerTrtTraits::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW); assert(nInputDim == 2); assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_); assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_); diff --git a/FasterTransformer/v2.1/fastertransformer/trt_plugin/trt_model.h b/FasterTransformer/v2.1/fastertransformer/trt_plugin/trt_model.h index ea28972d..a2cfdfcf 100644 --- a/FasterTransformer/v2.1/fastertransformer/trt_plugin/trt_model.h +++ b/FasterTransformer/v2.1/fastertransformer/trt_plugin/trt_model.h @@ -104,7 +104,7 @@ class TRT_Transformer builder->setMaxBatchSize(batch_size_); builder->setMaxWorkspaceSize(1 << 20); - builder->setFp16Mode(false); + builder->setFp16Mode(sizeof(T) == 2); engine_ = builder->buildCudaEngine(*network); assert(engine_);