[Kaldi] Update to 21.08

This commit is contained in:
Alec Kohlhoff 2021-10-12 17:45:31 -07:00
parent 26d8955cc5
commit 1a5c7556b5
30 changed files with 2258 additions and 1195 deletions

View File

@ -2,3 +2,4 @@ data/*
!data/README.md
.*.swp
.*.swo
.clang-format

View File

@ -1,4 +1,4 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -11,12 +11,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb
FROM nvcr.io/nvidia/tritonserver:20.03-py3
ENV DEBIAN_FRONTEND=noninteractive
ARG TRITONSERVER_IMAGE=nvcr.io/nvidia/tritonserver:21.05-py3
ARG KALDI_IMAGE=nvcr.io/nvidia/kaldi:21.08-py3
ARG PYTHON_VER=3.8
#
# Kaldi shared library dependencies
#
FROM ${KALDI_IMAGE} as kaldi
#
# Builder image based on Triton Server SDK image
#
FROM ${TRITONSERVER_IMAGE}-sdk as builder
ARG PYTHON_VER
# Kaldi dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
RUN set -eux; \
apt-get update; \
apt-get install -yq --no-install-recommends \
automake \
autoconf \
cmake \
@ -24,29 +37,80 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
gawk \
libatlas3-base \
libtool \
python3.6 \
python3.6-dev \
python${PYTHON_VER} \
python${PYTHON_VER}-dev \
sox \
subversion \
unzip \
bc \
libatlas-base-dev \
zlib1g-dev
gfortran \
zlib1g-dev; \
rm -rf /var/lib/apt/lists/*
RUN mkdir /opt/trtis-kaldi && mkdir -p /workspace/model-repo/kaldi_online/1 && mkdir -p /mnt/model-repo
# Copying static files
# Add Kaldi dependency
COPY --from=kaldi /opt/kaldi /opt/kaldi
# Set up Atlas
RUN set -eux; \
ln -sf /usr/include/x86_64-linux-gnu/atlas /usr/local/include/atlas; \
ln -sf /usr/include/x86_64-linux-gnu/cblas.h /usr/local/include/cblas.h; \
ln -sf /usr/include/x86_64-linux-gnu/clapack.h /usr/local/include/clapack.h; \
ln -sf /usr/lib/x86_64-linux-gnu/atlas /usr/local/lib/atlas
#
# Kaldi custom backend build
#
FROM builder as backend-build
# Build the custom backend
COPY kaldi-asr-backend /workspace/triton-kaldi-backend
RUN set -eux; \
cd /workspace/triton-kaldi-backend; \
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="$(pwd)/install" \
-B build .; \
cmake --build build --parallel; \
cmake --install build
#
# Final server image
#
FROM ${TRITONSERVER_IMAGE}
ARG PYTHON_VER
# Kaldi dependencies
RUN set -eux; \
apt-get update; \
apt-get install -yq --no-install-recommends \
automake \
autoconf \
cmake \
flac \
gawk \
libatlas3-base \
libtool \
python${PYTHON_VER} \
python${PYTHON_VER}-dev \
sox \
subversion \
unzip \
bc \
libatlas-base-dev \
zlib1g-dev; \
rm -rf /var/lib/apt/lists/*
# Add Kaldi dependency
COPY --from=kaldi /opt/kaldi /opt/kaldi
# Add Kaldi custom backend shared library and scripts
COPY --from=backend-build /workspace/triton-kaldi-backend/install/backends/kaldi/libtriton_kaldi.so /workspace/model-repo/kaldi_online/1/
COPY scripts /workspace/scripts
# Moving Kaldi to container
COPY --from=kb /opt/kaldi /opt/kaldi
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH
# Building the custom backend
COPY trtis-kaldi-backend /workspace/trtis-kaldi-backend
#COPY --from=cbe /workspace/install/custom-backend-sdk /workspace/trtis-kaldi-backend/custom-backend-sdk
RUN cd /workspace/trtis-kaldi-backend && wget https://github.com/NVIDIA/tensorrt-inference-server/releases/download/v1.9.0/v1.9.0_ubuntu1804.custombackend.tar.gz -O custom-backend-sdk.tar.gz && tar -xzf custom-backend-sdk.tar.gz
RUN cd /workspace/trtis-kaldi-backend/ && make && cp libkaldi-trtisbackend.so /workspace/model-repo/kaldi_online/1/ && cd - && rm -r /workspace/trtis-kaldi-backend
COPY scripts/nvidia_kaldi_trtis_entrypoint.sh /opt/trtis-kaldi
ENTRYPOINT ["/opt/trtis-kaldi/nvidia_kaldi_trtis_entrypoint.sh"]
# Setup entrypoint and environment
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:/opt/tritonserver/lib:$LD_LIBRARY_PATH
COPY scripts/nvidia_kaldi_triton_entrypoint.sh /opt/triton-kaldi/
VOLUME /mnt/model-repo
ENTRYPOINT ["/opt/triton-kaldi/nvidia_kaldi_triton_entrypoint.sh"]
CMD ["tritonserver", "--model-repo=/workspace/model-repo"]

View File

@ -1,4 +1,4 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -11,11 +11,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb
FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk
ARG TRITONSERVER_IMAGE=nvcr.io/nvidia/tritonserver:21.05-py3
ARG KALDI_IMAGE=nvcr.io/nvidia/kaldi:21.08-py3
ARG PYTHON_VER=3.8
#
# Kaldi shared library dependencies
#
FROM ${KALDI_IMAGE} as kaldi
#
# Builder image based on Triton Server SDK image
#
FROM ${TRITONSERVER_IMAGE}-sdk as builder
ARG PYTHON_VER
# Kaldi dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
RUN set -eux; \
apt-get update; \
apt-get install -yq --no-install-recommends \
automake \
autoconf \
cmake \
@ -23,21 +39,78 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
gawk \
libatlas3-base \
libtool \
python3.6 \
python3.6-dev \
python${PYTHON_VER} \
python${PYTHON_VER}-dev \
sox \
subversion \
unzip \
bc \
libatlas-base-dev \
zlib1g-dev
gfortran \
zlib1g-dev; \
rm -rf /var/lib/apt/lists/*
# Moving Kaldi to container
COPY --from=kb /opt/kaldi /opt/kaldi
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH
# Add Kaldi dependency
COPY --from=kaldi /opt/kaldi /opt/kaldi
# Set up Atlas
RUN set -eux; \
ln -sf /usr/include/x86_64-linux-gnu/atlas /usr/local/include/atlas; \
ln -sf /usr/include/x86_64-linux-gnu/cblas.h /usr/local/include/cblas.h; \
ln -sf /usr/include/x86_64-linux-gnu/clapack.h /usr/local/include/clapack.h; \
ln -sf /usr/lib/x86_64-linux-gnu/atlas /usr/local/lib/atlas
#
# Triton Kaldi client build
#
FROM builder as client-build
# Build the clients
COPY kaldi-asr-client /workspace/triton-client
RUN set -eux; \
cd /workspace; \
echo 'add_subdirectory(../../../triton-client src/c++/triton-client)' \
>> /workspace/client/src/c++/CMakeLists.txt; \
cmake -DCMAKE_BUILD_TYPE=Release -B build client; \
cmake --build build --parallel --target cc-clients
#
# Final gRPC client image
#
FROM ${TRITONSERVER_IMAGE}
ARG PYTHON_VER
# Kaldi dependencies
RUN set -eux; \
apt-get update; \
apt-get install -yq --no-install-recommends \
automake \
autoconf \
cmake \
flac \
gawk \
libatlas3-base \
libtool \
python${PYTHON_VER} \
python${PYTHON_VER}-dev \
sox \
subversion \
unzip \
bc \
libatlas-base-dev \
zlib1g-dev; \
rm -rf /var/lib/apt/lists/*
# Add Kaldi dependency
COPY --from=kaldi /opt/kaldi /opt/kaldi
# Add Triton clients and scripts
COPY --from=client-build /workspace/build/cc-clients/src/c++/triton-client/kaldi-asr-parallel-client /usr/local/bin/
COPY scripts /workspace/scripts
COPY kaldi-asr-client /workspace/src/clients/c++/kaldi-asr-client
RUN echo "add_subdirectory(kaldi-asr-client)" >> "/workspace/src/clients/c++/CMakeLists.txt"
RUN cd /workspace/build/ && make -j16 trtis-clients
# Setup environment and entrypoint
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:/opt/tritonserver/lib:$LD_LIBRARY_PATH
VOLUME /mnt/model-repo
ENTRYPOINT ["/usr/local/bin/kaldi-asr-parallel-client"]

View File

@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk
FROM nvcr.io/nvidia/tritonserver:21.05-py3-sdk
# Kaldi dependencies
RUN apt-get update && apt-get install -y jupyter \

View File

@ -46,15 +46,18 @@ A reference model is used by all test scripts and benchmarks presented in this r
Details about parameters can be found in the [Parameters](#parameters) section.
* `model path`: Configured to use the pretrained LibriSpeech model.
* `use_tensor_cores`: 1
* `main_q_capacity`: 30000
* `aux_q_capacity`: 400000
* `beam`: 10
* `num_channels`: 4000
* `lattice_beam`: 7
* `max_active`: 10,000
* `frame_subsampling_factor`: 3
* `acoustic_scale`: 1.0
* `num_worker_threads`: 20
* `max_execution_batch_size`: 256
* `max_batch_size`: 4096
* `instance_group.count`: 2
* `num_worker_threads`: 40
* `max_batch_size`: 400
* `instance_group.count`: 1
## Setup
@ -134,9 +137,8 @@ The model configuration parameters are passed to the model and have an impact o
The inference engine configuration parameters configure the inference engine. They impact performance, but not accuracy.
* `max_batch_size`: The maximum number of inference channels opened at a given time. If set to `4096`, then one instance will handle at most 4096 concurrent users.
* `max_batch_size`: The size of one execution batch on the GPU. This parameter should be set as large as necessary to saturate the GPU, but not bigger. Larger batches will lead to a higher throughput, smaller batches to lower latency.
* `num_worker_threads`: The number of CPU threads for the postprocessing CPU tasks, such as lattice determinization and text generation from the lattice.
* `max_execution_batch_size`: The size of one execution batch on the GPU. This parameter should be set as large as necessary to saturate the GPU, but not bigger. Larger batches will lead to a higher throughput, smaller batches to lower latency.
* `input.WAV_DATA.dims`: The maximum number of samples per chunk. The value must be a multiple of `frame_subsampling_factor * chunks_per_frame`.
### Inference process
@ -156,7 +158,7 @@ The client can be configured through a set of parameters that define its behavio
-u <URL for inference service and its gRPC port>
-o : Only feed each channel at realtime speed. Simulates online clients.
-p : Print text outputs
-b : Print partial (best path) text outputs
```
### Input/Output
@ -187,13 +189,8 @@ Even if only the best path is used, we are still generating a full lattice for b
Support for Kaldi ASR models that are different from the provided LibriSpeech model is experimental. However, it is possible to modify the [Model Path](#model-path) section of the config file `model-repo/kaldi_online/config.pbtxt` to set up your own model.
The models and Kaldi allocators are currently not shared between instances. This means that if your model is large, you may end up with not enough memory on the GPU to store two different instances. If that's the case,
you can set `count` to `1` in the [`instance_group` section](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html#instance-groups) of the config file.
## Performance
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIAs latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
### Metrics
@ -207,8 +204,7 @@ Latency is defined as the delay between the availability of the last chunk of au
4. *Server:* Compute inference of last chunk
5. *Server:* Generate the raw lattice for the full utterance
6. *Server:* Determinize the raw lattice
7. *Server:* Generate the text output associated with the best path in the determinized lattice
8. *Client:* Receive text output
8. *Client:* Receive lattice output
9. *Client:* Call callback with output
10. ***t1** <- Current time*
@ -219,20 +215,18 @@ The latency is defined such as `latency = t1 - t0`.
Our results were obtained by:
1. Building and starting the server as described in [Quick Start Guide](#quick-start-guide).
2. Running `scripts/run_inference_all_v100.sh` and `scripts/run_inference_all_t4.sh`
2. Running `scripts/run_inference_all_a100.sh`, `scripts/run_inference_all_v100.sh` and `scripts/run_inference_all_t4.sh`
| GPU | Realtime I/O | Number of parallel audio channels | Throughput (RTFX) | Latency | | | |
| ------ | ------ | ------ | ------ | ------ | ------ | ------ |------ |
| | | | | 90% | 95% | 99% | Avg |
| V100 | No | 2000 | 1506.5 | N/A | N/A | N/A | N/A |
| V100 | Yes | 1500 | 1243.2 | 0.582 | 0.699 | 1.04 | 0.400 |
| V100 | Yes | 1000 | 884.1 | 0.379 | 0.393 | 0.788 | 0.333 |
| V100 | Yes | 800 | 660.2 | 0.334 | 0.340 | 0.438 | 0.288 |
| T4 | No | 1000 | 675.2 | N/A | N/A | N/A| N/A |
| T4 | Yes | 700 | 629.2 | 0.945 | 1.08 | 1.27 | 0.645 |
| T4 | Yes | 400 | 373.7 | 0.579 | 0.624 | 0.758 | 0.452 |
| GPU | Realtime I/O | Number of parallel audio channels | Latency (s) | | | |
| ----- | ------------ | --------------------------------- | ----------- | ----- | ----- | ----- |
| | | | 90% | 95% | 99% | Avg |
| A100 | Yes | 2000 | 0.11 | 0.12 | 0.14 | 0.09 |
| V100 | Yes | 2000 | 0.42 | 0.50 | 0.61 | 0.23 |
| V100 | Yes | 1000 | 0.09 | 0.09 | 0.11 | 0.07 |
| T4 | Yes | 600 | 0.17 | 0.18 | 0.22 | 0.14 |
| T4 | Yes | 400 | 0.12 | 0.13 | 0.15 | 0.10 |
## Release notes
### Changelog
@ -244,5 +238,9 @@ April 2020
* Printing WER accuracy in Triton client
* Using the latest Kaldi GPU ASR pipeline, extended support for features (ivectors, fbanks)
July 2021
* Significantly improve latency and throughput for the backend
* Update Triton to v2.10.0
### Known issues
* No multi-gpu support for the Triton integration

View File

@ -0,0 +1,132 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17..3.20)
project(TritonKaldiBackend LANGUAGES C CXX)
#
# Options
#
# Must include options required for this project as well as any
# projects included in this one by FetchContent.
#
# GPU support is enabled by default because the Kaldi backend requires GPUs.
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent's composibility isn't very good. We must include the
# transitive closure of all repos so that we can override the tag.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# Shared library implementing the Triton Backend API
#
add_library(triton-kaldi-backend SHARED)
add_library(TritonKaldiBackend::triton-kaldi-backend ALIAS triton-kaldi-backend)
target_sources(triton-kaldi-backend
PRIVATE
triton-kaldi-backend.cc
kaldi-backend-utils.cc
kaldi-backend-utils.h
)
target_include_directories(triton-kaldi-backend SYSTEM
PRIVATE
$<$<BOOL:${TRITON_ENABLE_GPU}>:${CUDA_INCLUDE_DIRS}>
/opt/kaldi/src
/opt/kaldi/tools/openfst/include
)
target_include_directories(triton-kaldi-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_compile_features(triton-kaldi-backend PRIVATE cxx_std_17)
target_compile_options(triton-kaldi-backend
PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror>
)
target_link_directories(triton-kaldi-backend
PRIVATE
/opt/kaldi/src/lib
)
target_link_libraries(triton-kaldi-backend
PRIVATE
TritonCore::triton-core-serverapi # from repo-core
TritonCore::triton-core-backendapi # from repo-core
TritonCore::triton-core-serverstub # from repo-core
TritonBackend::triton-backend-utils # from repo-backend
-lkaldi-cudadecoder
)
set_target_properties(triton-kaldi-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_kaldi
)
#
# Install
#
include(GNUInstallDirs)
install(
TARGETS
triton-kaldi-backend
EXPORT
triton-kaldi-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/kaldi
ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/kaldi
)

View File

@ -0,0 +1,142 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "kaldi-backend-utils.h"
#include <triton/core/tritonserver.h>
using triton::common::TritonJson;
namespace triton {
namespace backend {
TRITONSERVER_Error* GetInputTensor(TRITONBACKEND_Request* request,
const std::string& input_name,
const size_t expected_byte_size,
std::vector<uint8_t>* buffer,
const void** out) {
buffer->clear(); // reset buffer
TRITONBACKEND_Input* input;
RETURN_IF_ERROR(
TRITONBACKEND_RequestInput(request, input_name.c_str(), &input));
uint64_t input_byte_size;
uint32_t input_buffer_count;
RETURN_IF_ERROR(
TRITONBACKEND_InputProperties(input, nullptr, nullptr, nullptr, nullptr,
&input_byte_size, &input_buffer_count));
RETURN_ERROR_IF_FALSE(
input_byte_size == expected_byte_size, TRITONSERVER_ERROR_INVALID_ARG,
std::string(std::string("unexpected byte size ") +
std::to_string(expected_byte_size) + " requested for " +
input_name.c_str() + ", received " +
std::to_string(input_byte_size)));
// The values for an input tensor are not necessarily in one
// contiguous chunk, so we might copy the chunks into 'input' vector.
// If possible, we use the data in place
uint64_t total_content_byte_size = 0;
for (uint32_t b = 0; b < input_buffer_count; ++b) {
const void* input_buffer = nullptr;
uint64_t input_buffer_byte_size = 0;
TRITONSERVER_MemoryType input_memory_type = TRITONSERVER_MEMORY_CPU;
int64_t input_memory_type_id = 0;
RETURN_IF_ERROR(TRITONBACKEND_InputBuffer(
input, b, &input_buffer, &input_buffer_byte_size, &input_memory_type,
&input_memory_type_id));
RETURN_ERROR_IF_FALSE(input_memory_type != TRITONSERVER_MEMORY_GPU,
TRITONSERVER_ERROR_INTERNAL,
std::string("expected input tensor in CPU memory"));
// Skip the copy if input already exists as a single contiguous
// block
if ((input_buffer_byte_size == expected_byte_size) && (b == 0)) {
*out = input_buffer;
return nullptr;
}
buffer->insert(
buffer->end(), static_cast<const uint8_t*>(input_buffer),
static_cast<const uint8_t*>(input_buffer) + input_buffer_byte_size);
total_content_byte_size += input_buffer_byte_size;
}
// Make sure we end up with exactly the amount of input we expect.
RETURN_ERROR_IF_FALSE(
total_content_byte_size == expected_byte_size,
TRITONSERVER_ERROR_INVALID_ARG,
std::string(std::string("total unexpected byte size ") +
std::to_string(expected_byte_size) + " requested for " +
input_name.c_str() + ", received " +
std::to_string(total_content_byte_size)));
*out = &buffer[0];
return nullptr;
}
void LatticeToString(fst::SymbolTable& word_syms,
const kaldi::CompactLattice& dlat, std::string* out_str) {
kaldi::CompactLattice best_path_clat;
kaldi::CompactLatticeShortestPath(dlat, &best_path_clat);
kaldi::Lattice best_path_lat;
fst::ConvertLattice(best_path_clat, &best_path_lat);
std::vector<int32> alignment;
std::vector<int32> words;
kaldi::LatticeWeight weight;
fst::GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
std::ostringstream oss;
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms.Find(words[i]);
if (s == "") {
LOG_MESSAGE(
TRITONSERVER_LOG_WARN,
("word-id " + std::to_string(words[i]) + " not in symbol table")
.c_str());
}
oss << s << " ";
}
*out_str = std::move(oss.str());
}
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
const std::string& key, std::string* param) {
TritonJson::Value value;
RETURN_ERROR_IF_FALSE(
params.Find(key.c_str(), &value), TRITONSERVER_ERROR_INVALID_ARG,
std::string("model configuration is missing the parameter ") + key);
RETURN_IF_ERROR(value.MemberAsString("string_value", param));
return nullptr; // success
}
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
const std::string& key, int* param) {
std::string tmp;
RETURN_IF_ERROR(ReadParameter(params, key, &tmp));
*param = std::stoi(tmp);
return nullptr; // success
}
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
const std::string& key, float* param) {
std::string tmp;
RETURN_IF_ERROR(ReadParameter(params, key, &tmp));
*param = std::stof(tmp);
return nullptr; // success
}
} // namespace backend
} // namespace triton

View File

@ -0,0 +1,57 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <lat/lattice-functions.h>
#include <triton/backend/backend_common.h>
#include <triton/common/triton_json.h>
namespace triton {
namespace backend {
using triton::common::TritonJson;
#define RETURN_AND_LOG_IF_ERROR(X, MSG) \
do { \
TRITONSERVER_Error* rie_err__ = (X); \
if (rie_err__ != nullptr) { \
LOG_MESSAGE(TRITONSERVER_LOG_INFO, MSG); \
return rie_err__; \
} \
} while (false)
TRITONSERVER_Error* GetInputTensor(TRITONBACKEND_Request* request,
const std::string& input_name,
const size_t expected_byte_size,
std::vector<uint8_t>* input,
const void** out);
TRITONSERVER_Error* LatticeToString(TRITONBACKEND_Request* request,
const std::string& input_name, char* buffer,
size_t* buffer_byte_size);
void LatticeToString(fst::SymbolTable& word_syms,
const kaldi::CompactLattice& dlat, std::string* out_str);
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
const std::string& key, std::string* param);
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
const std::string& key, int* param);
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
const std::string& key, float* param);
} // namespace backend
} // namespace triton

File diff suppressed because it is too large Load Diff

View File

@ -1,70 +1,78 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# http://www.apache.org/licenses/LICENSE-2.0
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17..3.20)
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required (VERSION 3.5)
add_executable(kaldi_asr_parallel_client kaldi_asr_parallel_client.cc asr_client_imp.cc)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE request
)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE protobuf::libprotobuf
)
target_include_directories(
kaldi_asr_parallel_client
# gRPC client for custom Kaldi backend
#
add_executable(kaldi-asr-parallel-client)
add_executable(TritonKaldiGrpcClient::kaldi-asr-parallel-client ALIAS kaldi-asr-parallel-client)
target_sources(kaldi-asr-parallel-client
PRIVATE
/opt/kaldi/src/
kaldi_asr_parallel_client.cc
asr_client_imp.cc
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") # openfst yields many warnings
target_include_directories(
kaldi_asr_parallel_client
target_include_directories(kaldi-asr-parallel-client SYSTEM
PRIVATE
/opt/kaldi/tools/openfst-1.6.7/include/
/opt/kaldi/src
/opt/kaldi/tools/openfst/include
)
target_include_directories(kaldi-asr-parallel-client
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)
target_compile_features(kaldi-asr-parallel-client PRIVATE cxx_std_17)
target_compile_options(kaldi-asr-parallel-client
PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror>
)
target_link_directories(kaldi-asr-parallel-client
PRIVATE
/opt/kaldi/src/lib
)
target_link_libraries(kaldi-asr-parallel-client
PRIVATE
TritonClient::grpcclient_static
-lkaldi-base
-lkaldi-util
-lkaldi-matrix
-lkaldi-feat
-lkaldi-lat
)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE /opt/kaldi/src/lib/libkaldi-feat.so
)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE /opt/kaldi/src/lib/libkaldi-util.so
)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE /opt/kaldi/src/lib/libkaldi-matrix.so
)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE /opt/kaldi/src/lib/libkaldi-base.so
)
target_link_libraries(
kaldi_asr_parallel_client
PRIVATE /opt/kaldi/src/lat/libkaldi-lat.so
)
#
# Install
#
include(GNUInstallDirs)
install(
TARGETS kaldi_asr_parallel_client
RUNTIME DESTINATION bin
TARGETS
kaldi-asr-parallel-client
EXPORT
kaldi-asr-parallel-client-targets
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin
)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,7 +13,9 @@
// limitations under the License.
#include "asr_client_imp.h"
#include <unistd.h>
#include <cmath>
#include <cstring>
#include <iomanip>
@ -33,117 +35,179 @@
} \
}
void TRTISASRClient::CreateClientContext() {
contextes_.emplace_back();
ClientContext& client = contextes_.back();
void TritonASRClient::CreateClientContext() {
clients_.emplace_back();
TritonClient& client = clients_.back();
FAIL_IF_ERR(nic::InferenceServerGrpcClient::Create(&client.triton_client,
url_, /*verbose*/ false),
"unable to create triton client");
FAIL_IF_ERR(
nic::InferGrpcStreamContext::Create(&client.trtis_context,
/*corr_id*/ -1, url_, model_name_,
/*model_version*/ -1,
/*verbose*/ false),
"unable to create context");
client.triton_client->StartStream(
[&](nic::InferResult* result) {
double end_timestamp = gettime_monotonic();
std::unique_ptr<nic::InferResult> result_ptr(result);
FAIL_IF_ERR(result_ptr->RequestStatus(),
"inference request failed");
std::string request_id;
FAIL_IF_ERR(result_ptr->Id(&request_id),
"unable to get request id for response");
uint64_t corr_id =
std::stoi(std::string(request_id, 0, request_id.find("_")));
bool end_of_stream = (request_id.back() == '1');
if (!end_of_stream) {
if (print_partial_results_) {
std::vector<std::string> text;
FAIL_IF_ERR(result_ptr->StringData("TEXT", &text),
"unable to get TEXT output");
std::lock_guard<std::mutex> lk(stdout_m_);
std::cout << "CORR_ID " << corr_id << "\t[partial]\t" << text[0]
<< '\n';
}
return;
}
double start_timestamp;
{
std::lock_guard<std::mutex> lk(start_timestamps_m_);
auto it = start_timestamps_.find(corr_id);
if (it != start_timestamps_.end()) {
start_timestamp = it->second;
start_timestamps_.erase(it);
} else {
std::cerr << "start_timestamp not found" << std::endl;
exit(1);
}
}
if (print_results_) {
std::vector<std::string> text;
FAIL_IF_ERR(result_ptr->StringData(ctm_ ? "CTM" : "TEXT", &text),
"unable to get TEXT or CTM output");
std::lock_guard<std::mutex> lk(stdout_m_);
std::cout << "CORR_ID " << corr_id;
std::cout << (ctm_ ? "\n" : "\t\t");
std::cout << text[0] << std::endl;
}
std::vector<std::string> lattice_bytes;
FAIL_IF_ERR(result_ptr->StringData("RAW_LATTICE", &lattice_bytes),
"unable to get RAW_LATTICE output");
{
double elapsed = end_timestamp - start_timestamp;
std::lock_guard<std::mutex> lk(results_m_);
results_.insert(
{corr_id, {std::move(lattice_bytes[0]), elapsed}});
}
n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
},
false),
"unable to establish a streaming connection to server");
}
void TRTISASRClient::SendChunk(ni::CorrelationID corr_id,
bool start_of_sequence, bool end_of_sequence,
float* chunk, int chunk_byte_size) {
ClientContext* client = &contextes_[corr_id % ncontextes_];
nic::InferContext& context = *client->trtis_context;
if (start_of_sequence) n_in_flight_.fetch_add(1, std::memory_order_consume);
void TritonASRClient::SendChunk(uint64_t corr_id, bool start_of_sequence,
bool end_of_sequence, float* chunk,
int chunk_byte_size, const uint64_t index) {
// Setting options
std::unique_ptr<nic::InferContext::Options> options;
FAIL_IF_ERR(nic::InferContext::Options::Create(&options),
"unable to create inference options");
options->SetBatchSize(1);
options->SetFlags(0);
options->SetCorrelationId(corr_id);
if (start_of_sequence)
options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_START,
start_of_sequence);
if (end_of_sequence) {
options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_END,
end_of_sequence);
for (const auto& output : context.Outputs()) {
if (output->Name() == "TEXT" && !print_results_)
continue; // no need for text output if not printing
options->AddRawResult(output);
}
}
nic::InferOptions options(model_name_);
options.sequence_id_ = corr_id;
options.sequence_start_ = start_of_sequence;
options.sequence_end_ = end_of_sequence;
options.request_id_ = std::to_string(corr_id) + "_" + std::to_string(index) +
"_" + (start_of_sequence ? "1" : "0") + "_" +
(end_of_sequence ? "1" : "0");
FAIL_IF_ERR(context.SetRunOptions(*options), "unable to set context options");
std::shared_ptr<nic::InferContext::Input> in_wave_data, in_wave_data_dim;
FAIL_IF_ERR(context.GetInput("WAV_DATA", &in_wave_data),
"unable to get WAV_DATA");
FAIL_IF_ERR(context.GetInput("WAV_DATA_DIM", &in_wave_data_dim),
"unable to get WAV_DATA_DIM");
// Wave data input
FAIL_IF_ERR(in_wave_data->Reset(), "unable to reset WAVE_DATA");
// Initialize the inputs with the data.
nic::InferInput* wave_data_ptr;
std::vector<int64_t> wav_shape{1, samps_per_chunk_};
FAIL_IF_ERR(
nic::InferInput::Create(&wave_data_ptr, "WAV_DATA", wav_shape, "FP32"),
"unable to create 'WAV_DATA'");
std::shared_ptr<nic::InferInput> wave_data_in(wave_data_ptr);
FAIL_IF_ERR(wave_data_in->Reset(), "unable to reset 'WAV_DATA'");
uint8_t* wave_data = reinterpret_cast<uint8_t*>(chunk);
if (chunk_byte_size < max_chunk_byte_size_) {
std::memcpy(&chunk_buf_[0], chunk, chunk_byte_size);
wave_data = &chunk_buf_[0];
}
FAIL_IF_ERR(in_wave_data->SetRaw(wave_data, max_chunk_byte_size_),
"unable to set data for WAVE_DATA");
FAIL_IF_ERR(wave_data_in->AppendRaw(wave_data, max_chunk_byte_size_),
"unable to set data for 'WAV_DATA'");
// Dim
FAIL_IF_ERR(in_wave_data_dim->Reset(), "unable to reset WAVE_DATA_DIM");
nic::InferInput* dim_ptr;
std::vector<int64_t> shape{1, 1};
FAIL_IF_ERR(nic::InferInput::Create(&dim_ptr, "WAV_DATA_DIM", shape, "INT32"),
"unable to create 'WAV_DATA_DIM'");
std::shared_ptr<nic::InferInput> dim_in(dim_ptr);
FAIL_IF_ERR(dim_in->Reset(), "unable to reset WAVE_DATA_DIM");
int nsamples = chunk_byte_size / sizeof(float);
FAIL_IF_ERR(in_wave_data_dim->SetRaw(reinterpret_cast<uint8_t*>(&nsamples),
sizeof(int32_t)),
"unable to set data for WAVE_DATA_DIM");
FAIL_IF_ERR(
dim_in->AppendRaw(reinterpret_cast<uint8_t*>(&nsamples), sizeof(int32_t)),
"unable to set data for WAVE_DATA_DIM");
total_audio_ += (static_cast<double>(nsamples) / 16000.); // TODO freq
double start = gettime_monotonic();
FAIL_IF_ERR(context.AsyncRun([corr_id, end_of_sequence, start, this](
nic::InferContext* ctx,
const std::shared_ptr<
nic::InferContext::Request>& request) {
if (end_of_sequence) {
double elapsed = gettime_monotonic() - start;
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
ctx->GetAsyncRunResults(request, &results);
std::vector<nic::InferInput*> inputs = {wave_data_in.get(), dim_in.get()};
if (results.empty()) {
std::cerr << "Warning: Could not read "
"output for corr_id "
<< corr_id << std::endl;
} else {
if (print_results_) {
std::string text;
FAIL_IF_ERR(results["TEXT"]->GetRawAtCursor(0, &text),
"unable to get TEXT output");
std::lock_guard<std::mutex> lk(stdout_m_);
std::cout << "CORR_ID " << corr_id << "\t\t" << text << std::endl;
}
std::vector<const nic::InferRequestedOutput*> outputs;
std::shared_ptr<nic::InferRequestedOutput> raw_lattice, text;
outputs.reserve(2);
if (end_of_sequence) {
nic::InferRequestedOutput* raw_lattice_ptr;
FAIL_IF_ERR(
nic::InferRequestedOutput::Create(&raw_lattice_ptr, "RAW_LATTICE"),
"unable to get 'RAW_LATTICE'");
raw_lattice.reset(raw_lattice_ptr);
outputs.push_back(raw_lattice.get());
std::string lattice_bytes;
FAIL_IF_ERR(results["RAW_LATTICE"]->GetRawAtCursor(0, &lattice_bytes),
"unable to get RAW_LATTICE output");
{
std::lock_guard<std::mutex> lk(results_m_);
results_.insert({corr_id, {std::move(lattice_bytes), elapsed}});
}
}
n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
// Request the TEXT results only when required for printing
if (print_results_) {
nic::InferRequestedOutput* text_ptr;
FAIL_IF_ERR(
nic::InferRequestedOutput::Create(&text_ptr, ctm_ ? "CTM" : "TEXT"),
"unable to get 'TEXT' or 'CTM'");
text.reset(text_ptr);
outputs.push_back(text.get());
}
}),
} else if (print_partial_results_) {
nic::InferRequestedOutput* text_ptr;
FAIL_IF_ERR(nic::InferRequestedOutput::Create(&text_ptr, "TEXT"),
"unable to get 'TEXT'");
text.reset(text_ptr);
outputs.push_back(text.get());
}
total_audio_ += (static_cast<double>(nsamples) / samp_freq_);
if (start_of_sequence) {
n_in_flight_.fetch_add(1, std::memory_order_consume);
}
// Record the timestamp when the last chunk was made available.
if (end_of_sequence) {
std::lock_guard<std::mutex> lk(start_timestamps_m_);
start_timestamps_[corr_id] = gettime_monotonic();
}
TritonClient* client = &clients_[corr_id % nclients_];
// nic::InferenceServerGrpcClient& triton_client = *client->triton_client;
FAIL_IF_ERR(client->triton_client->AsyncStreamInfer(options, inputs, outputs),
"unable to run model");
}
void TRTISASRClient::WaitForCallbacks() {
int n;
while ((n = n_in_flight_.load(std::memory_order_consume))) {
void TritonASRClient::WaitForCallbacks() {
while (n_in_flight_.load(std::memory_order_consume)) {
usleep(1000);
}
}
void TRTISASRClient::PrintStats(bool print_latency_stats) {
void TritonASRClient::PrintStats(bool print_latency_stats,
bool print_throughput) {
double now = gettime_monotonic();
double diff = now - started_at_;
double rtf = total_audio_ / diff;
std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl;
if (print_throughput)
std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl;
std::vector<double> latencies;
{
std::lock_guard<std::mutex> lk(results_m_);
@ -176,20 +240,33 @@ void TRTISASRClient::PrintStats(bool print_latency_stats) {
}
}
TRTISASRClient::TRTISASRClient(const std::string& url,
const std::string& model_name,
const int ncontextes, bool print_results)
TritonASRClient::TritonASRClient(const std::string& url,
const std::string& model_name,
const int nclients, bool print_results,
bool print_partial_results, bool ctm,
float samp_freq)
: url_(url),
model_name_(model_name),
ncontextes_(ncontextes),
print_results_(print_results) {
ncontextes_ = std::max(ncontextes_, 1);
for (int i = 0; i < ncontextes_; ++i) CreateClientContext();
nclients_(nclients),
print_results_(print_results),
print_partial_results_(print_partial_results),
ctm_(ctm),
samp_freq_(samp_freq) {
nclients_ = std::max(nclients_, 1);
for (int i = 0; i < nclients_; ++i) CreateClientContext();
std::shared_ptr<nic::InferContext::Input> in_wave_data;
FAIL_IF_ERR(contextes_[0].trtis_context->GetInput("WAV_DATA", &in_wave_data),
"unable to get WAV_DATA");
max_chunk_byte_size_ = in_wave_data->ByteSize();
inference::ModelMetadataResponse model_metadata;
FAIL_IF_ERR(
clients_[0].triton_client->ModelMetadata(&model_metadata, model_name),
"unable to get model metadata");
for (const auto& in_tensor : model_metadata.inputs()) {
if (in_tensor.name().compare("WAV_DATA") == 0) {
samps_per_chunk_ = in_tensor.shape()[1];
}
}
max_chunk_byte_size_ = samps_per_chunk_ * sizeof(float);
chunk_buf_.resize(max_chunk_byte_size_);
shape_ = {max_chunk_byte_size_};
n_in_flight_.store(0);
@ -197,20 +274,24 @@ TRTISASRClient::TRTISASRClient(const std::string& url,
total_audio_ = 0;
}
void TRTISASRClient::WriteLatticesToFile(
void TritonASRClient::WriteLatticesToFile(
const std::string& clat_wspecifier,
const std::unordered_map<ni::CorrelationID, std::string>&
corr_id_and_keys) {
const std::unordered_map<uint64_t, std::string>& corr_id_and_keys) {
kaldi::CompactLatticeWriter clat_writer;
clat_writer.Open(clat_wspecifier);
std::unordered_map<std::string, size_t> key_count;
std::lock_guard<std::mutex> lk(results_m_);
for (auto& p : corr_id_and_keys) {
ni::CorrelationID corr_id = p.first;
const std::string& key = p.second;
uint64_t corr_id = p.first;
std::string key = p.second;
const auto iter = key_count[key]++;
if (iter > 0) {
key += std::to_string(iter);
}
auto it = results_.find(corr_id);
if(it == results_.end()) {
std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl;
continue;
if (it == results_.end()) {
std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl;
continue;
}
const std::string& raw_lattice = it->second.raw_lattice;
// We could in theory write directly the binary hold in raw_lattice (it is

View File

@ -1,4 +1,4 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <grpc_client.h>
#include <queue>
#include <string>
#include <vector>
#include <unordered_map>
#include <vector>
#include "request_grpc.h"
#ifndef TRITON_KALDI_ASR_CLIENT_H_
#define TRITON_KALDI_ASR_CLIENT_H_
#ifndef TRTIS_KALDI_ASR_CLIENT_H_
#define TRTIS_KALDI_ASR_CLIENT_H_
namespace ni = nvidia::inferenceserver;
namespace nic = nvidia::inferenceserver::client;
@ -33,16 +34,16 @@ double inline gettime_monotonic() {
return time;
}
class TRTISASRClient {
struct ClientContext {
std::unique_ptr<nic::InferContext> trtis_context;
class TritonASRClient {
struct TritonClient {
std::unique_ptr<nic::InferenceServerGrpcClient> triton_client;
};
std::string url_;
std::string model_name_;
std::vector<ClientContext> contextes_;
int ncontextes_;
std::vector<TritonClient> clients_;
int nclients_;
std::vector<uint8_t> chunk_buf_;
std::vector<int64_t> shape_;
int max_chunk_byte_size_;
@ -50,26 +51,36 @@ class TRTISASRClient {
double started_at_;
double total_audio_;
bool print_results_;
bool print_partial_results_;
bool ctm_;
std::mutex stdout_m_;
int samps_per_chunk_;
float samp_freq_;
struct Result {
std::string raw_lattice;
double latency;
};
std::unordered_map<ni::CorrelationID, Result> results_;
std::unordered_map<uint64_t, double> start_timestamps_;
std::mutex start_timestamps_m_;
std::unordered_map<uint64_t, Result> results_;
std::mutex results_m_;
public:
TritonASRClient(const std::string& url, const std::string& model_name,
const int ncontextes, bool print_results,
bool print_partial_results, bool ctm, float samp_freq);
void CreateClientContext();
void SendChunk(uint64_t corr_id, bool start_of_sequence, bool end_of_sequence,
float* chunk, int chunk_byte_size);
float* chunk, int chunk_byte_size, uint64_t index);
void WaitForCallbacks();
void PrintStats(bool print_latency_stats);
void WriteLatticesToFile(const std::string &clat_wspecifier, const std::unordered_map<ni::CorrelationID, std::string> &corr_id_and_keys);
TRTISASRClient(const std::string& url, const std::string& model_name,
const int ncontextes, bool print_results);
void PrintStats(bool print_latency_stats, bool print_throughput);
void WriteLatticesToFile(
const std::string& clat_wspecifier,
const std::unordered_map<uint64_t, std::string>& corr_id_and_keys);
};
#endif // TRTIS_KALDI_ASR_CLIENT_H_
#endif // TRITON_KALDI_ASR_CLIENT_H_

View File

@ -1,4 +1,4 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,9 +13,12 @@
// limitations under the License.
#include <unistd.h>
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "asr_client_imp.h"
#include "feat/wave-reader.h" // to read the wav.scp
#include "util/kaldi-table.h"
@ -41,6 +44,8 @@ void Usage(char** argv, const std::string& msg = std::string()) {
"online clients."
<< std::endl;
std::cerr << "\t-p : Print text outputs" << std::endl;
std::cerr << "\t-b : Print partial (best path) text outputs" << std::endl;
//std::cerr << "\t-t : Print text with timings (CTM)" << std::endl;
std::cerr << std::endl;
exit(1);
@ -49,7 +54,7 @@ void Usage(char** argv, const std::string& msg = std::string()) {
int main(int argc, char** argv) {
std::cout << "\n";
std::cout << "==================================================\n"
<< "============= TRTIS Kaldi ASR Client =============\n"
<< "============= Triton Kaldi ASR Client ============\n"
<< "==================================================\n"
<< std::endl;
@ -65,14 +70,15 @@ int main(int argc, char** argv) {
size_t nchannels = 1000;
int niterations = 5;
bool verbose = false;
float samp_freq = 16000;
int ncontextes = 10;
int nclients = 10;
bool online = false;
bool print_results = false;
bool print_partial_results = false;
bool ctm = false;
// Parse commandline...
int opt;
while ((opt = getopt(argc, argv, "va:u:i:c:ophl:")) != -1) {
while ((opt = getopt(argc, argv, "va:u:i:c:otpbhl:")) != -1) {
switch (opt) {
case 'i':
niterations = std::atoi(optarg);
@ -95,6 +101,13 @@ int main(int argc, char** argv) {
case 'p':
print_results = true;
break;
case 'b':
print_partial_results = true;
break;
case 't':
ctm = true;
print_results = true;
break;
case 'l':
chunk_length = std::atoi(optarg);
break;
@ -116,19 +129,21 @@ int main(int argc, char** argv) {
std::cout << "Server URL\t\t\t: " << url << std::endl;
std::cout << "Print text outputs\t\t: " << (print_results ? "Yes" : "No")
<< std::endl;
std::cout << "Print partial text outputs\t: "
<< (print_partial_results ? "Yes" : "No") << std::endl;
std::cout << "Online - Realtime I/O\t\t: " << (online ? "Yes" : "No")
<< std::endl;
std::cout << std::endl;
float chunk_seconds = (double)chunk_length / samp_freq;
// need to read wav files
SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
float samp_freq = 0;
double total_audio = 0;
// pre-loading data
// we don't want to measure I/O
std::vector<std::shared_ptr<WaveData>> all_wav;
std::vector<std::string> all_wav_keys;
// need to read wav files
SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
{
std::cout << "Loading eval dataset..." << std::flush;
for (; !wav_reader.Done(); wav_reader.Next()) {
@ -138,90 +153,119 @@ int main(int argc, char** argv) {
all_wav.push_back(wave_data);
all_wav_keys.push_back(utt);
total_audio += wave_data->Duration();
samp_freq = wave_data->SampFreq();
}
std::cout << "done" << std::endl;
}
if (all_wav.empty()) {
std::cerr << "Empty dataset";
exit(0);
}
std::cout << "Loaded dataset with " << all_wav.size()
<< " utterances, frequency " << samp_freq << "hz, total audio "
<< total_audio << " seconds" << std::endl;
double chunk_seconds = (double)chunk_length / samp_freq;
double seconds_per_sample = chunk_seconds / chunk_length;
struct Stream {
std::shared_ptr<WaveData> wav;
ni::CorrelationID corr_id;
uint64_t corr_id;
int offset;
float send_next_chunk_at;
std::atomic<bool> received_output;
double send_next_chunk_at;
Stream(const std::shared_ptr<WaveData>& _wav, ni::CorrelationID _corr_id)
: wav(_wav), corr_id(_corr_id), offset(0), received_output(true) {
send_next_chunk_at = gettime_monotonic();
Stream(const std::shared_ptr<WaveData>& _wav, uint64_t _corr_id,
double _send_next_chunk_at)
: wav(_wav),
corr_id(_corr_id),
offset(0),
send_next_chunk_at(_send_next_chunk_at) {}
bool operator<(const Stream& other) const {
return (send_next_chunk_at > other.send_next_chunk_at);
}
};
std::cout << "Opening GRPC contextes..." << std::flush;
TRTISASRClient asr_client(url, model_name, ncontextes, print_results);
std::unordered_map<uint64_t, std::string> corr_id_and_keys;
TritonASRClient asr_client(url, model_name, nclients, print_results,
print_partial_results, ctm, samp_freq);
std::cout << "done" << std::endl;
std::cout << "Streaming utterances..." << std::flush;
std::vector<std::unique_ptr<Stream>> curr_tasks, next_tasks;
curr_tasks.reserve(nchannels);
next_tasks.reserve(nchannels);
std::cout << "Streaming utterances..." << std::endl;
std::priority_queue<Stream> streams;
size_t all_wav_i = 0;
size_t all_wav_max = all_wav.size() * niterations;
uint64_t index = 0;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0.0, 1.0);
bool add_random_offset = true;
while (true) {
while (curr_tasks.size() < nchannels && all_wav_i < all_wav_max) {
while (streams.size() < nchannels && all_wav_i < all_wav_max) {
// Creating new tasks
uint64_t corr_id = static_cast<uint64_t>(all_wav_i) + 1;
auto all_wav_i_modulo = all_wav_i % (all_wav.size());
double stream_will_start_at = gettime_monotonic();
if (add_random_offset) stream_will_start_at += dis(gen);
double first_chunk_available_at =
stream_will_start_at +
std::min(static_cast<double>(all_wav[all_wav_i_modulo]->Duration()),
chunk_seconds);
std::unique_ptr<Stream> ptr(
new Stream(all_wav[all_wav_i % (all_wav.size())], corr_id));
curr_tasks.emplace_back(std::move(ptr));
corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i_modulo]});
streams.emplace(all_wav[all_wav_i_modulo], corr_id,
first_chunk_available_at);
++all_wav_i;
}
// If still empty, done
if (curr_tasks.empty()) break;
if (streams.empty()) break;
for (size_t itask = 0; itask < curr_tasks.size(); ++itask) {
Stream& task = *(curr_tasks[itask]);
SubVector<BaseFloat> data(task.wav->Data(), 0);
int32 samp_offset = task.offset;
int32 nsamp = data.Dim();
int32 samp_remaining = nsamp - samp_offset;
int32 num_samp =
chunk_length < samp_remaining ? chunk_length : samp_remaining;
bool is_last_chunk = (chunk_length >= samp_remaining);
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
bool is_first_chunk = (samp_offset == 0);
if (online) {
double now = gettime_monotonic();
double wait_for = task.send_next_chunk_at - now;
if (wait_for > 0) usleep(wait_for * 1e6);
}
asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk,
wave_part.Data(), wave_part.SizeInBytes());
task.send_next_chunk_at += chunk_seconds;
if (verbose)
std::cout << "Sending correlation_id=" << task.corr_id
<< " chunk offset=" << num_samp << std::endl;
task.offset += num_samp;
if (!is_last_chunk) next_tasks.push_back(std::move(curr_tasks[itask]));
auto task = streams.top();
streams.pop();
if (online) {
double wait_for = task.send_next_chunk_at - gettime_monotonic();
if (wait_for > 0) usleep(wait_for * 1e6);
}
add_random_offset = false;
SubVector<BaseFloat> data(task.wav->Data(), 0);
int32 samp_offset = task.offset;
int32 nsamp = data.Dim();
int32 samp_remaining = nsamp - samp_offset;
int32 num_samp =
chunk_length < samp_remaining ? chunk_length : samp_remaining;
bool is_last_chunk = (chunk_length >= samp_remaining);
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
bool is_first_chunk = (samp_offset == 0);
asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk,
wave_part.Data(), wave_part.SizeInBytes(), index++);
if (verbose)
std::cout << "Sending correlation_id=" << task.corr_id
<< " chunk offset=" << num_samp << std::endl;
task.offset += num_samp;
int32 next_chunk_num_samp = std::min(nsamp - task.offset, chunk_length);
task.send_next_chunk_at += next_chunk_num_samp * seconds_per_sample;
if (!is_last_chunk) streams.push(task);
curr_tasks.swap(next_tasks);
next_tasks.clear();
// Showing activity if necessary
if (!print_results && !verbose) std::cout << "." << std::flush;
if (!print_results && !print_partial_results && !verbose &&
index % nchannels == 0)
std::cout << "." << std::flush;
}
std::cout << "done" << std::endl;
std::cout << "Waiting for all results..." << std::flush;
asr_client.WaitForCallbacks();
std::cout << "done" << std::endl;
asr_client.PrintStats(online);
std::unordered_map<ni::CorrelationID, std::string> corr_id_and_keys;
for (size_t all_wav_i = 0; all_wav_i < all_wav.size(); ++all_wav_i) {
ni::CorrelationID corr_id = static_cast<ni::CorrelationID>(all_wav_i) + 1;
corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i]});
}
asr_client.PrintStats(
online,
!online); // Print latency if online, do not print throughput if online
asr_client.WriteLatticesToFile("ark:|gzip -c > /data/results/lat.cuda-asr.gz",
corr_id_and_keys);

View File

@ -1,80 +1,103 @@
name: "kaldi_online"
platform: "custom"
default_model_filename: "libkaldi-trtisbackend.so"
max_batch_size: 2200
parameters: {
key: "mfcc_filename"
value: {
string_value:"/data/models/LibriSpeech/conf/mfcc.conf"
}
}
parameters: {
key: "ivector_filename"
value: {
string_value:"/data/models/LibriSpeech/conf/ivector_extractor.conf"
}
}
parameters: {
key: "nnet3_rxfilename"
value: {
string_value: "/data/models/LibriSpeech/final.mdl"
}
}
parameters: {
key: "fst_rxfilename"
value: {
string_value: "/data/models/LibriSpeech/HCLG.fst"
}
}
parameters: {
key: "word_syms_rxfilename"
value: {
string_value:"/data/models/LibriSpeech/words.txt"
}
}
parameters: [{
key: "beam"
value: {
string_value:"10"
}
},{
key: "num_worker_threads"
value: {
string_value:"40"
}
},
{
key: "max_execution_batch_size"
value: {
string_value:"400"
}
}]
parameters: {
key: "lattice_beam"
value: {
string_value:"7"
}
}
parameters: {
key: "max_active"
value: {
string_value:"10000"
}
}
parameters: {
key: "frame_subsampling_factor"
value: {
string_value:"3"
}
}
parameters: {
key: "acoustic_scale"
value: {
string_value:"1.0"
}
backend: "kaldi"
max_batch_size: 600
model_transaction_policy {
decoupled: True
}
parameters [
{
key: "config_filename"
value {
string_value: "/data/models/LibriSpeech/conf/online.conf"
}
},
{
key: "nnet3_rxfilename"
value {
string_value: "/data/models/LibriSpeech/final.mdl"
}
},
{
key: "fst_rxfilename"
value {
string_value: "/data/models/LibriSpeech/HCLG.fst"
}
},
{
key: "word_syms_rxfilename"
value {
string_value: "/data/models/LibriSpeech/words.txt"
}
},
{
key: "lattice_postprocessor_rxfilename"
value {
string_value: ""
}
},
{
key: "use_tensor_cores"
value {
string_value: "1"
}
},
{
key: "main_q_capacity"
value {
string_value: "30000"
}
},
{
key: "aux_q_capacity"
value {
string_value: "400000"
}
},
{
key: "beam"
value {
string_value: "10"
}
},
{
key: "num_worker_threads"
value {
string_value: "40"
}
},
{
key: "num_channels"
value {
string_value: "4000"
}
},
{
key: "lattice_beam"
value {
string_value: "7"
}
},
{
key: "max_active"
value {
string_value: "10000"
}
},
{
key: "frame_subsampling_factor"
value {
string_value: "3"
}
},
{
key: "acoustic_scale"
value {
string_value: "1.0"
}
}
]
sequence_batching {
max_sequence_idle_microseconds:5000000
max_sequence_idle_microseconds: 5000000
control_input [
{
name: "START"
@ -108,16 +131,16 @@ max_sequence_idle_microseconds:5000000
control [
{
kind: CONTROL_SEQUENCE_CORRID
data_type: TYPE_UINT64
data_type: TYPE_UINT64
}
]
}
]
oldest {
max_candidate_sequences:2200
preferred_batch_size:[400]
max_queue_delay_microseconds:1000
}
oldest {
max_candidate_sequences: 4000
preferred_batch_size: [ 600 ]
max_queue_delay_microseconds: 1000
}
},
input [
@ -142,13 +165,17 @@ output [
name: "TEXT"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "CTM"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
instance_group [
{
count: 2
count: 1
kind: KIND_GPU
}
]

View File

@ -15,10 +15,10 @@ oovtok=$(cat $result_path/words.txt | grep "<unk>" | awk '{print $2}')
# convert lattice to transcript
/opt/kaldi/src/latbin/lattice-best-path \
"ark:gunzip -c $result_path/lat.cuda-asr.gz |"\
"ark,t:|gzip -c > $result_path/trans.cuda-asr.gz" 2> /dev/null
"ark,t:$result_path/trans.cuda-asr" 2> /dev/null
# calculate wer
/opt/kaldi/src/bin/compute-wer --mode=present \
"ark:$result_path/text_ints" \
"ark:gunzip -c $result_path/trans.cuda-asr.gz |" 2>&1
"ark:$result_path/trans.cuda-asr" 2> /dev/null

View File

@ -13,5 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
docker build . -f Dockerfile --rm -t trtis_kaldi_server
docker build . -f Dockerfile.client --rm -t trtis_kaldi_client
set -eu
# Use development branch of Kaldi for latest feature support
docker build . -f Dockerfile \
--rm -t triton_kaldi_server
docker build . -f Dockerfile.client \
--rm -t triton_kaldi_client

View File

@ -19,4 +19,5 @@ docker run --rm -it \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v $PWD/data:/data \
trtis_kaldi_client /workspace/scripts/docker/run_client.sh $@
--entrypoint /bin/bash \
triton_kaldi_client /workspace/scripts/docker/run_client.sh $@

View File

@ -1,6 +1,6 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Start TRTIS server container for download - need some kaldi tools
nvidia-docker run --rm \
# Start Triton server container for download - need some kaldi tools
docker run --rm \
--shm-size=1g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
-v $PWD/data:/mnt/data \
trtis_kaldi_server /workspace/scripts/docker/dataset_setup.sh $(id -u) $(id -g)
triton_kaldi_server /workspace/scripts/docker/dataset_setup.sh $(id -u) $(id -g)
# --user $(id -u):$(id -g) \

View File

@ -1,6 +1,6 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -15,8 +15,9 @@
NV_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-"0"}
# Start TRTIS server
nvidia-docker run --rm -it \
# Start Triton server
docker run --rm -it \
--gpus $NV_VISIBLE_DEVICES \
--shm-size=1g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
@ -24,7 +25,6 @@ nvidia-docker run --rm -it \
-p8001:8001 \
-p8002:8002 \
--name trt_server_asr \
-e NVIDIA_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES \
-v $PWD/data:/data \
-v $PWD/model-repo:/mnt/model-repo \
trtis_kaldi_server trtserver --model-repo=/workspace/model-repo/
triton_kaldi_server

View File

@ -7,7 +7,7 @@ then
rm -rf $results_dir
fi
mkdir $results_dir
install/bin/kaldi_asr_parallel_client $@
kaldi-asr-parallel-client $@
echo "Computing WER..."
/workspace/scripts/compute_wer.sh
rm -rf $results_dir

View File

@ -19,4 +19,4 @@ if [ -d "/mnt/model-repo/kaldi_online" ]; then
ln -s /mnt/model-repo/kaldi_online/config.pbtxt /workspace/model-repo/kaldi_online/
fi
/opt/tensorrtserver/nvidia_entrypoint.sh $@
/opt/tritonserver/nvidia_entrypoint.sh $@

View File

@ -1,4 +1,6 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#!/bin/bash
# Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -11,11 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
{
global:
CustomErrorString;
CustomExecute;
CustomFinalize;
CustomInitialize;
local: *;
};
set -e
if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then
printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
exit 1
fi
printf "\nOffline benchmarks:\n"
scripts/docker/launch_client.sh -i 5 -c 4000
printf "\nOnline benchmarks:\n"
scripts/docker/launch_client.sh -i 10 -c 2000 -o

View File

@ -15,7 +15,7 @@
set -e
if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then
printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
exit 1
fi
@ -26,5 +26,5 @@ scripts/docker/launch_client.sh -i 5 -c 1000
printf "\nOnline benchmarks:\n"
scripts/docker/launch_client.sh -i 10 -c 700 -o
scripts/docker/launch_client.sh -i 10 -c 600 -o
scripts/docker/launch_client.sh -i 10 -c 400 -o

View File

@ -1,6 +1,6 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -15,7 +15,7 @@
set -e
if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then
printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
exit 1
fi
@ -26,6 +26,5 @@ scripts/docker/launch_client.sh -i 5 -c 2000
printf "\nOnline benchmarks:\n"
scripts/docker/launch_client.sh -i 10 -c 1500 -o
scripts/docker/launch_client.sh -i 10 -c 2000 -o
scripts/docker/launch_client.sh -i 10 -c 1000 -o
scripts/docker/launch_client.sh -i 5 -c 800 -o

View File

@ -1,5 +0,0 @@
.PHONY: all
all: kaldibackend
kaldibackend: kaldi-backend.cc kaldi-backend-utils.cc
g++ -fpic -shared -std=c++11 -o libkaldi-trtisbackend.so kaldi-backend.cc kaldi-backend-utils.cc -Icustom-backend-sdk/include custom-backend-sdk/lib/libcustombackend.a -I/opt/kaldi/src/ -I/usr/local/cuda/include -I/opt/kaldi/tools/openfst/include/ -L/opt/kaldi/src/lib/ -lkaldi-cudadecoder

View File

@ -1,155 +0,0 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "kaldi-backend-utils.h"
namespace nvidia {
namespace inferenceserver {
namespace custom {
namespace kaldi_cbe {
int GetInputTensor(CustomGetNextInputFn_t input_fn, void* input_context,
const char* name, const size_t expected_byte_size,
std::vector<uint8_t>* input, const void** out) {
input->clear(); // reset buffer
// The values for an input tensor are not necessarily in one
// contiguous chunk, so we might copy the chunks into 'input' vector.
// If possible, we use the data in place
uint64_t total_content_byte_size = 0;
while (true) {
const void* content;
uint64_t content_byte_size = expected_byte_size - total_content_byte_size;
if (!input_fn(input_context, name, &content, &content_byte_size)) {
return kInputContents;
}
// If 'content' returns nullptr we have all the input.
if (content == nullptr) break;
// If the total amount of content received exceeds what we expect
// then something is wrong.
total_content_byte_size += content_byte_size;
if (total_content_byte_size > expected_byte_size)
return kInputSize;
if (content_byte_size == expected_byte_size) {
*out = content;
return kSuccess;
}
input->insert(input->end(), static_cast<const uint8_t*>(content),
static_cast<const uint8_t*>(content) + content_byte_size);
}
// Make sure we end up with exactly the amount of input we expect.
if (total_content_byte_size != expected_byte_size) {
return kInputSize;
}
*out = &input[0];
return kSuccess;
}
void LatticeToString(fst::SymbolTable& word_syms,
const kaldi::CompactLattice& dlat, std::string* out_str) {
kaldi::CompactLattice best_path_clat;
kaldi::CompactLatticeShortestPath(dlat, &best_path_clat);
kaldi::Lattice best_path_lat;
fst::ConvertLattice(best_path_clat, &best_path_lat);
std::vector<int32> alignment;
std::vector<int32> words;
kaldi::LatticeWeight weight;
fst::GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
std::ostringstream oss;
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms.Find(words[i]);
if (s == "") std::cerr << "Word-id " << words[i] << " not in symbol table.";
oss << s << " ";
}
*out_str = std::move(oss.str());
}
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
std::string* param) {
auto it = model_config_.parameters().find(key);
if (it == model_config_.parameters().end()) {
std::cerr << "Parameter \"" << key
<< "\" missing from config file. Exiting." << std::endl;
return kInvalidModelConfig;
}
*param = it->second.string_value();
return kSuccess;
}
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
int* param) {
std::string tmp;
int err = ReadParameter(model_config_, key, &tmp);
*param = std::stoi(tmp);
return err;
}
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
float* param) {
std::string tmp;
int err = ReadParameter(model_config_, key, &tmp);
*param = std::stof(tmp);
return err;
}
const char* CustomErrorString(int errcode) {
switch (errcode) {
case kSuccess:
return "success";
case kInvalidModelConfig:
return "invalid model configuration";
case kGpuNotSupported:
return "execution on GPU not supported";
case kSequenceBatcher:
return "model configuration must configure sequence batcher";
case kModelControl:
return "'START' and 'READY' must be configured as the control inputs";
case kInputOutput:
return "model must have four inputs and one output with shape [-1]";
case kInputName:
return "names for input don't exist";
case kOutputName:
return "model output must be named 'OUTPUT'";
case kInputOutputDataType:
return "model inputs or outputs data_type cannot be specified";
case kInputContents:
return "unable to get input tensor values";
case kInputSize:
return "unexpected size for input tensor";
case kOutputBuffer:
return "unable to get buffer for output tensor values";
case kBatchTooBig:
return "unable to execute batch larger than max-batch-size";
case kTimesteps:
return "unable to execute more than 1 timestep at a time";
case kChunkTooBig:
return "a chunk cannot contain more samples than the WAV_DATA dimension";
default:
break;
}
return "unknown error";
}
} // kaldi
} // custom
} // inferenceserver
} // nvidia

View File

@ -1,66 +0,0 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lat/lattice-functions.h"
#include "src/core/model_config.h"
#include "src/core/model_config.pb.h"
#include "src/custom/sdk/custom_instance.h"
namespace nvidia {
namespace inferenceserver {
namespace custom {
namespace kaldi_cbe {
enum ErrorCodes {
kSuccess,
kUnknown,
kInvalidModelConfig,
kGpuNotSupported,
kSequenceBatcher,
kModelControl,
kInputOutput,
kInputName,
kOutputName,
kInputOutputDataType,
kInputContents,
kInputSize,
kOutputBuffer,
kBatchTooBig,
kTimesteps,
kChunkTooBig
};
int GetInputTensor(CustomGetNextInputFn_t input_fn, void* input_context,
const char* name, const size_t expected_byte_size,
std::vector<uint8_t>* input, const void** out);
void LatticeToString(fst::SymbolTable& word_syms,
const kaldi::CompactLattice& dlat, std::string* out_str);
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
std::string* param);
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
int* param);
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
float* param);
const char* CustomErrorString(int errcode);
} // kaldi
} // custom
} // inferenceserver
} // nvidia

View File

@ -1,423 +0,0 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "kaldi-backend.h"
#include "kaldi-backend-utils.h"
namespace nvidia {
namespace inferenceserver {
namespace custom {
namespace kaldi_cbe {
Context::Context(const std::string& instance_name,
const ModelConfig& model_config, const int gpu_device)
: instance_name_(instance_name),
model_config_(model_config),
gpu_device_(gpu_device),
num_channels_(
model_config_
.max_batch_size()), // diff in def between kaldi and trtis
int32_byte_size_(GetDataTypeByteSize(TYPE_INT32)),
int64_byte_size_(GetDataTypeByteSize(TYPE_INT64)) {}
Context::~Context() { delete word_syms_; }
int Context::ReadModelParameters() {
// Reading config
float beam, lattice_beam;
int max_active;
int frame_subsampling_factor;
float acoustic_scale;
int num_worker_threads;
int err =
ReadParameter(model_config_, "mfcc_filename",
&batched_decoder_config_.feature_opts.mfcc_config) ||
ReadParameter(
model_config_, "ivector_filename",
&batched_decoder_config_.feature_opts.ivector_extraction_config) ||
ReadParameter(model_config_, "beam", &beam) ||
ReadParameter(model_config_, "lattice_beam", &lattice_beam) ||
ReadParameter(model_config_, "max_active", &max_active) ||
ReadParameter(model_config_, "frame_subsampling_factor",
&frame_subsampling_factor) ||
ReadParameter(model_config_, "acoustic_scale", &acoustic_scale) ||
ReadParameter(model_config_, "nnet3_rxfilename", &nnet3_rxfilename_) ||
ReadParameter(model_config_, "fst_rxfilename", &fst_rxfilename_) ||
ReadParameter(model_config_, "word_syms_rxfilename",
&word_syms_rxfilename_) ||
ReadParameter(model_config_, "num_worker_threads", &num_worker_threads) ||
ReadParameter(model_config_, "max_execution_batch_size",
&max_batch_size_);
if (err) return err;
max_batch_size_ = std::max<int>(max_batch_size_, 1);
num_channels_ = std::max<int>(num_channels_, 1);
// Sanity checks
if (beam <= 0) return kInvalidModelConfig;
if (lattice_beam <= 0) return kInvalidModelConfig;
if (max_active <= 0) return kInvalidModelConfig;
if (acoustic_scale <= 0) return kInvalidModelConfig;
if (num_worker_threads <= 0) return kInvalidModelConfig;
if (num_channels_ <= max_batch_size_) return kInvalidModelConfig;
batched_decoder_config_.compute_opts.frame_subsampling_factor =
frame_subsampling_factor;
batched_decoder_config_.compute_opts.acoustic_scale = acoustic_scale;
batched_decoder_config_.decoder_opts.default_beam = beam;
batched_decoder_config_.decoder_opts.lattice_beam = lattice_beam;
batched_decoder_config_.decoder_opts.max_active = max_active;
batched_decoder_config_.num_worker_threads = num_worker_threads;
batched_decoder_config_.max_batch_size = max_batch_size_;
batched_decoder_config_.num_channels = num_channels_;
auto feature_config = batched_decoder_config_.feature_opts;
kaldi::OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
sample_freq_ = feature_info.mfcc_opts.frame_opts.samp_freq;
BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
seconds_per_chunk_ = chunk_num_samps_ / sample_freq_;
int samp_per_frame = static_cast<int>(sample_freq_ * frame_shift);
float n_input_framesf = chunk_num_samps_ / samp_per_frame;
bool is_integer = (n_input_framesf == std::floor(n_input_framesf));
if (!is_integer) {
std::cerr << "WAVE_DATA dim must be a multiple fo samples per frame ("
<< samp_per_frame << ")" << std::endl;
return kInvalidModelConfig;
}
int n_input_frames = static_cast<int>(std::floor(n_input_framesf));
batched_decoder_config_.compute_opts.frames_per_chunk = n_input_frames;
return kSuccess;
}
int Context::InitializeKaldiPipeline() {
batch_corr_ids_.reserve(max_batch_size_);
batch_wave_samples_.reserve(max_batch_size_);
batch_is_first_chunk_.reserve(max_batch_size_);
batch_is_last_chunk_.reserve(max_batch_size_);
wave_byte_buffers_.resize(max_batch_size_);
output_shape_ = {1, 1};
kaldi::CuDevice::Instantiate()
.SelectAndInitializeGpuIdWithExistingCudaContext(gpu_device_);
kaldi::CuDevice::Instantiate().AllowMultithreading();
// Loading models
{
bool binary;
kaldi::Input ki(nnet3_rxfilename_, &binary);
trans_model_.Read(ki.Stream(), binary);
am_nnet_.Read(ki.Stream(), binary);
kaldi::nnet3::SetBatchnormTestMode(true, &(am_nnet_.GetNnet()));
kaldi::nnet3::SetDropoutTestMode(true, &(am_nnet_.GetNnet()));
kaldi::nnet3::CollapseModel(kaldi::nnet3::CollapseModelConfig(),
&(am_nnet_.GetNnet()));
}
fst::Fst<fst::StdArc>* decode_fst = fst::ReadFstKaldiGeneric(fst_rxfilename_);
cuda_pipeline_.reset(
new kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline(
batched_decoder_config_, *decode_fst, am_nnet_, trans_model_));
delete decode_fst;
// Loading word syms for text output
if (word_syms_rxfilename_ != "") {
if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) {
std::cerr << "Could not read symbol table from file "
<< word_syms_rxfilename_;
return kInvalidModelConfig;
}
}
chunk_num_samps_ = cuda_pipeline_->GetNSampsPerChunk();
chunk_num_bytes_ = chunk_num_samps_ * sizeof(BaseFloat);
return kSuccess;
}
int Context::Init() {
return InputOutputSanityCheck() || ReadModelParameters() ||
InitializeKaldiPipeline();
}
bool Context::CheckPayloadError(const CustomPayload& payload) {
int err = payload.error_code;
if (err) std::cerr << "Error: " << CustomErrorString(err) << std::endl;
return (err != 0);
}
int Context::Execute(const uint32_t payload_cnt, CustomPayload* payloads,
CustomGetNextInputFn_t input_fn,
CustomGetOutputFn_t output_fn) {
// kaldi::Timer timer;
if (payload_cnt > num_channels_) return kBatchTooBig;
// Each payload is a chunk for one sequence
// Currently using dynamic batcher, not sequence batcher
for (uint32_t pidx = 0; pidx < payload_cnt; ++pidx) {
if (batch_corr_ids_.size() == max_batch_size_) FlushBatch();
CustomPayload& payload = payloads[pidx];
if (payload.batch_size != 1) payload.error_code = kTimesteps;
if (CheckPayloadError(payload)) continue;
// Get input tensors
int32_t start, dim, end, ready;
CorrelationID corr_id;
const BaseFloat* wave_buffer;
payload.error_code = GetSequenceInput(
input_fn, payload.input_context, &corr_id, &start, &ready, &dim, &end,
&wave_buffer, &wave_byte_buffers_[pidx]);
if (CheckPayloadError(payload)) continue;
if (!ready) continue;
if (dim > chunk_num_samps_) payload.error_code = kChunkTooBig;
if (CheckPayloadError(payload)) continue;
kaldi::SubVector<BaseFloat> wave_part(wave_buffer, dim);
// Initialize corr_id if first chunk
if (start) {
if (!cuda_pipeline_->TryInitCorrID(corr_id)) {
printf("ERR %i \n", __LINE__);
// TODO add error code
continue;
}
}
// Add to batch
batch_corr_ids_.push_back(corr_id);
batch_wave_samples_.push_back(wave_part);
batch_is_first_chunk_.push_back(start);
batch_is_last_chunk_.push_back(end);
if (end) {
// If last chunk, set the callback for that seq
cuda_pipeline_->SetLatticeCallback(
corr_id, [this, &output_fn, &payloads, pidx,
corr_id](kaldi::CompactLattice& clat) {
SetOutputs(clat, output_fn, payloads[pidx]);
});
}
}
FlushBatch();
cuda_pipeline_->WaitForLatticeCallbacks();
return kSuccess;
}
int Context::FlushBatch() {
if (!batch_corr_ids_.empty()) {
cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_,
batch_is_first_chunk_, batch_is_last_chunk_);
batch_corr_ids_.clear();
batch_wave_samples_.clear();
batch_is_first_chunk_.clear();
batch_is_last_chunk_.clear();
}
}
int Context::InputOutputSanityCheck() {
if (!model_config_.has_sequence_batching()) {
return kSequenceBatcher;
}
auto& batcher = model_config_.sequence_batching();
if (batcher.control_input_size() != 4) {
return kModelControl;
}
std::set<std::string> control_input_names;
for (int i = 0; i < 4; ++i)
control_input_names.insert(batcher.control_input(i).name());
if (!(control_input_names.erase("START") &&
control_input_names.erase("END") &&
control_input_names.erase("CORRID") &&
control_input_names.erase("READY"))) {
return kModelControl;
}
if (model_config_.input_size() != 2) {
return kInputOutput;
}
if ((model_config_.input(0).dims().size() != 1) ||
(model_config_.input(0).dims(0) <= 0) ||
(model_config_.input(1).dims().size() != 1) ||
(model_config_.input(1).dims(0) != 1)) {
return kInputOutput;
}
chunk_num_samps_ = model_config_.input(0).dims(0);
chunk_num_bytes_ = chunk_num_samps_ * sizeof(float);
if ((model_config_.input(0).data_type() != DataType::TYPE_FP32) ||
(model_config_.input(1).data_type() != DataType::TYPE_INT32)) {
return kInputOutputDataType;
}
if ((model_config_.input(0).name() != "WAV_DATA") ||
(model_config_.input(1).name() != "WAV_DATA_DIM")) {
return kInputName;
}
if (model_config_.output_size() != 2) return kInputOutput;
for (int ioutput = 0; ioutput < 2; ++ioutput) {
if ((model_config_.output(ioutput).dims().size() != 1) ||
(model_config_.output(ioutput).dims(0) != 1)) {
return kInputOutput;
}
if (model_config_.output(ioutput).data_type() != DataType::TYPE_STRING) {
return kInputOutputDataType;
}
}
if (model_config_.output(0).name() != "RAW_LATTICE") return kOutputName;
if (model_config_.output(1).name() != "TEXT") return kOutputName;
return kSuccess;
}
int Context::GetSequenceInput(CustomGetNextInputFn_t& input_fn,
void* input_context, CorrelationID* corr_id,
int32_t* start, int32_t* ready, int32_t* dim,
int32_t* end, const BaseFloat** wave_buffer,
std::vector<uint8_t>* input_buffer) {
int err;
//&input_buffer[0]: char pointer -> alias with any types
// wave_data[0] will holds the struct
// Get start of sequence tensor
const void* out;
err = GetInputTensor(input_fn, input_context, "WAV_DATA_DIM",
int32_byte_size_, &byte_buffer_, &out);
if (err != kSuccess) return err;
*dim = *reinterpret_cast<const int32_t*>(out);
err = GetInputTensor(input_fn, input_context, "END", int32_byte_size_,
&byte_buffer_, &out);
if (err != kSuccess) return err;
*end = *reinterpret_cast<const int32_t*>(out);
err = GetInputTensor(input_fn, input_context, "START", int32_byte_size_,
&byte_buffer_, &out);
if (err != kSuccess) return err;
*start = *reinterpret_cast<const int32_t*>(out);
err = GetInputTensor(input_fn, input_context, "READY", int32_byte_size_,
&byte_buffer_, &out);
if (err != kSuccess) return err;
*ready = *reinterpret_cast<const int32_t*>(out);
err = GetInputTensor(input_fn, input_context, "CORRID", int64_byte_size_,
&byte_buffer_, &out);
if (err != kSuccess) return err;
*corr_id = *reinterpret_cast<const CorrelationID*>(out);
// Get pointer to speech tensor
err = GetInputTensor(input_fn, input_context, "WAV_DATA", chunk_num_bytes_,
input_buffer, &out);
if (err != kSuccess) return err;
*wave_buffer = reinterpret_cast<const BaseFloat*>(out);
return kSuccess;
}
int Context::SetOutputs(kaldi::CompactLattice& clat,
CustomGetOutputFn_t output_fn, CustomPayload payload) {
int status = kSuccess;
if (payload.error_code != kSuccess) return payload.error_code;
for (int ioutput = 0; ioutput < payload.output_cnt; ++ioutput) {
const char* output_name = payload.required_output_names[ioutput];
if (!strcmp(output_name, "RAW_LATTICE")) {
std::ostringstream oss;
kaldi::WriteCompactLattice(oss, true, clat);
status = SetOutputByName(output_name, oss.str(), output_fn, payload);
if(status != kSuccess) return status;
} else if (!strcmp(output_name, "TEXT")) {
std::string output;
LatticeToString(*word_syms_, clat, &output);
status = SetOutputByName(output_name, output, output_fn, payload);
if(status != kSuccess) return status;
}
}
return status;
}
int Context::SetOutputByName(const char* output_name,
const std::string& out_bytes,
CustomGetOutputFn_t output_fn,
CustomPayload payload) {
uint32_t byte_size_with_size_int = out_bytes.size() + sizeof(int32);
void* obuffer; // output buffer
if (!output_fn(payload.output_context, output_name, output_shape_.size(),
&output_shape_[0], byte_size_with_size_int, &obuffer)) {
payload.error_code = kOutputBuffer;
return payload.error_code;
}
if (obuffer == nullptr) return kOutputBuffer;
int32* buffer_as_int = reinterpret_cast<int32*>(obuffer);
buffer_as_int[0] = out_bytes.size();
memcpy(&buffer_as_int[1], out_bytes.data(), out_bytes.size());
return kSuccess;
}
/////////////
extern "C" {
int CustomInitialize(const CustomInitializeData* data, void** custom_context) {
// Convert the serialized model config to a ModelConfig object.
ModelConfig model_config;
if (!model_config.ParseFromString(std::string(
data->serialized_model_config, data->serialized_model_config_size))) {
return kInvalidModelConfig;
}
// Create the context and validate that the model configuration is
// something that we can handle.
Context* context = new Context(std::string(data->instance_name), model_config,
data->gpu_device_id);
int err = context->Init();
if (err != kSuccess) {
return err;
}
*custom_context = static_cast<void*>(context);
return kSuccess;
}
int CustomFinalize(void* custom_context) {
if (custom_context != nullptr) {
Context* context = static_cast<Context*>(custom_context);
delete context;
}
return kSuccess;
}
const char* CustomErrorString(void* custom_context, int errcode) {
return CustomErrorString(errcode);
}
int CustomExecute(void* custom_context, const uint32_t payload_cnt,
CustomPayload* payloads, CustomGetNextInputFn_t input_fn,
CustomGetOutputFn_t output_fn) {
if (custom_context == nullptr) {
return kUnknown;
}
Context* context = static_cast<Context*>(custom_context);
return context->Execute(payload_cnt, payloads, input_fn, output_fn);
}
} // extern "C"
} // namespace kaldi_cbe
} // namespace custom
} // namespace inferenceserver
} // namespace nvidia

View File

@ -1,126 +0,0 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define HAVE_CUDA 1 // Loading Kaldi headers with GPU
#include <cfloat>
#include <sstream>
#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include "nnet3/am-nnet-simple.h"
#include "nnet3/nnet-utils.h"
#include "util/kaldi-thread.h"
#include "src/core/model_config.h"
#include "src/core/model_config.pb.h"
#include "src/custom/sdk/custom_instance.h"
using kaldi::BaseFloat;
namespace nvidia {
namespace inferenceserver {
namespace custom {
namespace kaldi_cbe {
// Context object. All state must be kept in this object.
class Context {
public:
Context(const std::string& instance_name, const ModelConfig& config,
const int gpu_device);
virtual ~Context();
// Initialize the context. Validate that the model configuration,
// etc. is something that we can handle.
int Init();
// Perform custom execution on the payloads.
int Execute(const uint32_t payload_cnt, CustomPayload* payloads,
CustomGetNextInputFn_t input_fn, CustomGetOutputFn_t output_fn);
private:
// init kaldi pipeline
int InitializeKaldiPipeline();
int InputOutputSanityCheck();
int ReadModelParameters();
int GetSequenceInput(CustomGetNextInputFn_t& input_fn, void* input_context,
CorrelationID* corr_id, int32_t* start, int32_t* ready,
int32_t* dim, int32_t* end,
const kaldi::BaseFloat** wave_buffer,
std::vector<uint8_t>* input_buffer);
int SetOutputs(kaldi::CompactLattice& clat,
CustomGetOutputFn_t output_fn, CustomPayload payload);
int SetOutputByName(const char* output_name,
const std::string& out_bytes,
CustomGetOutputFn_t output_fn,
CustomPayload payload);
bool CheckPayloadError(const CustomPayload& payload);
int FlushBatch();
// The name of this instance of the backend.
const std::string instance_name_;
// The model configuration.
const ModelConfig model_config_;
// The GPU device ID to execute on or CUSTOM_NO_GPU_DEVICE if should
// execute on CPU.
const int gpu_device_;
// Models paths
std::string nnet3_rxfilename_, fst_rxfilename_;
std::string word_syms_rxfilename_;
// batch_size
int max_batch_size_;
int num_channels_;
int num_worker_threads_;
std::vector<CorrelationID> batch_corr_ids_;
std::vector<kaldi::SubVector<kaldi::BaseFloat>> batch_wave_samples_;
std::vector<bool> batch_is_first_chunk_;
std::vector<bool> batch_is_last_chunk_;
BaseFloat sample_freq_, seconds_per_chunk_;
int chunk_num_bytes_, chunk_num_samps_;
// feature_config includes configuration for the iVector adaptation,
// as well as the basic features.
kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipelineConfig
batched_decoder_config_;
std::unique_ptr<kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline>
cuda_pipeline_;
// Maintain the state of some shared objects
kaldi::TransitionModel trans_model_;
kaldi::nnet3::AmNnetSimple am_nnet_;
fst::SymbolTable* word_syms_;
const uint64_t int32_byte_size_;
const uint64_t int64_byte_size_;
std::vector<int64_t> output_shape_;
std::vector<uint8_t> byte_buffer_;
std::vector<std::vector<uint8_t>> wave_byte_buffers_;
};
} // namespace kaldi_cbe
} // namespace custom
} // namespace inferenceserver
} // namespace nvidia