[FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. (#640)
* [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder.
This commit is contained in:
parent
280e75c63e
commit
1aa6813450
|
@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2
|
|||
|
||||
bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override
|
||||
{
|
||||
return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
|
||||
return type == TransformerTrtTraits<T>::DataType && format == PluginFormat::kNCHW;
|
||||
}
|
||||
|
||||
void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim,
|
||||
int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override
|
||||
{
|
||||
assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
|
||||
assert(dataType == TransformerTrtTraits<T>::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW);
|
||||
assert(nInputDim == 2);
|
||||
assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
|
||||
assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);
|
||||
|
|
|
@ -104,7 +104,7 @@ class TRT_Transformer
|
|||
|
||||
builder->setMaxBatchSize(batch_size_);
|
||||
builder->setMaxWorkspaceSize(1 << 20);
|
||||
builder->setFp16Mode(false);
|
||||
builder->setFp16Mode(sizeof(T) == 2);
|
||||
|
||||
engine_ = builder->buildCudaEngine(*network);
|
||||
assert(engine_);
|
||||
|
|
Loading…
Reference in a new issue