1188 lines
48 KiB
C++
1188 lines
48 KiB
C++
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#define HAVE_CUDA 1 // Loading Kaldi headers with GPU
|
|
|
|
#include <triton/backend/backend_common.h>
|
|
|
|
#include <cfloat>
|
|
#include <chrono>
|
|
#include <sstream>
|
|
#include <thread>
|
|
|
|
#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h"
|
|
#include "fstext/fstext-lib.h"
|
|
#include "kaldi-backend-utils.h"
|
|
#include "lat/kaldi-lattice.h"
|
|
#include "lat/lattice-functions.h"
|
|
#include "nnet3/am-nnet-simple.h"
|
|
#include "nnet3/nnet-utils.h"
|
|
#include "util/kaldi-thread.h"
|
|
|
|
using kaldi::BaseFloat;
|
|
|
|
namespace ni = triton::common;
|
|
namespace nib = triton::backend;
|
|
|
|
namespace {
|
|
|
|
#define RESPOND_AND_RETURN_IF_ERROR(REQUEST, X) \
|
|
do { \
|
|
TRITONSERVER_Error* rarie_err__ = (X); \
|
|
if (rarie_err__ != nullptr) { \
|
|
TRITONBACKEND_Response* rarie_response__ = nullptr; \
|
|
LOG_IF_ERROR(TRITONBACKEND_ResponseNew(&rarie_response__, REQUEST), \
|
|
"failed to create response"); \
|
|
if (rarie_response__ != nullptr) { \
|
|
LOG_IF_ERROR(TRITONBACKEND_ResponseSend( \
|
|
rarie_response__, \
|
|
TRITONSERVER_RESPONSE_COMPLETE_FINAL, rarie_err__), \
|
|
"failed to send error response"); \
|
|
} \
|
|
TRITONSERVER_ErrorDelete(rarie_err__); \
|
|
return; \
|
|
} \
|
|
} while (false)
|
|
|
|
#define RESPOND_FACTORY_AND_RETURN_IF_ERROR(FACTORY, X) \
|
|
do { \
|
|
TRITONSERVER_Error* rfarie_err__ = (X); \
|
|
if (rfarie_err__ != nullptr) { \
|
|
TRITONBACKEND_Response* rfarie_response__ = nullptr; \
|
|
LOG_IF_ERROR( \
|
|
TRITONBACKEND_ResponseNewFromFactory(&rfarie_response__, FACTORY), \
|
|
"failed to create response"); \
|
|
if (rfarie_response__ != nullptr) { \
|
|
LOG_IF_ERROR(TRITONBACKEND_ResponseSend( \
|
|
rfarie_response__, \
|
|
TRITONSERVER_RESPONSE_COMPLETE_FINAL, rfarie_err__), \
|
|
"failed to send error response"); \
|
|
} \
|
|
TRITONSERVER_ErrorDelete(rfarie_err__); \
|
|
return; \
|
|
} \
|
|
} while (false)
|
|
|
|
//
|
|
// ResponseOutput
|
|
//
|
|
// Bit flags for desired response outputs
|
|
//
|
|
enum ResponseOutput {
|
|
kResponseOutputRawLattice = 1 << 0,
|
|
kResponseOutputText = 1 << 1,
|
|
kResponseOutputCTM = 1 << 2
|
|
};
|
|
|
|
//
|
|
// ModelParams
|
|
//
|
|
// The parameters parsed from the model configuration.
|
|
//
|
|
struct ModelParams {
|
|
// Model paths
|
|
std::string nnet3_rxfilename;
|
|
std::string fst_rxfilename;
|
|
std::string word_syms_rxfilename;
|
|
std::string lattice_postprocessor_rxfilename;
|
|
|
|
// Filenames
|
|
std::string config_filename;
|
|
|
|
uint64_t max_batch_size;
|
|
int num_channels;
|
|
int num_worker_threads;
|
|
|
|
int use_tensor_cores;
|
|
float beam;
|
|
float lattice_beam;
|
|
int max_active;
|
|
int frame_subsampling_factor;
|
|
float acoustic_scale;
|
|
int main_q_capacity;
|
|
int aux_q_capacity;
|
|
|
|
int chunk_num_bytes;
|
|
int chunk_num_samps;
|
|
};
|
|
|
|
//
|
|
// ModelState
|
|
//
|
|
// State associated with a model that is using this backend. An object
|
|
// of this class is created and associated with each
|
|
// TRITONBACKEND_Model.
|
|
//
|
|
class ModelState {
|
|
public:
|
|
static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model,
|
|
ModelState** state);
|
|
|
|
// Get the handle to the TRITONBACKEND model.
|
|
TRITONBACKEND_Model* TritonModel() { return triton_model_; }
|
|
|
|
// Validate and parse the model configuration
|
|
TRITONSERVER_Error* ValidateModelConfig();
|
|
|
|
// Obtain the parameters parsed from the model configuration
|
|
const ModelParams* Parameters() { return &model_params_; }
|
|
|
|
private:
|
|
ModelState(TRITONBACKEND_Model* triton_model,
|
|
ni::TritonJson::Value&& model_config);
|
|
|
|
TRITONBACKEND_Model* triton_model_;
|
|
ni::TritonJson::Value model_config_;
|
|
|
|
ModelParams model_params_;
|
|
};
|
|
|
|
TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model,
|
|
ModelState** state) {
|
|
TRITONSERVER_Message* config_message;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelConfig(
|
|
triton_model, 1 /* config_version */, &config_message));
|
|
|
|
const char* buffer;
|
|
size_t byte_size;
|
|
RETURN_IF_ERROR(
|
|
TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size));
|
|
|
|
ni::TritonJson::Value model_config;
|
|
TRITONSERVER_Error* err = model_config.Parse(buffer, byte_size);
|
|
RETURN_IF_ERROR(TRITONSERVER_MessageDelete(config_message));
|
|
RETURN_IF_ERROR(err);
|
|
|
|
*state = new ModelState(triton_model, std::move(model_config));
|
|
return nullptr; // success
|
|
}
|
|
|
|
ModelState::ModelState(TRITONBACKEND_Model* triton_model,
|
|
ni::TritonJson::Value&& model_config)
|
|
: triton_model_(triton_model), model_config_(std::move(model_config)) {}
|
|
|
|
TRITONSERVER_Error* ModelState::ValidateModelConfig() {
|
|
// We have the json DOM for the model configuration...
|
|
ni::TritonJson::WriteBuffer buffer;
|
|
RETURN_AND_LOG_IF_ERROR(model_config_.PrettyWrite(&buffer),
|
|
"failed to pretty write model configuration");
|
|
LOG_MESSAGE(
|
|
TRITONSERVER_LOG_VERBOSE,
|
|
(std::string("model configuration:\n") + buffer.Contents()).c_str());
|
|
|
|
RETURN_AND_LOG_IF_ERROR(model_config_.MemberAsUInt(
|
|
"max_batch_size", &model_params_.max_batch_size),
|
|
"failed to get max batch size");
|
|
|
|
ni::TritonJson::Value batcher;
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_config_.Find("sequence_batching", &batcher),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("model configuration must configure sequence batcher"));
|
|
ni::TritonJson::Value control_inputs;
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
batcher.MemberAsArray("control_input", &control_inputs),
|
|
"failed to read control input array");
|
|
std::set<std::string> control_input_names;
|
|
for (uint32_t i = 0; i < control_inputs.ArraySize(); i++) {
|
|
ni::TritonJson::Value control_input;
|
|
RETURN_AND_LOG_IF_ERROR(control_inputs.IndexAsObject(i, &control_input),
|
|
"failed to get control input");
|
|
std::string control_input_name;
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
control_input.MemberAsString("name", &control_input_name),
|
|
"failed to get control input name");
|
|
control_input_names.insert(control_input_name);
|
|
}
|
|
|
|
RETURN_ERROR_IF_FALSE(
|
|
(control_input_names.erase("START") && control_input_names.erase("END") &&
|
|
control_input_names.erase("CORRID") &&
|
|
control_input_names.erase("READY")),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("missing control input names in the model configuration"));
|
|
|
|
// Check the Model Transaction Policy
|
|
ni::TritonJson::Value txn_policy;
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_config_.Find("model_transaction_policy", &txn_policy),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("model configuration must specify a transaction policy"));
|
|
bool is_decoupled;
|
|
RETURN_AND_LOG_IF_ERROR(txn_policy.MemberAsBool("decoupled", &is_decoupled),
|
|
"failed to read the decouled txn policy");
|
|
RETURN_ERROR_IF_FALSE(
|
|
is_decoupled, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("model configuration must use decoupled transaction policy"));
|
|
|
|
// Check the Inputs and Outputs
|
|
ni::TritonJson::Value inputs, outputs;
|
|
RETURN_AND_LOG_IF_ERROR(model_config_.MemberAsArray("input", &inputs),
|
|
"failed to read input array");
|
|
RETURN_AND_LOG_IF_ERROR(model_config_.MemberAsArray("output", &outputs),
|
|
"failed to read output array");
|
|
|
|
// There must be 2 inputs and 3 outputs.
|
|
RETURN_ERROR_IF_FALSE(inputs.ArraySize() == 2, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected 2 inputs, got ") +
|
|
std::to_string(inputs.ArraySize()));
|
|
RETURN_ERROR_IF_FALSE(outputs.ArraySize() == 3,
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected 3 outputs, got ") +
|
|
std::to_string(outputs.ArraySize()));
|
|
|
|
// Here we rely on the model configuation listing the inputs and
|
|
// outputs in a specific order, which we shouldn't really require...
|
|
// TODO use sets and loops
|
|
ni::TritonJson::Value in0, in1, out0, out1, out2;
|
|
RETURN_AND_LOG_IF_ERROR(inputs.IndexAsObject(0, &in0),
|
|
"failed to get the first input");
|
|
RETURN_AND_LOG_IF_ERROR(inputs.IndexAsObject(1, &in1),
|
|
"failed to get the second input");
|
|
RETURN_AND_LOG_IF_ERROR(outputs.IndexAsObject(0, &out0),
|
|
"failed to get the first output");
|
|
RETURN_AND_LOG_IF_ERROR(outputs.IndexAsObject(1, &out1),
|
|
"failed to get the second output");
|
|
RETURN_AND_LOG_IF_ERROR(outputs.IndexAsObject(2, &out2),
|
|
"failed to get the third output");
|
|
|
|
// Check tensor names
|
|
std::string in0_name, in1_name, out0_name, out1_name, out2_name;
|
|
RETURN_AND_LOG_IF_ERROR(in0.MemberAsString("name", &in0_name),
|
|
"failed to get the first input name");
|
|
RETURN_AND_LOG_IF_ERROR(in1.MemberAsString("name", &in1_name),
|
|
"failed to get the second input name");
|
|
RETURN_AND_LOG_IF_ERROR(out0.MemberAsString("name", &out0_name),
|
|
"failed to get the first output name");
|
|
RETURN_AND_LOG_IF_ERROR(out1.MemberAsString("name", &out1_name),
|
|
"failed to get the second output name");
|
|
RETURN_AND_LOG_IF_ERROR(out2.MemberAsString("name", &out2_name),
|
|
"failed to get the third output name");
|
|
|
|
RETURN_ERROR_IF_FALSE(
|
|
in0_name == "WAV_DATA", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected first input tensor name to be WAV_DATA, got ") +
|
|
in0_name);
|
|
RETURN_ERROR_IF_FALSE(
|
|
in1_name == "WAV_DATA_DIM", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string(
|
|
"expected second input tensor name to be WAV_DATA_DIM, got ") +
|
|
in1_name);
|
|
RETURN_ERROR_IF_FALSE(
|
|
out0_name == "RAW_LATTICE", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected first output tensor name to be RAW_LATTICE, got ") +
|
|
out0_name);
|
|
RETURN_ERROR_IF_FALSE(
|
|
out1_name == "TEXT", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected second output tensor name to be TEXT, got ") +
|
|
out1_name);
|
|
RETURN_ERROR_IF_FALSE(
|
|
out2_name == "CTM", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected second output tensor name to be CTM, got ") +
|
|
out2_name);
|
|
|
|
// Check shapes
|
|
std::vector<int64_t> in0_shape, in1_shape, out0_shape, out1_shape;
|
|
RETURN_AND_LOG_IF_ERROR(nib::ParseShape(in0, "dims", &in0_shape),
|
|
" first input shape");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ParseShape(in1, "dims", &in1_shape),
|
|
" second input shape");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ParseShape(out0, "dims", &out0_shape),
|
|
" first output shape");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ParseShape(out1, "dims", &out1_shape),
|
|
" second ouput shape");
|
|
|
|
RETURN_ERROR_IF_FALSE(
|
|
in0_shape.size() == 1, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected WAV_DATA shape to have one dimension, got ") +
|
|
nib::ShapeToString(in0_shape));
|
|
RETURN_ERROR_IF_FALSE(
|
|
in0_shape[0] > 0, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected WAV_DATA shape to be greater than 0, got ") +
|
|
nib::ShapeToString(in0_shape));
|
|
model_params_.chunk_num_samps = in0_shape[0];
|
|
model_params_.chunk_num_bytes = model_params_.chunk_num_samps * sizeof(float);
|
|
|
|
RETURN_ERROR_IF_FALSE(
|
|
((in1_shape.size() == 1) && (in1_shape[0] == 1)),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected WAV_DATA_DIM shape to be [1], got ") +
|
|
nib::ShapeToString(in1_shape));
|
|
RETURN_ERROR_IF_FALSE(
|
|
((out0_shape.size() == 1) && (out0_shape[0] == 1)),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected RAW_LATTICE shape to be [1], got ") +
|
|
nib::ShapeToString(out0_shape));
|
|
RETURN_ERROR_IF_FALSE(((out1_shape.size() == 1) && (out1_shape[0] == 1)),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected TEXT shape to be [1], got ") +
|
|
nib::ShapeToString(out1_shape));
|
|
|
|
// Check datatypes
|
|
std::string in0_dtype, in1_dtype, out0_dtype, out1_dtype;
|
|
RETURN_AND_LOG_IF_ERROR(in0.MemberAsString("data_type", &in0_dtype),
|
|
"first input data type");
|
|
RETURN_AND_LOG_IF_ERROR(in1.MemberAsString("data_type", &in1_dtype),
|
|
"second input datatype");
|
|
RETURN_AND_LOG_IF_ERROR(out0.MemberAsString("data_type", &out0_dtype),
|
|
"first output datatype");
|
|
RETURN_AND_LOG_IF_ERROR(out1.MemberAsString("data_type", &out1_dtype),
|
|
"second output datatype");
|
|
|
|
RETURN_ERROR_IF_FALSE(
|
|
in0_dtype == "TYPE_FP32", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected IN datatype to be INT32, got ") + in0_dtype);
|
|
RETURN_ERROR_IF_FALSE(
|
|
in1_dtype == "TYPE_INT32", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected DELAY datatype to be UINT32, got ") + in1_dtype);
|
|
RETURN_ERROR_IF_FALSE(
|
|
out0_dtype == "TYPE_STRING", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected WAIT datatype to be UINT32, got ") + out0_dtype);
|
|
RETURN_ERROR_IF_FALSE(
|
|
out1_dtype == "TYPE_STRING", TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected OUT datatype to be INT32, got ") + out1_dtype);
|
|
|
|
// Validate and set parameters
|
|
ni::TritonJson::Value params;
|
|
RETURN_ERROR_IF_FALSE(
|
|
(model_config_.Find("parameters", ¶ms)),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("missing parameters in the model configuration"));
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "config_filename",
|
|
&(model_params_.config_filename)),
|
|
"config_filename");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "use_tensor_cores",
|
|
&(model_params_.use_tensor_cores)),
|
|
"cuda use tensor cores");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "main_q_capacity",
|
|
&(model_params_.main_q_capacity)),
|
|
"cuda use tensor cores");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "aux_q_capacity",
|
|
&(model_params_.aux_q_capacity)),
|
|
"cuda use tensor cores");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "beam", &(model_params_.beam)), "beam");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "lattice_beam", &(model_params_.lattice_beam)),
|
|
"lattice beam");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "max_active", &(model_params_.max_active)),
|
|
"max active");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "frame_subsampling_factor",
|
|
&(model_params_.frame_subsampling_factor)),
|
|
"frame_subsampling_factor");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "acoustic_scale",
|
|
&(model_params_.acoustic_scale)),
|
|
"acoustic_scale");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "nnet3_rxfilename",
|
|
&(model_params_.nnet3_rxfilename)),
|
|
"nnet3_rxfilename");
|
|
RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "fst_rxfilename",
|
|
&(model_params_.fst_rxfilename)),
|
|
"fst_rxfilename");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "word_syms_rxfilename",
|
|
&(model_params_.word_syms_rxfilename)),
|
|
"word_syms_rxfilename");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "num_worker_threads",
|
|
&(model_params_.num_worker_threads)),
|
|
"num_worker_threads");
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "num_channels", &(model_params_.num_channels)),
|
|
"num_channels");
|
|
|
|
RETURN_AND_LOG_IF_ERROR(
|
|
nib::ReadParameter(params, "lattice_postprocessor_rxfilename",
|
|
&(model_params_.lattice_postprocessor_rxfilename)),
|
|
"(optional) lattice postprocessor config file");
|
|
|
|
model_params_.max_batch_size = std::max<int>(model_params_.max_batch_size, 1);
|
|
model_params_.num_channels = std::max<int>(model_params_.num_channels, 1);
|
|
|
|
// Sanity checks
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_params_.beam > 0, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected parameter \"beam\" to be greater than 0, got ") +
|
|
std::to_string(model_params_.beam));
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_params_.lattice_beam > 0, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string(
|
|
"expected parameter \"lattice_beam\" to be greater than 0, got ") +
|
|
std::to_string(model_params_.lattice_beam));
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_params_.max_active > 0, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string(
|
|
"expected parameter \"max_active\" to be greater than 0, got ") +
|
|
std::to_string(model_params_.max_active));
|
|
RETURN_ERROR_IF_FALSE(model_params_.main_q_capacity >= -1,
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected parameter \"main_q_capacity\" to "
|
|
"be greater than or equal to -1, got ") +
|
|
std::to_string(model_params_.main_q_capacity));
|
|
RETURN_ERROR_IF_FALSE(model_params_.aux_q_capacity >= -1,
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected parameter \"aux_q_capacity\" to "
|
|
"be greater than or equal to -1, got ") +
|
|
std::to_string(model_params_.aux_q_capacity));
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_params_.acoustic_scale > 0, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string(
|
|
"expected parameter \"acoustic_scale\" to be greater than 0, got ") +
|
|
std::to_string(model_params_.acoustic_scale));
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_params_.num_worker_threads >= -1, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("expected parameter \"num_worker_threads\" to be greater "
|
|
"than or equal to -1, got ") +
|
|
std::to_string(model_params_.num_worker_threads));
|
|
RETURN_ERROR_IF_FALSE(
|
|
model_params_.num_channels > 0, TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string(
|
|
"expected parameter \"num_channels\" to be greater than 0, got ") +
|
|
std::to_string(model_params_.num_channels));
|
|
|
|
return nullptr; // success
|
|
}
|
|
|
|
//
|
|
// ModelInstanceState
|
|
//
|
|
// State associated with a model instance. An object of this class is
|
|
// created and associated with each TRITONBACKEND_ModelInstance.
|
|
//
|
|
class ModelInstanceState {
|
|
public:
|
|
static TRITONSERVER_Error* Create(
|
|
ModelState* model_state,
|
|
TRITONBACKEND_ModelInstance* triton_model_instance,
|
|
ModelInstanceState** state);
|
|
~ModelInstanceState();
|
|
|
|
// Get the handle to the TRITONBACKEND model instance.
|
|
TRITONBACKEND_ModelInstance* TritonModelInstance() {
|
|
return triton_model_instance_;
|
|
}
|
|
|
|
// Get the name, kind and device ID of the instance.
|
|
const std::string& Name() const { return name_; }
|
|
TRITONSERVER_InstanceGroupKind Kind() const { return kind_; }
|
|
int32_t DeviceId() const { return device_id_; }
|
|
|
|
// Get the state of the model that corresponds to this instance.
|
|
ModelState* StateForModel() const { return model_state_; }
|
|
|
|
// Initialize this object
|
|
TRITONSERVER_Error* Init();
|
|
|
|
// Initialize kaldi pipeline with this object
|
|
TRITONSERVER_Error* InitializeKaldiPipeline();
|
|
|
|
// Prepares the requests for kaldi pipeline
|
|
void PrepareRequest(TRITONBACKEND_Request* request, uint32_t slot_idx);
|
|
|
|
// Executes the batch on the decoder
|
|
void FlushBatch();
|
|
|
|
// Waits for all pipeline callbacks to complete
|
|
void WaitForLatticeCallbacks();
|
|
|
|
private:
|
|
ModelInstanceState(ModelState* model_state,
|
|
TRITONBACKEND_ModelInstance* triton_model_instance,
|
|
const char* name,
|
|
const TRITONSERVER_InstanceGroupKind kind,
|
|
const int32_t device_id);
|
|
|
|
TRITONSERVER_Error* GetSequenceInput(TRITONBACKEND_Request* request,
|
|
int32_t* start, int32_t* ready,
|
|
int32_t* dim, int32_t* end,
|
|
uint64_t* corr_id,
|
|
const BaseFloat** wave_buffer,
|
|
std::vector<uint8_t>* input_buffer);
|
|
|
|
void DeliverPartialResponse(const std::string& text,
|
|
TRITONBACKEND_ResponseFactory* response_factory,
|
|
uint8_t response_outputs);
|
|
void DeliverResponse(
|
|
std::vector<kaldi::cuda_decoder::CudaPipelineResult>& results,
|
|
uint64_t corr_id, TRITONBACKEND_ResponseFactory* response_factory,
|
|
uint8_t response_outputs);
|
|
void SetPartialOutput(const std::string& text,
|
|
TRITONBACKEND_ResponseFactory* response_factory,
|
|
TRITONBACKEND_Response* response);
|
|
void SetOutput(std::vector<kaldi::cuda_decoder::CudaPipelineResult>& results,
|
|
uint64_t corr_id, const std::string& output_name,
|
|
TRITONBACKEND_ResponseFactory* response_factory,
|
|
TRITONBACKEND_Response* response);
|
|
|
|
void SetOutputBuffer(const std::string& out_bytes,
|
|
TRITONBACKEND_Response* response,
|
|
TRITONBACKEND_Output* response_output);
|
|
|
|
ModelState* model_state_;
|
|
TRITONBACKEND_ModelInstance* triton_model_instance_;
|
|
const std::string name_;
|
|
const TRITONSERVER_InstanceGroupKind kind_;
|
|
const int32_t device_id_;
|
|
|
|
std::mutex partial_resfactory_mu_;
|
|
std::unordered_map<uint64_t,
|
|
std::queue<std::shared_ptr<TRITONBACKEND_ResponseFactory>>>
|
|
partial_responsefactory_;
|
|
std::vector<uint64_t> batch_corr_ids_;
|
|
std::vector<kaldi::SubVector<kaldi::BaseFloat>> batch_wave_samples_;
|
|
std::vector<bool> batch_is_first_chunk_;
|
|
std::vector<bool> batch_is_last_chunk_;
|
|
|
|
BaseFloat sample_freq_, seconds_per_chunk_;
|
|
int chunk_num_bytes_, chunk_num_samps_;
|
|
|
|
// feature_config includes configuration for the iVector adaptation,
|
|
// as well as the basic features.
|
|
kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipelineConfig
|
|
batched_decoder_config_;
|
|
std::unique_ptr<kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline>
|
|
cuda_pipeline_;
|
|
// Maintain the state of some shared objects
|
|
kaldi::TransitionModel trans_model_;
|
|
|
|
kaldi::nnet3::AmNnetSimple am_nnet_;
|
|
fst::SymbolTable* word_syms_;
|
|
|
|
std::vector<uint8_t> byte_buffer_;
|
|
std::vector<std::vector<uint8_t>> wave_byte_buffers_;
|
|
|
|
std::vector<int64_t> output_shape_;
|
|
std::vector<std::string> request_outputs_;
|
|
};
|
|
|
|
TRITONSERVER_Error* ModelInstanceState::Create(
|
|
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance,
|
|
ModelInstanceState** state) {
|
|
const char* instance_name;
|
|
RETURN_IF_ERROR(
|
|
TRITONBACKEND_ModelInstanceName(triton_model_instance, &instance_name));
|
|
|
|
TRITONSERVER_InstanceGroupKind instance_kind;
|
|
RETURN_IF_ERROR(
|
|
TRITONBACKEND_ModelInstanceKind(triton_model_instance, &instance_kind));
|
|
|
|
int32_t instance_id;
|
|
RETURN_IF_ERROR(
|
|
TRITONBACKEND_ModelInstanceDeviceId(triton_model_instance, &instance_id));
|
|
|
|
*state = new ModelInstanceState(model_state, triton_model_instance,
|
|
instance_name, instance_kind, instance_id);
|
|
return nullptr; // success
|
|
}
|
|
|
|
ModelInstanceState::ModelInstanceState(
|
|
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance,
|
|
const char* name, const TRITONSERVER_InstanceGroupKind kind,
|
|
const int32_t device_id)
|
|
: model_state_(model_state),
|
|
triton_model_instance_(triton_model_instance),
|
|
name_(name),
|
|
kind_(kind),
|
|
device_id_(device_id) {}
|
|
|
|
ModelInstanceState::~ModelInstanceState() { delete word_syms_; }
|
|
|
|
TRITONSERVER_Error* ModelInstanceState::Init() {
|
|
const ModelParams* model_params = model_state_->Parameters();
|
|
|
|
chunk_num_samps_ = model_params->chunk_num_samps;
|
|
chunk_num_bytes_ = model_params->chunk_num_bytes;
|
|
|
|
|
|
{
|
|
std::ostringstream usage_str;
|
|
usage_str << "Parsing config from " << "from '" << model_params->config_filename << "'";
|
|
kaldi::ParseOptions po(usage_str.str().c_str());
|
|
batched_decoder_config_.Register(&po);
|
|
po.DisableOption("cuda-decoder-copy-threads");
|
|
po.DisableOption("cuda-worker-threads");
|
|
po.DisableOption("max-active");
|
|
po.DisableOption("max-batch-size");
|
|
po.DisableOption("num-channels");
|
|
po.ReadConfigFile(model_params->config_filename);
|
|
}
|
|
kaldi::CuDevice::EnableTensorCores(bool(model_params->use_tensor_cores));
|
|
|
|
batched_decoder_config_.compute_opts.frame_subsampling_factor =
|
|
model_params->frame_subsampling_factor;
|
|
batched_decoder_config_.compute_opts.acoustic_scale =
|
|
model_params->acoustic_scale;
|
|
batched_decoder_config_.decoder_opts.default_beam = model_params->beam;
|
|
batched_decoder_config_.decoder_opts.lattice_beam =
|
|
model_params->lattice_beam;
|
|
batched_decoder_config_.decoder_opts.max_active = model_params->max_active;
|
|
batched_decoder_config_.num_worker_threads = model_params->num_worker_threads;
|
|
batched_decoder_config_.max_batch_size = model_params->max_batch_size;
|
|
batched_decoder_config_.num_channels = model_params->num_channels;
|
|
batched_decoder_config_.decoder_opts.main_q_capacity =
|
|
model_params->main_q_capacity;
|
|
batched_decoder_config_.decoder_opts.aux_q_capacity =
|
|
model_params->aux_q_capacity;
|
|
|
|
auto feature_config = batched_decoder_config_.feature_opts;
|
|
kaldi::OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
|
|
sample_freq_ = feature_info.mfcc_opts.frame_opts.samp_freq;
|
|
BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
|
|
seconds_per_chunk_ = chunk_num_samps_ / sample_freq_;
|
|
|
|
int samp_per_frame = static_cast<int>(sample_freq_ * frame_shift);
|
|
float n_input_framesf = chunk_num_samps_ / samp_per_frame;
|
|
RETURN_ERROR_IF_FALSE(
|
|
(n_input_framesf == std::floor(n_input_framesf)),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("WAVE_DATA dim must be a multiple fo samples per frame (") +
|
|
std::to_string(samp_per_frame) + std::string(")"));
|
|
int n_input_frames = static_cast<int>(std::floor(n_input_framesf));
|
|
batched_decoder_config_.compute_opts.frames_per_chunk = n_input_frames;
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
TRITONSERVER_Error* ModelInstanceState::InitializeKaldiPipeline() {
|
|
const ModelParams* model_params = model_state_->Parameters();
|
|
|
|
batch_corr_ids_.reserve(model_params->max_batch_size);
|
|
batch_wave_samples_.reserve(model_params->max_batch_size);
|
|
batch_is_first_chunk_.reserve(model_params->max_batch_size);
|
|
batch_is_last_chunk_.reserve(model_params->max_batch_size);
|
|
wave_byte_buffers_.resize(model_params->max_batch_size);
|
|
for (auto& wbb : wave_byte_buffers_) {
|
|
wbb.resize(chunk_num_bytes_);
|
|
}
|
|
output_shape_ = {1, 1};
|
|
kaldi::g_cuda_allocator.SetOptions(kaldi::g_allocator_options);
|
|
kaldi::CuDevice::Instantiate()
|
|
.SelectAndInitializeGpuIdWithExistingCudaContext(device_id_);
|
|
kaldi::CuDevice::Instantiate().AllowMultithreading();
|
|
|
|
// Loading models
|
|
{
|
|
bool binary;
|
|
kaldi::Input ki(model_params->nnet3_rxfilename, &binary);
|
|
trans_model_.Read(ki.Stream(), binary);
|
|
am_nnet_.Read(ki.Stream(), binary);
|
|
|
|
kaldi::nnet3::SetBatchnormTestMode(true, &(am_nnet_.GetNnet()));
|
|
kaldi::nnet3::SetDropoutTestMode(true, &(am_nnet_.GetNnet()));
|
|
kaldi::nnet3::CollapseModel(kaldi::nnet3::CollapseModelConfig(),
|
|
&(am_nnet_.GetNnet()));
|
|
}
|
|
fst::Fst<fst::StdArc>* decode_fst =
|
|
fst::ReadFstKaldiGeneric(model_params->fst_rxfilename);
|
|
cuda_pipeline_.reset(
|
|
new kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline(
|
|
batched_decoder_config_, *decode_fst, am_nnet_, trans_model_));
|
|
delete decode_fst;
|
|
|
|
// Loading word syms for text output
|
|
if (model_params->word_syms_rxfilename != "") {
|
|
RETURN_ERROR_IF_FALSE(
|
|
(word_syms_ =
|
|
fst::SymbolTable::ReadText(model_params->word_syms_rxfilename)),
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("could not read symbol table from file ") +
|
|
model_params->word_syms_rxfilename);
|
|
cuda_pipeline_->SetSymbolTable(*word_syms_);
|
|
}
|
|
|
|
// Load lattice postprocessor, required if using CTM
|
|
if (!model_params->lattice_postprocessor_rxfilename.empty()) {
|
|
LoadAndSetLatticePostprocessor(
|
|
model_params->lattice_postprocessor_rxfilename, cuda_pipeline_.get());
|
|
}
|
|
chunk_num_samps_ = cuda_pipeline_->GetNSampsPerChunk();
|
|
chunk_num_bytes_ = chunk_num_samps_ * sizeof(BaseFloat);
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
TRITONSERVER_Error* ModelInstanceState::GetSequenceInput(
|
|
TRITONBACKEND_Request* request, int32_t* start, int32_t* ready,
|
|
int32_t* dim, int32_t* end, uint64_t* corr_id,
|
|
const BaseFloat** wave_buffer, std::vector<uint8_t>* input_buffer) {
|
|
size_t dim_bsize = sizeof(*dim);
|
|
RETURN_IF_ERROR(nib::ReadInputTensor(
|
|
request, "WAV_DATA_DIM", reinterpret_cast<char*>(dim), &dim_bsize));
|
|
|
|
size_t end_bsize = sizeof(*end);
|
|
RETURN_IF_ERROR(nib::ReadInputTensor(
|
|
request, "END", reinterpret_cast<char*>(end), &end_bsize));
|
|
|
|
size_t start_bsize = sizeof(*start);
|
|
RETURN_IF_ERROR(nib::ReadInputTensor(
|
|
request, "START", reinterpret_cast<char*>(start), &start_bsize));
|
|
|
|
size_t ready_bsize = sizeof(*ready);
|
|
RETURN_IF_ERROR(nib::ReadInputTensor(
|
|
request, "READY", reinterpret_cast<char*>(ready), &ready_bsize));
|
|
|
|
size_t corrid_bsize = sizeof(*corr_id);
|
|
RETURN_IF_ERROR(nib::ReadInputTensor(
|
|
request, "CORRID", reinterpret_cast<char*>(corr_id), &corrid_bsize));
|
|
|
|
// Get pointer to speech tensor
|
|
size_t wavdata_bsize = input_buffer->size();
|
|
RETURN_IF_ERROR(nib::ReadInputTensor(
|
|
request, "WAV_DATA", reinterpret_cast<char*>(input_buffer->data()),
|
|
&wavdata_bsize));
|
|
*wave_buffer = reinterpret_cast<const BaseFloat*>(input_buffer->data());
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
void ModelInstanceState::PrepareRequest(TRITONBACKEND_Request* request,
|
|
uint32_t slot_idx) {
|
|
const ModelParams* model_params = model_state_->Parameters();
|
|
|
|
if (batch_corr_ids_.size() == (uint32_t)model_params->max_batch_size) {
|
|
FlushBatch();
|
|
}
|
|
|
|
int32_t start, dim, end, ready;
|
|
uint64_t corr_id;
|
|
const BaseFloat* wave_buffer;
|
|
|
|
if (slot_idx >= (uint32_t)model_params->max_batch_size) {
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
|
|
"slot_idx exceeded"));
|
|
}
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request, GetSequenceInput(request, &start, &ready, &dim, &end, &corr_id,
|
|
&wave_buffer, &wave_byte_buffers_[slot_idx]));
|
|
|
|
uint32_t output_count;
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request, TRITONBACKEND_RequestOutputCount(request, &output_count));
|
|
|
|
uint8_t response_outputs = 0;
|
|
int kaldi_result_type = 0;
|
|
for (uint32_t index = 0; index < output_count; index++) {
|
|
const char* output_name;
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request, TRITONBACKEND_RequestOutputName(request, index, &output_name));
|
|
std::string output_name_str = output_name;
|
|
if (output_name_str == "RAW_LATTICE") {
|
|
response_outputs |= kResponseOutputRawLattice;
|
|
kaldi_result_type |=
|
|
kaldi::cuda_decoder::CudaPipelineResult::RESULT_TYPE_LATTICE;
|
|
} else if (output_name_str == "TEXT") {
|
|
response_outputs |= kResponseOutputText;
|
|
kaldi_result_type |=
|
|
kaldi::cuda_decoder::CudaPipelineResult::RESULT_TYPE_LATTICE;
|
|
} else if (output_name_str == "CTM") {
|
|
response_outputs |= kResponseOutputCTM;
|
|
kaldi_result_type |=
|
|
kaldi::cuda_decoder::CudaPipelineResult::RESULT_TYPE_CTM;
|
|
} else {
|
|
TRITONSERVER_LogMessage(
|
|
TRITONSERVER_LOG_WARN, __FILE__, __LINE__,
|
|
("unrecognized requested output " + output_name_str).c_str());
|
|
}
|
|
}
|
|
|
|
if (dim > chunk_num_samps_) {
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request,
|
|
TRITONSERVER_ErrorNew(
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
"a chunk cannot contain more samples than the WAV_DATA dimension"));
|
|
}
|
|
|
|
if (!ready) {
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
|
|
"request is not yet ready"));
|
|
}
|
|
|
|
// Initialize corr_id if first chunk
|
|
if (start) {
|
|
if (!cuda_pipeline_->TryInitCorrID(corr_id)) {
|
|
RESPOND_AND_RETURN_IF_ERROR(
|
|
request, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
|
|
"failed to start cuda pipeline"));
|
|
}
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lock_partial_resfactory(
|
|
partial_resfactory_mu_);
|
|
cuda_pipeline_->SetBestPathCallback(
|
|
corr_id, [this, corr_id](const std::string& str, bool partial,
|
|
bool endpoint_detected) {
|
|
// Bestpath callbacks are synchronous in regards to each correlation
|
|
// ID, so the lock is only needed for acquiring a reference to the
|
|
// queue.
|
|
std::unique_lock<std::mutex> lock_partial_resfactory(
|
|
partial_resfactory_mu_);
|
|
auto& resfactory_queue = partial_responsefactory_.at(corr_id);
|
|
if (!partial) {
|
|
if (!endpoint_detected) {
|
|
// while (!resfactory_queue.empty()) {
|
|
// auto response_factory = resfactory_queue.front();
|
|
// resfactory_queue.pop();
|
|
// if (response_factory != nullptr) {
|
|
// LOG_IF_ERROR(
|
|
// TRITONBACKEND_ResponseFactoryDelete(response_factory),
|
|
// "error deleting response factory");
|
|
// }
|
|
// }
|
|
partial_responsefactory_.erase(corr_id);
|
|
}
|
|
return;
|
|
}
|
|
if (resfactory_queue.empty()) {
|
|
TRITONSERVER_LogMessage(
|
|
TRITONSERVER_LOG_WARN, __FILE__, __LINE__,
|
|
"response factory queue unexpectedly empty");
|
|
return;
|
|
}
|
|
|
|
auto response_factory = resfactory_queue.front();
|
|
resfactory_queue.pop();
|
|
lock_partial_resfactory.unlock();
|
|
if (response_factory == nullptr) return;
|
|
|
|
DeliverPartialResponse(str, response_factory.get(),
|
|
kResponseOutputText);
|
|
});
|
|
partial_responsefactory_.emplace(
|
|
corr_id,
|
|
std::queue<std::shared_ptr<TRITONBACKEND_ResponseFactory>>());
|
|
}
|
|
}
|
|
|
|
kaldi::SubVector<BaseFloat> wave_part(wave_buffer, dim);
|
|
|
|
// Add to batch
|
|
batch_corr_ids_.push_back(corr_id);
|
|
batch_wave_samples_.push_back(wave_part);
|
|
batch_is_first_chunk_.push_back(start);
|
|
batch_is_last_chunk_.push_back(end);
|
|
|
|
TRITONBACKEND_ResponseFactory* response_factory_ptr;
|
|
RESPOND_AND_RETURN_IF_ERROR(request, TRITONBACKEND_ResponseFactoryNew(
|
|
&response_factory_ptr, request));
|
|
std::shared_ptr<TRITONBACKEND_ResponseFactory> response_factory(
|
|
response_factory_ptr, [](TRITONBACKEND_ResponseFactory* f) {
|
|
LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(f),
|
|
"failed deleting response factory");
|
|
});
|
|
|
|
if (end) {
|
|
auto segmented_lattice_callback_fn =
|
|
[this, response_factory, response_outputs,
|
|
corr_id](kaldi::cuda_decoder::SegmentedLatticeCallbackParams& params) {
|
|
DeliverResponse(params.results, corr_id, response_factory.get(),
|
|
response_outputs);
|
|
};
|
|
cuda_pipeline_->SetLatticeCallback(corr_id, segmented_lattice_callback_fn,
|
|
kaldi_result_type);
|
|
} else if (response_outputs & kResponseOutputText) {
|
|
std::lock_guard<std::mutex> lock_partial_resfactory(partial_resfactory_mu_);
|
|
auto& resfactory_queue = partial_responsefactory_.at(corr_id);
|
|
resfactory_queue.push(response_factory);
|
|
} else {
|
|
{
|
|
std::lock_guard<std::mutex> lock_partial_resfactory(
|
|
partial_resfactory_mu_);
|
|
auto& resfactory_queue = partial_responsefactory_.at(corr_id);
|
|
resfactory_queue.emplace(nullptr);
|
|
}
|
|
|
|
// Mark the response complete without sending any responses
|
|
LOG_IF_ERROR(
|
|
TRITONBACKEND_ResponseFactorySendFlags(
|
|
response_factory.get(), TRITONSERVER_RESPONSE_COMPLETE_FINAL),
|
|
"failed sending final response");
|
|
}
|
|
}
|
|
|
|
void ModelInstanceState::FlushBatch() {
|
|
if (!batch_corr_ids_.empty()) {
|
|
cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_,
|
|
batch_is_first_chunk_, batch_is_last_chunk_);
|
|
batch_corr_ids_.clear();
|
|
batch_wave_samples_.clear();
|
|
batch_is_first_chunk_.clear();
|
|
batch_is_last_chunk_.clear();
|
|
}
|
|
}
|
|
|
|
void ModelInstanceState::WaitForLatticeCallbacks() {
|
|
cuda_pipeline_->WaitForLatticeCallbacks();
|
|
}
|
|
|
|
void ModelInstanceState::DeliverPartialResponse(
|
|
const std::string& text, TRITONBACKEND_ResponseFactory* response_factory,
|
|
uint8_t response_outputs) {
|
|
if (response_outputs & kResponseOutputText) {
|
|
TRITONBACKEND_Response* response;
|
|
RESPOND_FACTORY_AND_RETURN_IF_ERROR(
|
|
response_factory,
|
|
TRITONBACKEND_ResponseNewFromFactory(&response, response_factory));
|
|
SetPartialOutput(text, response_factory, response);
|
|
LOG_IF_ERROR(TRITONBACKEND_ResponseSend(
|
|
response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr),
|
|
"failed sending response");
|
|
} else {
|
|
LOG_IF_ERROR(TRITONBACKEND_ResponseFactorySendFlags(
|
|
response_factory, TRITONSERVER_RESPONSE_COMPLETE_FINAL),
|
|
"failed to send final flag for partial result");
|
|
}
|
|
}
|
|
|
|
void ModelInstanceState::DeliverResponse(
|
|
std::vector<kaldi::cuda_decoder::CudaPipelineResult>& results,
|
|
uint64_t corr_id, TRITONBACKEND_ResponseFactory* response_factory,
|
|
uint8_t response_outputs) {
|
|
TRITONBACKEND_Response* response;
|
|
RESPOND_FACTORY_AND_RETURN_IF_ERROR(
|
|
response_factory,
|
|
TRITONBACKEND_ResponseNewFromFactory(&response, response_factory));
|
|
if (response_outputs & kResponseOutputRawLattice) {
|
|
SetOutput(results, corr_id, "RAW_LATTICE", response_factory, response);
|
|
}
|
|
if (response_outputs & kResponseOutputText) {
|
|
SetOutput(results, corr_id, "TEXT", response_factory, response);
|
|
}
|
|
if (response_outputs & kResponseOutputCTM) {
|
|
SetOutput(results, corr_id, "CTM", response_factory, response);
|
|
}
|
|
// Send the response.
|
|
LOG_IF_ERROR(
|
|
TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL,
|
|
nullptr /* success */),
|
|
"failed sending response");
|
|
}
|
|
|
|
void ModelInstanceState::SetPartialOutput(
|
|
const std::string& text, TRITONBACKEND_ResponseFactory* response_factory,
|
|
TRITONBACKEND_Response* response) {
|
|
TRITONBACKEND_Output* response_output;
|
|
RESPOND_FACTORY_AND_RETURN_IF_ERROR(
|
|
response_factory, TRITONBACKEND_ResponseOutput(
|
|
response, &response_output, "TEXT",
|
|
TRITONSERVER_TYPE_BYTES, &output_shape_[0], 2));
|
|
SetOutputBuffer(text, response, response_output);
|
|
}
|
|
|
|
void ModelInstanceState::SetOutput(
|
|
std::vector<kaldi::cuda_decoder::CudaPipelineResult>& results,
|
|
uint64_t corr_id, const std::string& output_name,
|
|
TRITONBACKEND_ResponseFactory* response_factory,
|
|
TRITONBACKEND_Response* response) {
|
|
TRITONBACKEND_Output* response_output;
|
|
RESPOND_FACTORY_AND_RETURN_IF_ERROR(
|
|
response_factory,
|
|
TRITONBACKEND_ResponseOutput(response, &response_output,
|
|
output_name.c_str(), TRITONSERVER_TYPE_BYTES,
|
|
&output_shape_[0], 2 /* dims_count */));
|
|
|
|
if (output_name.compare("RAW_LATTICE") == 0) {
|
|
assert(!results.empty());
|
|
kaldi::CompactLattice& clat = results[0].GetLatticeResult();
|
|
|
|
std::ostringstream oss;
|
|
kaldi::WriteCompactLattice(oss, true, clat);
|
|
SetOutputBuffer(oss.str(), response, response_output);
|
|
} else if (output_name.compare("TEXT") == 0) {
|
|
assert(!results.empty());
|
|
kaldi::CompactLattice& clat = results[0].GetLatticeResult();
|
|
std::string output;
|
|
nib::LatticeToString(*word_syms_, clat, &output);
|
|
SetOutputBuffer(output, response, response_output);
|
|
} else if (output_name.compare("CTM") == 0) {
|
|
std::ostringstream oss;
|
|
MergeSegmentsToCTMOutput(results, std::to_string(corr_id), oss, word_syms_,
|
|
/* use segment offset*/ false);
|
|
SetOutputBuffer(oss.str(), response, response_output);
|
|
}
|
|
}
|
|
|
|
void ModelInstanceState::SetOutputBuffer(
|
|
const std::string& out_bytes, TRITONBACKEND_Response* response,
|
|
TRITONBACKEND_Output* response_output) {
|
|
TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU;
|
|
int64_t actual_memory_type_id = 0;
|
|
uint32_t byte_size_with_size_int = out_bytes.size() + sizeof(int32);
|
|
void* obuffer; // output buffer
|
|
auto err = TRITONBACKEND_OutputBuffer(
|
|
response_output, &obuffer, byte_size_with_size_int, &actual_memory_type,
|
|
&actual_memory_type_id);
|
|
if (err != nullptr) {
|
|
RESPOND_AND_SET_NULL_IF_ERROR(&response, err);
|
|
}
|
|
|
|
int32* buffer_as_int = reinterpret_cast<int32*>(obuffer);
|
|
buffer_as_int[0] = out_bytes.size();
|
|
memcpy(&buffer_as_int[1], out_bytes.data(), out_bytes.size());
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/////////////
|
|
|
|
extern "C" {
|
|
|
|
TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) {
|
|
const char* cname;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname));
|
|
std::string name(cname);
|
|
|
|
uint64_t version;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version));
|
|
|
|
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
|
|
(std::string("TRITONBACKEND_ModelInitialize: ") + name +
|
|
" (version " + std::to_string(version) + ")")
|
|
.c_str());
|
|
|
|
ModelState* model_state;
|
|
RETURN_IF_ERROR(ModelState::Create(model, &model_state));
|
|
RETURN_IF_ERROR(
|
|
TRITONBACKEND_ModelSetState(model, reinterpret_cast<void*>(model_state)));
|
|
|
|
RETURN_IF_ERROR(model_state->ValidateModelConfig());
|
|
|
|
return nullptr; // success
|
|
}
|
|
|
|
TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) {
|
|
void* vstate;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate));
|
|
ModelState* model_state = reinterpret_cast<ModelState*>(vstate);
|
|
|
|
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
|
|
"TRITONBACKEND_ModelFinalize: delete model state");
|
|
|
|
delete model_state;
|
|
|
|
return nullptr; // success
|
|
}
|
|
|
|
TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize(
|
|
TRITONBACKEND_ModelInstance* instance) {
|
|
const char* cname;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname));
|
|
std::string name(cname);
|
|
|
|
int32_t device_id;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceDeviceId(instance, &device_id));
|
|
|
|
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
|
|
(std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name +
|
|
" (device " + std::to_string(device_id) + ")")
|
|
.c_str());
|
|
|
|
TRITONBACKEND_Model* model;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model));
|
|
|
|
void* vmodelstate;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate));
|
|
ModelState* model_state = reinterpret_cast<ModelState*>(vmodelstate);
|
|
|
|
ModelInstanceState* instance_state;
|
|
RETURN_IF_ERROR(
|
|
ModelInstanceState::Create(model_state, instance, &instance_state));
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(
|
|
instance, reinterpret_cast<void*>(instance_state)));
|
|
|
|
RETURN_IF_ERROR(instance_state->Init());
|
|
RETURN_IF_ERROR(instance_state->InitializeKaldiPipeline());
|
|
|
|
return nullptr; // success
|
|
}
|
|
|
|
TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize(
|
|
TRITONBACKEND_ModelInstance* instance) {
|
|
void* vstate;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate));
|
|
ModelInstanceState* instance_state =
|
|
reinterpret_cast<ModelInstanceState*>(vstate);
|
|
|
|
LOG_MESSAGE(
|
|
TRITONSERVER_LOG_INFO,
|
|
"TRITONBACKEND_ModelInstanceFinalize: waiting for lattice callbacks");
|
|
instance_state->WaitForLatticeCallbacks();
|
|
|
|
LOG_MESSAGE(TRITONSERVER_LOG_INFO,
|
|
"TRITONBACKEND_ModelInstanceFinalize: delete instance state");
|
|
delete instance_state;
|
|
|
|
return nullptr; // success
|
|
}
|
|
|
|
TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute(
|
|
TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests,
|
|
const uint32_t request_count) {
|
|
ModelInstanceState* instance_state;
|
|
RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(
|
|
instance, reinterpret_cast<void**>(&instance_state)));
|
|
|
|
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
|
|
(std::string("model instance ") + instance_state->Name() +
|
|
", executing " + std::to_string(request_count) + " requests")
|
|
.c_str());
|
|
|
|
RETURN_ERROR_IF_FALSE(
|
|
request_count <=
|
|
instance_state->StateForModel()->Parameters()->max_batch_size,
|
|
TRITONSERVER_ERROR_INVALID_ARG,
|
|
std::string("request count exceeded the provided maximum batch size"));
|
|
|
|
uint64_t exec_start_ns = 0;
|
|
SET_TIMESTAMP(exec_start_ns);
|
|
|
|
// Each request is a chunk for one sequence
|
|
// Using the oldest strategy in the sequence batcher ensures that
|
|
// there will only be a single chunk for each sequence.
|
|
for (uint32_t r = 0; r < request_count; ++r) {
|
|
TRITONBACKEND_Request* request = requests[r];
|
|
instance_state->PrepareRequest(request, r);
|
|
}
|
|
|
|
instance_state->FlushBatch();
|
|
|
|
uint64_t exec_end_ns = 0;
|
|
SET_TIMESTAMP(exec_end_ns);
|
|
|
|
for (uint32_t r = 0; r < request_count; ++r) {
|
|
TRITONBACKEND_Request* request = requests[r];
|
|
LOG_IF_ERROR(
|
|
TRITONBACKEND_ModelInstanceReportStatistics(
|
|
instance_state->TritonModelInstance(), request, true /* success */,
|
|
exec_start_ns, exec_start_ns, exec_end_ns, exec_end_ns),
|
|
"failed reporting request statistics");
|
|
LOG_IF_ERROR(
|
|
TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL),
|
|
"failed releasing request");
|
|
}
|
|
|
|
LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics(
|
|
instance_state->TritonModelInstance(), request_count,
|
|
exec_start_ns, exec_start_ns, exec_end_ns, exec_end_ns),
|
|
"failed reporting batch request statistics");
|
|
|
|
return nullptr; // success
|
|
}
|
|
|
|
} // extern "C"
|