[Kaldi] Update to 21.08
This commit is contained in:
parent
26d8955cc5
commit
1a5c7556b5
|
@ -2,3 +2,4 @@ data/*
|
|||
!data/README.md
|
||||
.*.swp
|
||||
.*.swo
|
||||
.clang-format
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -11,12 +11,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb
|
||||
FROM nvcr.io/nvidia/tritonserver:20.03-py3
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TRITONSERVER_IMAGE=nvcr.io/nvidia/tritonserver:21.05-py3
|
||||
ARG KALDI_IMAGE=nvcr.io/nvidia/kaldi:21.08-py3
|
||||
ARG PYTHON_VER=3.8
|
||||
|
||||
#
|
||||
# Kaldi shared library dependencies
|
||||
#
|
||||
FROM ${KALDI_IMAGE} as kaldi
|
||||
|
||||
#
|
||||
# Builder image based on Triton Server SDK image
|
||||
#
|
||||
FROM ${TRITONSERVER_IMAGE}-sdk as builder
|
||||
ARG PYTHON_VER
|
||||
|
||||
# Kaldi dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
RUN set -eux; \
|
||||
apt-get update; \
|
||||
apt-get install -yq --no-install-recommends \
|
||||
automake \
|
||||
autoconf \
|
||||
cmake \
|
||||
|
@ -24,29 +37,80 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
gawk \
|
||||
libatlas3-base \
|
||||
libtool \
|
||||
python3.6 \
|
||||
python3.6-dev \
|
||||
python${PYTHON_VER} \
|
||||
python${PYTHON_VER}-dev \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
bc \
|
||||
libatlas-base-dev \
|
||||
zlib1g-dev
|
||||
gfortran \
|
||||
zlib1g-dev; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN mkdir /opt/trtis-kaldi && mkdir -p /workspace/model-repo/kaldi_online/1 && mkdir -p /mnt/model-repo
|
||||
# Copying static files
|
||||
# Add Kaldi dependency
|
||||
COPY --from=kaldi /opt/kaldi /opt/kaldi
|
||||
|
||||
# Set up Atlas
|
||||
RUN set -eux; \
|
||||
ln -sf /usr/include/x86_64-linux-gnu/atlas /usr/local/include/atlas; \
|
||||
ln -sf /usr/include/x86_64-linux-gnu/cblas.h /usr/local/include/cblas.h; \
|
||||
ln -sf /usr/include/x86_64-linux-gnu/clapack.h /usr/local/include/clapack.h; \
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/atlas /usr/local/lib/atlas
|
||||
|
||||
|
||||
#
|
||||
# Kaldi custom backend build
|
||||
#
|
||||
FROM builder as backend-build
|
||||
|
||||
# Build the custom backend
|
||||
COPY kaldi-asr-backend /workspace/triton-kaldi-backend
|
||||
RUN set -eux; \
|
||||
cd /workspace/triton-kaldi-backend; \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="$(pwd)/install" \
|
||||
-B build .; \
|
||||
cmake --build build --parallel; \
|
||||
cmake --install build
|
||||
|
||||
|
||||
#
|
||||
# Final server image
|
||||
#
|
||||
FROM ${TRITONSERVER_IMAGE}
|
||||
ARG PYTHON_VER
|
||||
|
||||
# Kaldi dependencies
|
||||
RUN set -eux; \
|
||||
apt-get update; \
|
||||
apt-get install -yq --no-install-recommends \
|
||||
automake \
|
||||
autoconf \
|
||||
cmake \
|
||||
flac \
|
||||
gawk \
|
||||
libatlas3-base \
|
||||
libtool \
|
||||
python${PYTHON_VER} \
|
||||
python${PYTHON_VER}-dev \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
bc \
|
||||
libatlas-base-dev \
|
||||
zlib1g-dev; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add Kaldi dependency
|
||||
COPY --from=kaldi /opt/kaldi /opt/kaldi
|
||||
|
||||
# Add Kaldi custom backend shared library and scripts
|
||||
COPY --from=backend-build /workspace/triton-kaldi-backend/install/backends/kaldi/libtriton_kaldi.so /workspace/model-repo/kaldi_online/1/
|
||||
COPY scripts /workspace/scripts
|
||||
|
||||
# Moving Kaldi to container
|
||||
COPY --from=kb /opt/kaldi /opt/kaldi
|
||||
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH
|
||||
|
||||
# Building the custom backend
|
||||
COPY trtis-kaldi-backend /workspace/trtis-kaldi-backend
|
||||
#COPY --from=cbe /workspace/install/custom-backend-sdk /workspace/trtis-kaldi-backend/custom-backend-sdk
|
||||
RUN cd /workspace/trtis-kaldi-backend && wget https://github.com/NVIDIA/tensorrt-inference-server/releases/download/v1.9.0/v1.9.0_ubuntu1804.custombackend.tar.gz -O custom-backend-sdk.tar.gz && tar -xzf custom-backend-sdk.tar.gz
|
||||
RUN cd /workspace/trtis-kaldi-backend/ && make && cp libkaldi-trtisbackend.so /workspace/model-repo/kaldi_online/1/ && cd - && rm -r /workspace/trtis-kaldi-backend
|
||||
|
||||
COPY scripts/nvidia_kaldi_trtis_entrypoint.sh /opt/trtis-kaldi
|
||||
|
||||
ENTRYPOINT ["/opt/trtis-kaldi/nvidia_kaldi_trtis_entrypoint.sh"]
|
||||
# Setup entrypoint and environment
|
||||
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:/opt/tritonserver/lib:$LD_LIBRARY_PATH
|
||||
COPY scripts/nvidia_kaldi_triton_entrypoint.sh /opt/triton-kaldi/
|
||||
VOLUME /mnt/model-repo
|
||||
ENTRYPOINT ["/opt/triton-kaldi/nvidia_kaldi_triton_entrypoint.sh"]
|
||||
CMD ["tritonserver", "--model-repo=/workspace/model-repo"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -11,11 +11,27 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb
|
||||
FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk
|
||||
ARG TRITONSERVER_IMAGE=nvcr.io/nvidia/tritonserver:21.05-py3
|
||||
ARG KALDI_IMAGE=nvcr.io/nvidia/kaldi:21.08-py3
|
||||
ARG PYTHON_VER=3.8
|
||||
|
||||
|
||||
#
|
||||
# Kaldi shared library dependencies
|
||||
#
|
||||
FROM ${KALDI_IMAGE} as kaldi
|
||||
|
||||
|
||||
#
|
||||
# Builder image based on Triton Server SDK image
|
||||
#
|
||||
FROM ${TRITONSERVER_IMAGE}-sdk as builder
|
||||
ARG PYTHON_VER
|
||||
|
||||
# Kaldi dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
RUN set -eux; \
|
||||
apt-get update; \
|
||||
apt-get install -yq --no-install-recommends \
|
||||
automake \
|
||||
autoconf \
|
||||
cmake \
|
||||
|
@ -23,21 +39,78 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
gawk \
|
||||
libatlas3-base \
|
||||
libtool \
|
||||
python3.6 \
|
||||
python3.6-dev \
|
||||
python${PYTHON_VER} \
|
||||
python${PYTHON_VER}-dev \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
bc \
|
||||
libatlas-base-dev \
|
||||
zlib1g-dev
|
||||
gfortran \
|
||||
zlib1g-dev; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Moving Kaldi to container
|
||||
COPY --from=kb /opt/kaldi /opt/kaldi
|
||||
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH
|
||||
# Add Kaldi dependency
|
||||
COPY --from=kaldi /opt/kaldi /opt/kaldi
|
||||
|
||||
# Set up Atlas
|
||||
RUN set -eux; \
|
||||
ln -sf /usr/include/x86_64-linux-gnu/atlas /usr/local/include/atlas; \
|
||||
ln -sf /usr/include/x86_64-linux-gnu/cblas.h /usr/local/include/cblas.h; \
|
||||
ln -sf /usr/include/x86_64-linux-gnu/clapack.h /usr/local/include/clapack.h; \
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/atlas /usr/local/lib/atlas
|
||||
|
||||
|
||||
#
|
||||
# Triton Kaldi client build
|
||||
#
|
||||
FROM builder as client-build
|
||||
|
||||
# Build the clients
|
||||
COPY kaldi-asr-client /workspace/triton-client
|
||||
RUN set -eux; \
|
||||
cd /workspace; \
|
||||
echo 'add_subdirectory(../../../triton-client src/c++/triton-client)' \
|
||||
>> /workspace/client/src/c++/CMakeLists.txt; \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -B build client; \
|
||||
cmake --build build --parallel --target cc-clients
|
||||
|
||||
|
||||
#
|
||||
# Final gRPC client image
|
||||
#
|
||||
FROM ${TRITONSERVER_IMAGE}
|
||||
ARG PYTHON_VER
|
||||
|
||||
# Kaldi dependencies
|
||||
RUN set -eux; \
|
||||
apt-get update; \
|
||||
apt-get install -yq --no-install-recommends \
|
||||
automake \
|
||||
autoconf \
|
||||
cmake \
|
||||
flac \
|
||||
gawk \
|
||||
libatlas3-base \
|
||||
libtool \
|
||||
python${PYTHON_VER} \
|
||||
python${PYTHON_VER}-dev \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
bc \
|
||||
libatlas-base-dev \
|
||||
zlib1g-dev; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add Kaldi dependency
|
||||
COPY --from=kaldi /opt/kaldi /opt/kaldi
|
||||
|
||||
# Add Triton clients and scripts
|
||||
COPY --from=client-build /workspace/build/cc-clients/src/c++/triton-client/kaldi-asr-parallel-client /usr/local/bin/
|
||||
COPY scripts /workspace/scripts
|
||||
|
||||
COPY kaldi-asr-client /workspace/src/clients/c++/kaldi-asr-client
|
||||
RUN echo "add_subdirectory(kaldi-asr-client)" >> "/workspace/src/clients/c++/CMakeLists.txt"
|
||||
RUN cd /workspace/build/ && make -j16 trtis-clients
|
||||
# Setup environment and entrypoint
|
||||
ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:/opt/tritonserver/lib:$LD_LIBRARY_PATH
|
||||
VOLUME /mnt/model-repo
|
||||
ENTRYPOINT ["/usr/local/bin/kaldi-asr-parallel-client"]
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk
|
||||
FROM nvcr.io/nvidia/tritonserver:21.05-py3-sdk
|
||||
|
||||
# Kaldi dependencies
|
||||
RUN apt-get update && apt-get install -y jupyter \
|
||||
|
|
|
@ -46,15 +46,18 @@ A reference model is used by all test scripts and benchmarks presented in this r
|
|||
Details about parameters can be found in the [Parameters](#parameters) section.
|
||||
|
||||
* `model path`: Configured to use the pretrained LibriSpeech model.
|
||||
* `use_tensor_cores`: 1
|
||||
* `main_q_capacity`: 30000
|
||||
* `aux_q_capacity`: 400000
|
||||
* `beam`: 10
|
||||
* `num_channels`: 4000
|
||||
* `lattice_beam`: 7
|
||||
* `max_active`: 10,000
|
||||
* `frame_subsampling_factor`: 3
|
||||
* `acoustic_scale`: 1.0
|
||||
* `num_worker_threads`: 20
|
||||
* `max_execution_batch_size`: 256
|
||||
* `max_batch_size`: 4096
|
||||
* `instance_group.count`: 2
|
||||
* `num_worker_threads`: 40
|
||||
* `max_batch_size`: 400
|
||||
* `instance_group.count`: 1
|
||||
|
||||
## Setup
|
||||
|
||||
|
@ -134,9 +137,8 @@ The model configuration parameters are passed to the model and have an impact o
|
|||
|
||||
The inference engine configuration parameters configure the inference engine. They impact performance, but not accuracy.
|
||||
|
||||
* `max_batch_size`: The maximum number of inference channels opened at a given time. If set to `4096`, then one instance will handle at most 4096 concurrent users.
|
||||
* `max_batch_size`: The size of one execution batch on the GPU. This parameter should be set as large as necessary to saturate the GPU, but not bigger. Larger batches will lead to a higher throughput, smaller batches to lower latency.
|
||||
* `num_worker_threads`: The number of CPU threads for the postprocessing CPU tasks, such as lattice determinization and text generation from the lattice.
|
||||
* `max_execution_batch_size`: The size of one execution batch on the GPU. This parameter should be set as large as necessary to saturate the GPU, but not bigger. Larger batches will lead to a higher throughput, smaller batches to lower latency.
|
||||
* `input.WAV_DATA.dims`: The maximum number of samples per chunk. The value must be a multiple of `frame_subsampling_factor * chunks_per_frame`.
|
||||
|
||||
### Inference process
|
||||
|
@ -156,7 +158,7 @@ The client can be configured through a set of parameters that define its behavio
|
|||
-u <URL for inference service and its gRPC port>
|
||||
-o : Only feed each channel at realtime speed. Simulates online clients.
|
||||
-p : Print text outputs
|
||||
|
||||
-b : Print partial (best path) text outputs
|
||||
```
|
||||
|
||||
### Input/Output
|
||||
|
@ -187,13 +189,8 @@ Even if only the best path is used, we are still generating a full lattice for b
|
|||
|
||||
Support for Kaldi ASR models that are different from the provided LibriSpeech model is experimental. However, it is possible to modify the [Model Path](#model-path) section of the config file `model-repo/kaldi_online/config.pbtxt` to set up your own model.
|
||||
|
||||
The models and Kaldi allocators are currently not shared between instances. This means that if your model is large, you may end up with not enough memory on the GPU to store two different instances. If that's the case,
|
||||
you can set `count` to `1` in the [`instance_group` section](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html#instance-groups) of the config file.
|
||||
|
||||
## Performance
|
||||
|
||||
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
||||
|
||||
|
||||
### Metrics
|
||||
|
||||
|
@ -207,8 +204,7 @@ Latency is defined as the delay between the availability of the last chunk of au
|
|||
4. *Server:* Compute inference of last chunk
|
||||
5. *Server:* Generate the raw lattice for the full utterance
|
||||
6. *Server:* Determinize the raw lattice
|
||||
7. *Server:* Generate the text output associated with the best path in the determinized lattice
|
||||
8. *Client:* Receive text output
|
||||
8. *Client:* Receive lattice output
|
||||
9. *Client:* Call callback with output
|
||||
10. ***t1** <- Current time*
|
||||
|
||||
|
@ -219,20 +215,18 @@ The latency is defined such as `latency = t1 - t0`.
|
|||
Our results were obtained by:
|
||||
|
||||
1. Building and starting the server as described in [Quick Start Guide](#quick-start-guide).
|
||||
2. Running `scripts/run_inference_all_v100.sh` and `scripts/run_inference_all_t4.sh`
|
||||
2. Running `scripts/run_inference_all_a100.sh`, `scripts/run_inference_all_v100.sh` and `scripts/run_inference_all_t4.sh`
|
||||
|
||||
|
||||
| GPU | Realtime I/O | Number of parallel audio channels | Throughput (RTFX) | Latency | | | |
|
||||
| ------ | ------ | ------ | ------ | ------ | ------ | ------ |------ |
|
||||
| | | | | 90% | 95% | 99% | Avg |
|
||||
| V100 | No | 2000 | 1506.5 | N/A | N/A | N/A | N/A |
|
||||
| V100 | Yes | 1500 | 1243.2 | 0.582 | 0.699 | 1.04 | 0.400 |
|
||||
| V100 | Yes | 1000 | 884.1 | 0.379 | 0.393 | 0.788 | 0.333 |
|
||||
| V100 | Yes | 800 | 660.2 | 0.334 | 0.340 | 0.438 | 0.288 |
|
||||
| T4 | No | 1000 | 675.2 | N/A | N/A | N/A| N/A |
|
||||
| T4 | Yes | 700 | 629.2 | 0.945 | 1.08 | 1.27 | 0.645 |
|
||||
| T4 | Yes | 400 | 373.7 | 0.579 | 0.624 | 0.758 | 0.452 |
|
||||
|
||||
| GPU | Realtime I/O | Number of parallel audio channels | Latency (s) | | | |
|
||||
| ----- | ------------ | --------------------------------- | ----------- | ----- | ----- | ----- |
|
||||
| | | | 90% | 95% | 99% | Avg |
|
||||
| A100 | Yes | 2000 | 0.11 | 0.12 | 0.14 | 0.09 |
|
||||
| V100 | Yes | 2000 | 0.42 | 0.50 | 0.61 | 0.23 |
|
||||
| V100 | Yes | 1000 | 0.09 | 0.09 | 0.11 | 0.07 |
|
||||
| T4 | Yes | 600 | 0.17 | 0.18 | 0.22 | 0.14 |
|
||||
| T4 | Yes | 400 | 0.12 | 0.13 | 0.15 | 0.10 |
|
||||
|
||||
## Release notes
|
||||
|
||||
### Changelog
|
||||
|
@ -244,5 +238,9 @@ April 2020
|
|||
* Printing WER accuracy in Triton client
|
||||
* Using the latest Kaldi GPU ASR pipeline, extended support for features (ivectors, fbanks)
|
||||
|
||||
July 2021
|
||||
* Significantly improve latency and throughput for the backend
|
||||
* Update Triton to v2.10.0
|
||||
|
||||
### Known issues
|
||||
* No multi-gpu support for the Triton integration
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17..3.20)
|
||||
|
||||
project(TritonKaldiBackend LANGUAGES C CXX)
|
||||
|
||||
#
|
||||
# Options
|
||||
#
|
||||
# Must include options required for this project as well as any
|
||||
# projects included in this one by FetchContent.
|
||||
#
|
||||
# GPU support is enabled by default because the Kaldi backend requires GPUs.
|
||||
#
|
||||
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
|
||||
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
|
||||
|
||||
set(TRITON_COMMON_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/common repo")
|
||||
set(TRITON_CORE_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/core repo")
|
||||
set(TRITON_BACKEND_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/backend repo")
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Dependencies
|
||||
#
|
||||
# FetchContent's composibility isn't very good. We must include the
|
||||
# transitive closure of all repos so that we can override the tag.
|
||||
#
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
repo-common
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
|
||||
GIT_TAG ${TRITON_COMMON_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-core
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
|
||||
GIT_TAG ${TRITON_CORE_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Declare(
|
||||
repo-backend
|
||||
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
|
||||
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
|
||||
|
||||
#
|
||||
# Shared library implementing the Triton Backend API
|
||||
#
|
||||
add_library(triton-kaldi-backend SHARED)
|
||||
add_library(TritonKaldiBackend::triton-kaldi-backend ALIAS triton-kaldi-backend)
|
||||
target_sources(triton-kaldi-backend
|
||||
PRIVATE
|
||||
triton-kaldi-backend.cc
|
||||
kaldi-backend-utils.cc
|
||||
kaldi-backend-utils.h
|
||||
)
|
||||
target_include_directories(triton-kaldi-backend SYSTEM
|
||||
PRIVATE
|
||||
$<$<BOOL:${TRITON_ENABLE_GPU}>:${CUDA_INCLUDE_DIRS}>
|
||||
/opt/kaldi/src
|
||||
/opt/kaldi/tools/openfst/include
|
||||
)
|
||||
target_include_directories(triton-kaldi-backend
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
target_compile_features(triton-kaldi-backend PRIVATE cxx_std_17)
|
||||
target_compile_options(triton-kaldi-backend
|
||||
PRIVATE
|
||||
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror>
|
||||
)
|
||||
target_link_directories(triton-kaldi-backend
|
||||
PRIVATE
|
||||
/opt/kaldi/src/lib
|
||||
)
|
||||
target_link_libraries(triton-kaldi-backend
|
||||
PRIVATE
|
||||
TritonCore::triton-core-serverapi # from repo-core
|
||||
TritonCore::triton-core-backendapi # from repo-core
|
||||
TritonCore::triton-core-serverstub # from repo-core
|
||||
TritonBackend::triton-backend-utils # from repo-backend
|
||||
-lkaldi-cudadecoder
|
||||
)
|
||||
set_target_properties(triton-kaldi-backend PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME triton_kaldi
|
||||
)
|
||||
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
|
||||
install(
|
||||
TARGETS
|
||||
triton-kaldi-backend
|
||||
EXPORT
|
||||
triton-kaldi-backend-targets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/kaldi
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/kaldi
|
||||
)
|
|
@ -0,0 +1,142 @@
|
|||
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "kaldi-backend-utils.h"
|
||||
|
||||
#include <triton/core/tritonserver.h>
|
||||
|
||||
using triton::common::TritonJson;
|
||||
|
||||
namespace triton {
|
||||
namespace backend {
|
||||
|
||||
TRITONSERVER_Error* GetInputTensor(TRITONBACKEND_Request* request,
|
||||
const std::string& input_name,
|
||||
const size_t expected_byte_size,
|
||||
std::vector<uint8_t>* buffer,
|
||||
const void** out) {
|
||||
buffer->clear(); // reset buffer
|
||||
|
||||
TRITONBACKEND_Input* input;
|
||||
RETURN_IF_ERROR(
|
||||
TRITONBACKEND_RequestInput(request, input_name.c_str(), &input));
|
||||
|
||||
uint64_t input_byte_size;
|
||||
uint32_t input_buffer_count;
|
||||
RETURN_IF_ERROR(
|
||||
TRITONBACKEND_InputProperties(input, nullptr, nullptr, nullptr, nullptr,
|
||||
&input_byte_size, &input_buffer_count));
|
||||
RETURN_ERROR_IF_FALSE(
|
||||
input_byte_size == expected_byte_size, TRITONSERVER_ERROR_INVALID_ARG,
|
||||
std::string(std::string("unexpected byte size ") +
|
||||
std::to_string(expected_byte_size) + " requested for " +
|
||||
input_name.c_str() + ", received " +
|
||||
std::to_string(input_byte_size)));
|
||||
|
||||
// The values for an input tensor are not necessarily in one
|
||||
// contiguous chunk, so we might copy the chunks into 'input' vector.
|
||||
// If possible, we use the data in place
|
||||
uint64_t total_content_byte_size = 0;
|
||||
for (uint32_t b = 0; b < input_buffer_count; ++b) {
|
||||
const void* input_buffer = nullptr;
|
||||
uint64_t input_buffer_byte_size = 0;
|
||||
TRITONSERVER_MemoryType input_memory_type = TRITONSERVER_MEMORY_CPU;
|
||||
int64_t input_memory_type_id = 0;
|
||||
RETURN_IF_ERROR(TRITONBACKEND_InputBuffer(
|
||||
input, b, &input_buffer, &input_buffer_byte_size, &input_memory_type,
|
||||
&input_memory_type_id));
|
||||
RETURN_ERROR_IF_FALSE(input_memory_type != TRITONSERVER_MEMORY_GPU,
|
||||
TRITONSERVER_ERROR_INTERNAL,
|
||||
std::string("expected input tensor in CPU memory"));
|
||||
|
||||
// Skip the copy if input already exists as a single contiguous
|
||||
// block
|
||||
if ((input_buffer_byte_size == expected_byte_size) && (b == 0)) {
|
||||
*out = input_buffer;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
buffer->insert(
|
||||
buffer->end(), static_cast<const uint8_t*>(input_buffer),
|
||||
static_cast<const uint8_t*>(input_buffer) + input_buffer_byte_size);
|
||||
total_content_byte_size += input_buffer_byte_size;
|
||||
}
|
||||
|
||||
// Make sure we end up with exactly the amount of input we expect.
|
||||
RETURN_ERROR_IF_FALSE(
|
||||
total_content_byte_size == expected_byte_size,
|
||||
TRITONSERVER_ERROR_INVALID_ARG,
|
||||
std::string(std::string("total unexpected byte size ") +
|
||||
std::to_string(expected_byte_size) + " requested for " +
|
||||
input_name.c_str() + ", received " +
|
||||
std::to_string(total_content_byte_size)));
|
||||
*out = &buffer[0];
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LatticeToString(fst::SymbolTable& word_syms,
|
||||
const kaldi::CompactLattice& dlat, std::string* out_str) {
|
||||
kaldi::CompactLattice best_path_clat;
|
||||
kaldi::CompactLatticeShortestPath(dlat, &best_path_clat);
|
||||
|
||||
kaldi::Lattice best_path_lat;
|
||||
fst::ConvertLattice(best_path_clat, &best_path_lat);
|
||||
|
||||
std::vector<int32> alignment;
|
||||
std::vector<int32> words;
|
||||
kaldi::LatticeWeight weight;
|
||||
fst::GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < words.size(); i++) {
|
||||
std::string s = word_syms.Find(words[i]);
|
||||
if (s == "") {
|
||||
LOG_MESSAGE(
|
||||
TRITONSERVER_LOG_WARN,
|
||||
("word-id " + std::to_string(words[i]) + " not in symbol table")
|
||||
.c_str());
|
||||
}
|
||||
oss << s << " ";
|
||||
}
|
||||
*out_str = std::move(oss.str());
|
||||
}
|
||||
|
||||
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
|
||||
const std::string& key, std::string* param) {
|
||||
TritonJson::Value value;
|
||||
RETURN_ERROR_IF_FALSE(
|
||||
params.Find(key.c_str(), &value), TRITONSERVER_ERROR_INVALID_ARG,
|
||||
std::string("model configuration is missing the parameter ") + key);
|
||||
RETURN_IF_ERROR(value.MemberAsString("string_value", param));
|
||||
return nullptr; // success
|
||||
}
|
||||
|
||||
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
|
||||
const std::string& key, int* param) {
|
||||
std::string tmp;
|
||||
RETURN_IF_ERROR(ReadParameter(params, key, &tmp));
|
||||
*param = std::stoi(tmp);
|
||||
return nullptr; // success
|
||||
}
|
||||
|
||||
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
|
||||
const std::string& key, float* param) {
|
||||
std::string tmp;
|
||||
RETURN_IF_ERROR(ReadParameter(params, key, &tmp));
|
||||
*param = std::stof(tmp);
|
||||
return nullptr; // success
|
||||
}
|
||||
|
||||
} // namespace backend
|
||||
} // namespace triton
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <lat/lattice-functions.h>
|
||||
#include <triton/backend/backend_common.h>
|
||||
#include <triton/common/triton_json.h>
|
||||
|
||||
namespace triton {
|
||||
namespace backend {
|
||||
|
||||
using triton::common::TritonJson;
|
||||
|
||||
#define RETURN_AND_LOG_IF_ERROR(X, MSG) \
|
||||
do { \
|
||||
TRITONSERVER_Error* rie_err__ = (X); \
|
||||
if (rie_err__ != nullptr) { \
|
||||
LOG_MESSAGE(TRITONSERVER_LOG_INFO, MSG); \
|
||||
return rie_err__; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
TRITONSERVER_Error* GetInputTensor(TRITONBACKEND_Request* request,
|
||||
const std::string& input_name,
|
||||
const size_t expected_byte_size,
|
||||
std::vector<uint8_t>* input,
|
||||
const void** out);
|
||||
|
||||
TRITONSERVER_Error* LatticeToString(TRITONBACKEND_Request* request,
|
||||
const std::string& input_name, char* buffer,
|
||||
size_t* buffer_byte_size);
|
||||
|
||||
void LatticeToString(fst::SymbolTable& word_syms,
|
||||
const kaldi::CompactLattice& dlat, std::string* out_str);
|
||||
|
||||
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
|
||||
const std::string& key, std::string* param);
|
||||
|
||||
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
|
||||
const std::string& key, int* param);
|
||||
TRITONSERVER_Error* ReadParameter(TritonJson::Value& params,
|
||||
const std::string& key, float* param);
|
||||
|
||||
} // namespace backend
|
||||
} // namespace triton
|
File diff suppressed because it is too large
Load Diff
|
@ -1,70 +1,78 @@
|
|||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.17..3.20)
|
||||
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
add_executable(kaldi_asr_parallel_client kaldi_asr_parallel_client.cc asr_client_imp.cc)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE request
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE protobuf::libprotobuf
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
kaldi_asr_parallel_client
|
||||
# gRPC client for custom Kaldi backend
|
||||
#
|
||||
add_executable(kaldi-asr-parallel-client)
|
||||
add_executable(TritonKaldiGrpcClient::kaldi-asr-parallel-client ALIAS kaldi-asr-parallel-client)
|
||||
target_sources(kaldi-asr-parallel-client
|
||||
PRIVATE
|
||||
/opt/kaldi/src/
|
||||
kaldi_asr_parallel_client.cc
|
||||
asr_client_imp.cc
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") # openfst yields many warnings
|
||||
target_include_directories(
|
||||
kaldi_asr_parallel_client
|
||||
target_include_directories(kaldi-asr-parallel-client SYSTEM
|
||||
PRIVATE
|
||||
/opt/kaldi/tools/openfst-1.6.7/include/
|
||||
/opt/kaldi/src
|
||||
/opt/kaldi/tools/openfst/include
|
||||
)
|
||||
target_include_directories(kaldi-asr-parallel-client
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
target_compile_features(kaldi-asr-parallel-client PRIVATE cxx_std_17)
|
||||
target_compile_options(kaldi-asr-parallel-client
|
||||
PRIVATE
|
||||
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror>
|
||||
)
|
||||
target_link_directories(kaldi-asr-parallel-client
|
||||
PRIVATE
|
||||
/opt/kaldi/src/lib
|
||||
)
|
||||
target_link_libraries(kaldi-asr-parallel-client
|
||||
PRIVATE
|
||||
TritonClient::grpcclient_static
|
||||
-lkaldi-base
|
||||
-lkaldi-util
|
||||
-lkaldi-matrix
|
||||
-lkaldi-feat
|
||||
-lkaldi-lat
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE /opt/kaldi/src/lib/libkaldi-feat.so
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE /opt/kaldi/src/lib/libkaldi-util.so
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE /opt/kaldi/src/lib/libkaldi-matrix.so
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE /opt/kaldi/src/lib/libkaldi-base.so
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
kaldi_asr_parallel_client
|
||||
PRIVATE /opt/kaldi/src/lat/libkaldi-lat.so
|
||||
)
|
||||
#
|
||||
# Install
|
||||
#
|
||||
include(GNUInstallDirs)
|
||||
|
||||
install(
|
||||
TARGETS kaldi_asr_parallel_client
|
||||
RUNTIME DESTINATION bin
|
||||
TARGETS
|
||||
kaldi-asr-parallel-client
|
||||
EXPORT
|
||||
kaldi-asr-parallel-client-targets
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +13,9 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "asr_client_imp.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
|
@ -33,117 +35,179 @@
|
|||
} \
|
||||
}
|
||||
|
||||
void TRTISASRClient::CreateClientContext() {
|
||||
contextes_.emplace_back();
|
||||
ClientContext& client = contextes_.back();
|
||||
void TritonASRClient::CreateClientContext() {
|
||||
clients_.emplace_back();
|
||||
TritonClient& client = clients_.back();
|
||||
FAIL_IF_ERR(nic::InferenceServerGrpcClient::Create(&client.triton_client,
|
||||
url_, /*verbose*/ false),
|
||||
"unable to create triton client");
|
||||
|
||||
FAIL_IF_ERR(
|
||||
nic::InferGrpcStreamContext::Create(&client.trtis_context,
|
||||
/*corr_id*/ -1, url_, model_name_,
|
||||
/*model_version*/ -1,
|
||||
/*verbose*/ false),
|
||||
"unable to create context");
|
||||
client.triton_client->StartStream(
|
||||
[&](nic::InferResult* result) {
|
||||
double end_timestamp = gettime_monotonic();
|
||||
std::unique_ptr<nic::InferResult> result_ptr(result);
|
||||
FAIL_IF_ERR(result_ptr->RequestStatus(),
|
||||
"inference request failed");
|
||||
std::string request_id;
|
||||
FAIL_IF_ERR(result_ptr->Id(&request_id),
|
||||
"unable to get request id for response");
|
||||
uint64_t corr_id =
|
||||
std::stoi(std::string(request_id, 0, request_id.find("_")));
|
||||
bool end_of_stream = (request_id.back() == '1');
|
||||
if (!end_of_stream) {
|
||||
if (print_partial_results_) {
|
||||
std::vector<std::string> text;
|
||||
FAIL_IF_ERR(result_ptr->StringData("TEXT", &text),
|
||||
"unable to get TEXT output");
|
||||
std::lock_guard<std::mutex> lk(stdout_m_);
|
||||
std::cout << "CORR_ID " << corr_id << "\t[partial]\t" << text[0]
|
||||
<< '\n';
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
double start_timestamp;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(start_timestamps_m_);
|
||||
auto it = start_timestamps_.find(corr_id);
|
||||
if (it != start_timestamps_.end()) {
|
||||
start_timestamp = it->second;
|
||||
start_timestamps_.erase(it);
|
||||
} else {
|
||||
std::cerr << "start_timestamp not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (print_results_) {
|
||||
std::vector<std::string> text;
|
||||
FAIL_IF_ERR(result_ptr->StringData(ctm_ ? "CTM" : "TEXT", &text),
|
||||
"unable to get TEXT or CTM output");
|
||||
std::lock_guard<std::mutex> lk(stdout_m_);
|
||||
std::cout << "CORR_ID " << corr_id;
|
||||
std::cout << (ctm_ ? "\n" : "\t\t");
|
||||
std::cout << text[0] << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string> lattice_bytes;
|
||||
FAIL_IF_ERR(result_ptr->StringData("RAW_LATTICE", &lattice_bytes),
|
||||
"unable to get RAW_LATTICE output");
|
||||
|
||||
{
|
||||
double elapsed = end_timestamp - start_timestamp;
|
||||
std::lock_guard<std::mutex> lk(results_m_);
|
||||
results_.insert(
|
||||
{corr_id, {std::move(lattice_bytes[0]), elapsed}});
|
||||
}
|
||||
|
||||
n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
|
||||
},
|
||||
false),
|
||||
"unable to establish a streaming connection to server");
|
||||
}
|
||||
|
||||
void TRTISASRClient::SendChunk(ni::CorrelationID corr_id,
|
||||
bool start_of_sequence, bool end_of_sequence,
|
||||
float* chunk, int chunk_byte_size) {
|
||||
ClientContext* client = &contextes_[corr_id % ncontextes_];
|
||||
nic::InferContext& context = *client->trtis_context;
|
||||
if (start_of_sequence) n_in_flight_.fetch_add(1, std::memory_order_consume);
|
||||
|
||||
void TritonASRClient::SendChunk(uint64_t corr_id, bool start_of_sequence,
|
||||
bool end_of_sequence, float* chunk,
|
||||
int chunk_byte_size, const uint64_t index) {
|
||||
// Setting options
|
||||
std::unique_ptr<nic::InferContext::Options> options;
|
||||
FAIL_IF_ERR(nic::InferContext::Options::Create(&options),
|
||||
"unable to create inference options");
|
||||
options->SetBatchSize(1);
|
||||
options->SetFlags(0);
|
||||
options->SetCorrelationId(corr_id);
|
||||
if (start_of_sequence)
|
||||
options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_START,
|
||||
start_of_sequence);
|
||||
if (end_of_sequence) {
|
||||
options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_END,
|
||||
end_of_sequence);
|
||||
for (const auto& output : context.Outputs()) {
|
||||
if (output->Name() == "TEXT" && !print_results_)
|
||||
continue; // no need for text output if not printing
|
||||
options->AddRawResult(output);
|
||||
}
|
||||
}
|
||||
nic::InferOptions options(model_name_);
|
||||
options.sequence_id_ = corr_id;
|
||||
options.sequence_start_ = start_of_sequence;
|
||||
options.sequence_end_ = end_of_sequence;
|
||||
options.request_id_ = std::to_string(corr_id) + "_" + std::to_string(index) +
|
||||
"_" + (start_of_sequence ? "1" : "0") + "_" +
|
||||
(end_of_sequence ? "1" : "0");
|
||||
|
||||
FAIL_IF_ERR(context.SetRunOptions(*options), "unable to set context options");
|
||||
std::shared_ptr<nic::InferContext::Input> in_wave_data, in_wave_data_dim;
|
||||
FAIL_IF_ERR(context.GetInput("WAV_DATA", &in_wave_data),
|
||||
"unable to get WAV_DATA");
|
||||
FAIL_IF_ERR(context.GetInput("WAV_DATA_DIM", &in_wave_data_dim),
|
||||
"unable to get WAV_DATA_DIM");
|
||||
|
||||
// Wave data input
|
||||
FAIL_IF_ERR(in_wave_data->Reset(), "unable to reset WAVE_DATA");
|
||||
// Initialize the inputs with the data.
|
||||
nic::InferInput* wave_data_ptr;
|
||||
std::vector<int64_t> wav_shape{1, samps_per_chunk_};
|
||||
FAIL_IF_ERR(
|
||||
nic::InferInput::Create(&wave_data_ptr, "WAV_DATA", wav_shape, "FP32"),
|
||||
"unable to create 'WAV_DATA'");
|
||||
std::shared_ptr<nic::InferInput> wave_data_in(wave_data_ptr);
|
||||
FAIL_IF_ERR(wave_data_in->Reset(), "unable to reset 'WAV_DATA'");
|
||||
uint8_t* wave_data = reinterpret_cast<uint8_t*>(chunk);
|
||||
if (chunk_byte_size < max_chunk_byte_size_) {
|
||||
std::memcpy(&chunk_buf_[0], chunk, chunk_byte_size);
|
||||
wave_data = &chunk_buf_[0];
|
||||
}
|
||||
FAIL_IF_ERR(in_wave_data->SetRaw(wave_data, max_chunk_byte_size_),
|
||||
"unable to set data for WAVE_DATA");
|
||||
FAIL_IF_ERR(wave_data_in->AppendRaw(wave_data, max_chunk_byte_size_),
|
||||
"unable to set data for 'WAV_DATA'");
|
||||
|
||||
// Dim
|
||||
FAIL_IF_ERR(in_wave_data_dim->Reset(), "unable to reset WAVE_DATA_DIM");
|
||||
nic::InferInput* dim_ptr;
|
||||
std::vector<int64_t> shape{1, 1};
|
||||
FAIL_IF_ERR(nic::InferInput::Create(&dim_ptr, "WAV_DATA_DIM", shape, "INT32"),
|
||||
"unable to create 'WAV_DATA_DIM'");
|
||||
std::shared_ptr<nic::InferInput> dim_in(dim_ptr);
|
||||
FAIL_IF_ERR(dim_in->Reset(), "unable to reset WAVE_DATA_DIM");
|
||||
int nsamples = chunk_byte_size / sizeof(float);
|
||||
FAIL_IF_ERR(in_wave_data_dim->SetRaw(reinterpret_cast<uint8_t*>(&nsamples),
|
||||
sizeof(int32_t)),
|
||||
"unable to set data for WAVE_DATA_DIM");
|
||||
FAIL_IF_ERR(
|
||||
dim_in->AppendRaw(reinterpret_cast<uint8_t*>(&nsamples), sizeof(int32_t)),
|
||||
"unable to set data for WAVE_DATA_DIM");
|
||||
|
||||
total_audio_ += (static_cast<double>(nsamples) / 16000.); // TODO freq
|
||||
double start = gettime_monotonic();
|
||||
FAIL_IF_ERR(context.AsyncRun([corr_id, end_of_sequence, start, this](
|
||||
nic::InferContext* ctx,
|
||||
const std::shared_ptr<
|
||||
nic::InferContext::Request>& request) {
|
||||
if (end_of_sequence) {
|
||||
double elapsed = gettime_monotonic() - start;
|
||||
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
|
||||
ctx->GetAsyncRunResults(request, &results);
|
||||
std::vector<nic::InferInput*> inputs = {wave_data_in.get(), dim_in.get()};
|
||||
|
||||
if (results.empty()) {
|
||||
std::cerr << "Warning: Could not read "
|
||||
"output for corr_id "
|
||||
<< corr_id << std::endl;
|
||||
} else {
|
||||
if (print_results_) {
|
||||
std::string text;
|
||||
FAIL_IF_ERR(results["TEXT"]->GetRawAtCursor(0, &text),
|
||||
"unable to get TEXT output");
|
||||
std::lock_guard<std::mutex> lk(stdout_m_);
|
||||
std::cout << "CORR_ID " << corr_id << "\t\t" << text << std::endl;
|
||||
}
|
||||
std::vector<const nic::InferRequestedOutput*> outputs;
|
||||
std::shared_ptr<nic::InferRequestedOutput> raw_lattice, text;
|
||||
outputs.reserve(2);
|
||||
if (end_of_sequence) {
|
||||
nic::InferRequestedOutput* raw_lattice_ptr;
|
||||
FAIL_IF_ERR(
|
||||
nic::InferRequestedOutput::Create(&raw_lattice_ptr, "RAW_LATTICE"),
|
||||
"unable to get 'RAW_LATTICE'");
|
||||
raw_lattice.reset(raw_lattice_ptr);
|
||||
outputs.push_back(raw_lattice.get());
|
||||
|
||||
std::string lattice_bytes;
|
||||
FAIL_IF_ERR(results["RAW_LATTICE"]->GetRawAtCursor(0, &lattice_bytes),
|
||||
"unable to get RAW_LATTICE output");
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(results_m_);
|
||||
results_.insert({corr_id, {std::move(lattice_bytes), elapsed}});
|
||||
}
|
||||
}
|
||||
n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
|
||||
// Request the TEXT results only when required for printing
|
||||
if (print_results_) {
|
||||
nic::InferRequestedOutput* text_ptr;
|
||||
FAIL_IF_ERR(
|
||||
nic::InferRequestedOutput::Create(&text_ptr, ctm_ ? "CTM" : "TEXT"),
|
||||
"unable to get 'TEXT' or 'CTM'");
|
||||
text.reset(text_ptr);
|
||||
outputs.push_back(text.get());
|
||||
}
|
||||
}),
|
||||
} else if (print_partial_results_) {
|
||||
nic::InferRequestedOutput* text_ptr;
|
||||
FAIL_IF_ERR(nic::InferRequestedOutput::Create(&text_ptr, "TEXT"),
|
||||
"unable to get 'TEXT'");
|
||||
text.reset(text_ptr);
|
||||
outputs.push_back(text.get());
|
||||
}
|
||||
|
||||
total_audio_ += (static_cast<double>(nsamples) / samp_freq_);
|
||||
|
||||
if (start_of_sequence) {
|
||||
n_in_flight_.fetch_add(1, std::memory_order_consume);
|
||||
}
|
||||
|
||||
// Record the timestamp when the last chunk was made available.
|
||||
if (end_of_sequence) {
|
||||
std::lock_guard<std::mutex> lk(start_timestamps_m_);
|
||||
start_timestamps_[corr_id] = gettime_monotonic();
|
||||
}
|
||||
|
||||
TritonClient* client = &clients_[corr_id % nclients_];
|
||||
// nic::InferenceServerGrpcClient& triton_client = *client->triton_client;
|
||||
FAIL_IF_ERR(client->triton_client->AsyncStreamInfer(options, inputs, outputs),
|
||||
"unable to run model");
|
||||
}
|
||||
|
||||
void TRTISASRClient::WaitForCallbacks() {
|
||||
int n;
|
||||
while ((n = n_in_flight_.load(std::memory_order_consume))) {
|
||||
void TritonASRClient::WaitForCallbacks() {
|
||||
while (n_in_flight_.load(std::memory_order_consume)) {
|
||||
usleep(1000);
|
||||
}
|
||||
}
|
||||
|
||||
void TRTISASRClient::PrintStats(bool print_latency_stats) {
|
||||
void TritonASRClient::PrintStats(bool print_latency_stats,
|
||||
bool print_throughput) {
|
||||
double now = gettime_monotonic();
|
||||
double diff = now - started_at_;
|
||||
double rtf = total_audio_ / diff;
|
||||
std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl;
|
||||
if (print_throughput)
|
||||
std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl;
|
||||
std::vector<double> latencies;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(results_m_);
|
||||
|
@ -176,20 +240,33 @@ void TRTISASRClient::PrintStats(bool print_latency_stats) {
|
|||
}
|
||||
}
|
||||
|
||||
TRTISASRClient::TRTISASRClient(const std::string& url,
|
||||
const std::string& model_name,
|
||||
const int ncontextes, bool print_results)
|
||||
TritonASRClient::TritonASRClient(const std::string& url,
|
||||
const std::string& model_name,
|
||||
const int nclients, bool print_results,
|
||||
bool print_partial_results, bool ctm,
|
||||
float samp_freq)
|
||||
: url_(url),
|
||||
model_name_(model_name),
|
||||
ncontextes_(ncontextes),
|
||||
print_results_(print_results) {
|
||||
ncontextes_ = std::max(ncontextes_, 1);
|
||||
for (int i = 0; i < ncontextes_; ++i) CreateClientContext();
|
||||
nclients_(nclients),
|
||||
print_results_(print_results),
|
||||
print_partial_results_(print_partial_results),
|
||||
ctm_(ctm),
|
||||
samp_freq_(samp_freq) {
|
||||
nclients_ = std::max(nclients_, 1);
|
||||
for (int i = 0; i < nclients_; ++i) CreateClientContext();
|
||||
|
||||
std::shared_ptr<nic::InferContext::Input> in_wave_data;
|
||||
FAIL_IF_ERR(contextes_[0].trtis_context->GetInput("WAV_DATA", &in_wave_data),
|
||||
"unable to get WAV_DATA");
|
||||
max_chunk_byte_size_ = in_wave_data->ByteSize();
|
||||
inference::ModelMetadataResponse model_metadata;
|
||||
FAIL_IF_ERR(
|
||||
clients_[0].triton_client->ModelMetadata(&model_metadata, model_name),
|
||||
"unable to get model metadata");
|
||||
|
||||
for (const auto& in_tensor : model_metadata.inputs()) {
|
||||
if (in_tensor.name().compare("WAV_DATA") == 0) {
|
||||
samps_per_chunk_ = in_tensor.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
max_chunk_byte_size_ = samps_per_chunk_ * sizeof(float);
|
||||
chunk_buf_.resize(max_chunk_byte_size_);
|
||||
shape_ = {max_chunk_byte_size_};
|
||||
n_in_flight_.store(0);
|
||||
|
@ -197,20 +274,24 @@ TRTISASRClient::TRTISASRClient(const std::string& url,
|
|||
total_audio_ = 0;
|
||||
}
|
||||
|
||||
void TRTISASRClient::WriteLatticesToFile(
|
||||
void TritonASRClient::WriteLatticesToFile(
|
||||
const std::string& clat_wspecifier,
|
||||
const std::unordered_map<ni::CorrelationID, std::string>&
|
||||
corr_id_and_keys) {
|
||||
const std::unordered_map<uint64_t, std::string>& corr_id_and_keys) {
|
||||
kaldi::CompactLatticeWriter clat_writer;
|
||||
clat_writer.Open(clat_wspecifier);
|
||||
std::unordered_map<std::string, size_t> key_count;
|
||||
std::lock_guard<std::mutex> lk(results_m_);
|
||||
for (auto& p : corr_id_and_keys) {
|
||||
ni::CorrelationID corr_id = p.first;
|
||||
const std::string& key = p.second;
|
||||
uint64_t corr_id = p.first;
|
||||
std::string key = p.second;
|
||||
const auto iter = key_count[key]++;
|
||||
if (iter > 0) {
|
||||
key += std::to_string(iter);
|
||||
}
|
||||
auto it = results_.find(corr_id);
|
||||
if(it == results_.end()) {
|
||||
std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl;
|
||||
continue;
|
||||
if (it == results_.end()) {
|
||||
std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl;
|
||||
continue;
|
||||
}
|
||||
const std::string& raw_lattice = it->second.raw_lattice;
|
||||
// We could in theory write directly the binary hold in raw_lattice (it is
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -12,15 +12,16 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <grpc_client.h>
|
||||
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "request_grpc.h"
|
||||
#ifndef TRITON_KALDI_ASR_CLIENT_H_
|
||||
#define TRITON_KALDI_ASR_CLIENT_H_
|
||||
|
||||
#ifndef TRTIS_KALDI_ASR_CLIENT_H_
|
||||
#define TRTIS_KALDI_ASR_CLIENT_H_
|
||||
namespace ni = nvidia::inferenceserver;
|
||||
namespace nic = nvidia::inferenceserver::client;
|
||||
|
||||
|
@ -33,16 +34,16 @@ double inline gettime_monotonic() {
|
|||
return time;
|
||||
}
|
||||
|
||||
class TRTISASRClient {
|
||||
struct ClientContext {
|
||||
std::unique_ptr<nic::InferContext> trtis_context;
|
||||
class TritonASRClient {
|
||||
struct TritonClient {
|
||||
std::unique_ptr<nic::InferenceServerGrpcClient> triton_client;
|
||||
};
|
||||
|
||||
std::string url_;
|
||||
std::string model_name_;
|
||||
|
||||
std::vector<ClientContext> contextes_;
|
||||
int ncontextes_;
|
||||
std::vector<TritonClient> clients_;
|
||||
int nclients_;
|
||||
std::vector<uint8_t> chunk_buf_;
|
||||
std::vector<int64_t> shape_;
|
||||
int max_chunk_byte_size_;
|
||||
|
@ -50,26 +51,36 @@ class TRTISASRClient {
|
|||
double started_at_;
|
||||
double total_audio_;
|
||||
bool print_results_;
|
||||
bool print_partial_results_;
|
||||
bool ctm_;
|
||||
std::mutex stdout_m_;
|
||||
int samps_per_chunk_;
|
||||
float samp_freq_;
|
||||
|
||||
struct Result {
|
||||
std::string raw_lattice;
|
||||
double latency;
|
||||
};
|
||||
|
||||
std::unordered_map<ni::CorrelationID, Result> results_;
|
||||
std::unordered_map<uint64_t, double> start_timestamps_;
|
||||
std::mutex start_timestamps_m_;
|
||||
|
||||
std::unordered_map<uint64_t, Result> results_;
|
||||
std::mutex results_m_;
|
||||
|
||||
public:
|
||||
TritonASRClient(const std::string& url, const std::string& model_name,
|
||||
const int ncontextes, bool print_results,
|
||||
bool print_partial_results, bool ctm, float samp_freq);
|
||||
|
||||
void CreateClientContext();
|
||||
void SendChunk(uint64_t corr_id, bool start_of_sequence, bool end_of_sequence,
|
||||
float* chunk, int chunk_byte_size);
|
||||
float* chunk, int chunk_byte_size, uint64_t index);
|
||||
void WaitForCallbacks();
|
||||
void PrintStats(bool print_latency_stats);
|
||||
void WriteLatticesToFile(const std::string &clat_wspecifier, const std::unordered_map<ni::CorrelationID, std::string> &corr_id_and_keys);
|
||||
|
||||
TRTISASRClient(const std::string& url, const std::string& model_name,
|
||||
const int ncontextes, bool print_results);
|
||||
void PrintStats(bool print_latency_stats, bool print_throughput);
|
||||
void WriteLatticesToFile(
|
||||
const std::string& clat_wspecifier,
|
||||
const std::unordered_map<uint64_t, std::string>& corr_id_and_keys);
|
||||
};
|
||||
|
||||
#endif // TRTIS_KALDI_ASR_CLIENT_H_
|
||||
#endif // TRITON_KALDI_ASR_CLIENT_H_
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -13,9 +13,12 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "asr_client_imp.h"
|
||||
#include "feat/wave-reader.h" // to read the wav.scp
|
||||
#include "util/kaldi-table.h"
|
||||
|
@ -41,6 +44,8 @@ void Usage(char** argv, const std::string& msg = std::string()) {
|
|||
"online clients."
|
||||
<< std::endl;
|
||||
std::cerr << "\t-p : Print text outputs" << std::endl;
|
||||
std::cerr << "\t-b : Print partial (best path) text outputs" << std::endl;
|
||||
//std::cerr << "\t-t : Print text with timings (CTM)" << std::endl;
|
||||
|
||||
std::cerr << std::endl;
|
||||
exit(1);
|
||||
|
@ -49,7 +54,7 @@ void Usage(char** argv, const std::string& msg = std::string()) {
|
|||
int main(int argc, char** argv) {
|
||||
std::cout << "\n";
|
||||
std::cout << "==================================================\n"
|
||||
<< "============= TRTIS Kaldi ASR Client =============\n"
|
||||
<< "============= Triton Kaldi ASR Client ============\n"
|
||||
<< "==================================================\n"
|
||||
<< std::endl;
|
||||
|
||||
|
@ -65,14 +70,15 @@ int main(int argc, char** argv) {
|
|||
size_t nchannels = 1000;
|
||||
int niterations = 5;
|
||||
bool verbose = false;
|
||||
float samp_freq = 16000;
|
||||
int ncontextes = 10;
|
||||
int nclients = 10;
|
||||
bool online = false;
|
||||
bool print_results = false;
|
||||
bool print_partial_results = false;
|
||||
bool ctm = false;
|
||||
|
||||
// Parse commandline...
|
||||
int opt;
|
||||
while ((opt = getopt(argc, argv, "va:u:i:c:ophl:")) != -1) {
|
||||
while ((opt = getopt(argc, argv, "va:u:i:c:otpbhl:")) != -1) {
|
||||
switch (opt) {
|
||||
case 'i':
|
||||
niterations = std::atoi(optarg);
|
||||
|
@ -95,6 +101,13 @@ int main(int argc, char** argv) {
|
|||
case 'p':
|
||||
print_results = true;
|
||||
break;
|
||||
case 'b':
|
||||
print_partial_results = true;
|
||||
break;
|
||||
case 't':
|
||||
ctm = true;
|
||||
print_results = true;
|
||||
break;
|
||||
case 'l':
|
||||
chunk_length = std::atoi(optarg);
|
||||
break;
|
||||
|
@ -116,19 +129,21 @@ int main(int argc, char** argv) {
|
|||
std::cout << "Server URL\t\t\t: " << url << std::endl;
|
||||
std::cout << "Print text outputs\t\t: " << (print_results ? "Yes" : "No")
|
||||
<< std::endl;
|
||||
std::cout << "Print partial text outputs\t: "
|
||||
<< (print_partial_results ? "Yes" : "No") << std::endl;
|
||||
std::cout << "Online - Realtime I/O\t\t: " << (online ? "Yes" : "No")
|
||||
<< std::endl;
|
||||
std::cout << std::endl;
|
||||
|
||||
float chunk_seconds = (double)chunk_length / samp_freq;
|
||||
// need to read wav files
|
||||
SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
|
||||
|
||||
float samp_freq = 0;
|
||||
double total_audio = 0;
|
||||
// pre-loading data
|
||||
// we don't want to measure I/O
|
||||
std::vector<std::shared_ptr<WaveData>> all_wav;
|
||||
std::vector<std::string> all_wav_keys;
|
||||
|
||||
// need to read wav files
|
||||
SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
|
||||
{
|
||||
std::cout << "Loading eval dataset..." << std::flush;
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
|
@ -138,90 +153,119 @@ int main(int argc, char** argv) {
|
|||
all_wav.push_back(wave_data);
|
||||
all_wav_keys.push_back(utt);
|
||||
total_audio += wave_data->Duration();
|
||||
samp_freq = wave_data->SampFreq();
|
||||
}
|
||||
std::cout << "done" << std::endl;
|
||||
}
|
||||
|
||||
if (all_wav.empty()) {
|
||||
std::cerr << "Empty dataset";
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::cout << "Loaded dataset with " << all_wav.size()
|
||||
<< " utterances, frequency " << samp_freq << "hz, total audio "
|
||||
<< total_audio << " seconds" << std::endl;
|
||||
|
||||
double chunk_seconds = (double)chunk_length / samp_freq;
|
||||
double seconds_per_sample = chunk_seconds / chunk_length;
|
||||
|
||||
struct Stream {
|
||||
std::shared_ptr<WaveData> wav;
|
||||
ni::CorrelationID corr_id;
|
||||
uint64_t corr_id;
|
||||
int offset;
|
||||
float send_next_chunk_at;
|
||||
std::atomic<bool> received_output;
|
||||
double send_next_chunk_at;
|
||||
|
||||
Stream(const std::shared_ptr<WaveData>& _wav, ni::CorrelationID _corr_id)
|
||||
: wav(_wav), corr_id(_corr_id), offset(0), received_output(true) {
|
||||
send_next_chunk_at = gettime_monotonic();
|
||||
Stream(const std::shared_ptr<WaveData>& _wav, uint64_t _corr_id,
|
||||
double _send_next_chunk_at)
|
||||
: wav(_wav),
|
||||
corr_id(_corr_id),
|
||||
offset(0),
|
||||
send_next_chunk_at(_send_next_chunk_at) {}
|
||||
|
||||
bool operator<(const Stream& other) const {
|
||||
return (send_next_chunk_at > other.send_next_chunk_at);
|
||||
}
|
||||
};
|
||||
|
||||
std::cout << "Opening GRPC contextes..." << std::flush;
|
||||
TRTISASRClient asr_client(url, model_name, ncontextes, print_results);
|
||||
std::unordered_map<uint64_t, std::string> corr_id_and_keys;
|
||||
TritonASRClient asr_client(url, model_name, nclients, print_results,
|
||||
print_partial_results, ctm, samp_freq);
|
||||
std::cout << "done" << std::endl;
|
||||
std::cout << "Streaming utterances..." << std::flush;
|
||||
std::vector<std::unique_ptr<Stream>> curr_tasks, next_tasks;
|
||||
curr_tasks.reserve(nchannels);
|
||||
next_tasks.reserve(nchannels);
|
||||
std::cout << "Streaming utterances..." << std::endl;
|
||||
std::priority_queue<Stream> streams;
|
||||
size_t all_wav_i = 0;
|
||||
size_t all_wav_max = all_wav.size() * niterations;
|
||||
uint64_t index = 0;
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis(0.0, 1.0);
|
||||
|
||||
bool add_random_offset = true;
|
||||
while (true) {
|
||||
while (curr_tasks.size() < nchannels && all_wav_i < all_wav_max) {
|
||||
while (streams.size() < nchannels && all_wav_i < all_wav_max) {
|
||||
// Creating new tasks
|
||||
uint64_t corr_id = static_cast<uint64_t>(all_wav_i) + 1;
|
||||
auto all_wav_i_modulo = all_wav_i % (all_wav.size());
|
||||
double stream_will_start_at = gettime_monotonic();
|
||||
if (add_random_offset) stream_will_start_at += dis(gen);
|
||||
double first_chunk_available_at =
|
||||
stream_will_start_at +
|
||||
std::min(static_cast<double>(all_wav[all_wav_i_modulo]->Duration()),
|
||||
chunk_seconds);
|
||||
|
||||
std::unique_ptr<Stream> ptr(
|
||||
new Stream(all_wav[all_wav_i % (all_wav.size())], corr_id));
|
||||
curr_tasks.emplace_back(std::move(ptr));
|
||||
corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i_modulo]});
|
||||
streams.emplace(all_wav[all_wav_i_modulo], corr_id,
|
||||
first_chunk_available_at);
|
||||
++all_wav_i;
|
||||
}
|
||||
// If still empty, done
|
||||
if (curr_tasks.empty()) break;
|
||||
if (streams.empty()) break;
|
||||
|
||||
for (size_t itask = 0; itask < curr_tasks.size(); ++itask) {
|
||||
Stream& task = *(curr_tasks[itask]);
|
||||
|
||||
SubVector<BaseFloat> data(task.wav->Data(), 0);
|
||||
int32 samp_offset = task.offset;
|
||||
int32 nsamp = data.Dim();
|
||||
int32 samp_remaining = nsamp - samp_offset;
|
||||
int32 num_samp =
|
||||
chunk_length < samp_remaining ? chunk_length : samp_remaining;
|
||||
bool is_last_chunk = (chunk_length >= samp_remaining);
|
||||
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
|
||||
bool is_first_chunk = (samp_offset == 0);
|
||||
if (online) {
|
||||
double now = gettime_monotonic();
|
||||
double wait_for = task.send_next_chunk_at - now;
|
||||
if (wait_for > 0) usleep(wait_for * 1e6);
|
||||
}
|
||||
asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk,
|
||||
wave_part.Data(), wave_part.SizeInBytes());
|
||||
task.send_next_chunk_at += chunk_seconds;
|
||||
if (verbose)
|
||||
std::cout << "Sending correlation_id=" << task.corr_id
|
||||
<< " chunk offset=" << num_samp << std::endl;
|
||||
|
||||
task.offset += num_samp;
|
||||
if (!is_last_chunk) next_tasks.push_back(std::move(curr_tasks[itask]));
|
||||
auto task = streams.top();
|
||||
streams.pop();
|
||||
if (online) {
|
||||
double wait_for = task.send_next_chunk_at - gettime_monotonic();
|
||||
if (wait_for > 0) usleep(wait_for * 1e6);
|
||||
}
|
||||
add_random_offset = false;
|
||||
|
||||
SubVector<BaseFloat> data(task.wav->Data(), 0);
|
||||
int32 samp_offset = task.offset;
|
||||
int32 nsamp = data.Dim();
|
||||
int32 samp_remaining = nsamp - samp_offset;
|
||||
int32 num_samp =
|
||||
chunk_length < samp_remaining ? chunk_length : samp_remaining;
|
||||
bool is_last_chunk = (chunk_length >= samp_remaining);
|
||||
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
|
||||
bool is_first_chunk = (samp_offset == 0);
|
||||
asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk,
|
||||
wave_part.Data(), wave_part.SizeInBytes(), index++);
|
||||
|
||||
if (verbose)
|
||||
std::cout << "Sending correlation_id=" << task.corr_id
|
||||
<< " chunk offset=" << num_samp << std::endl;
|
||||
|
||||
task.offset += num_samp;
|
||||
int32 next_chunk_num_samp = std::min(nsamp - task.offset, chunk_length);
|
||||
task.send_next_chunk_at += next_chunk_num_samp * seconds_per_sample;
|
||||
|
||||
if (!is_last_chunk) streams.push(task);
|
||||
|
||||
curr_tasks.swap(next_tasks);
|
||||
next_tasks.clear();
|
||||
// Showing activity if necessary
|
||||
if (!print_results && !verbose) std::cout << "." << std::flush;
|
||||
if (!print_results && !print_partial_results && !verbose &&
|
||||
index % nchannels == 0)
|
||||
std::cout << "." << std::flush;
|
||||
}
|
||||
std::cout << "done" << std::endl;
|
||||
std::cout << "Waiting for all results..." << std::flush;
|
||||
asr_client.WaitForCallbacks();
|
||||
std::cout << "done" << std::endl;
|
||||
|
||||
asr_client.PrintStats(online);
|
||||
|
||||
std::unordered_map<ni::CorrelationID, std::string> corr_id_and_keys;
|
||||
for (size_t all_wav_i = 0; all_wav_i < all_wav.size(); ++all_wav_i) {
|
||||
ni::CorrelationID corr_id = static_cast<ni::CorrelationID>(all_wav_i) + 1;
|
||||
corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i]});
|
||||
}
|
||||
asr_client.PrintStats(
|
||||
online,
|
||||
!online); // Print latency if online, do not print throughput if online
|
||||
asr_client.WriteLatticesToFile("ark:|gzip -c > /data/results/lat.cuda-asr.gz",
|
||||
corr_id_and_keys);
|
||||
|
||||
|
|
|
@ -1,80 +1,103 @@
|
|||
name: "kaldi_online"
|
||||
platform: "custom"
|
||||
default_model_filename: "libkaldi-trtisbackend.so"
|
||||
max_batch_size: 2200
|
||||
parameters: {
|
||||
key: "mfcc_filename"
|
||||
value: {
|
||||
string_value:"/data/models/LibriSpeech/conf/mfcc.conf"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "ivector_filename"
|
||||
value: {
|
||||
string_value:"/data/models/LibriSpeech/conf/ivector_extractor.conf"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "nnet3_rxfilename"
|
||||
value: {
|
||||
string_value: "/data/models/LibriSpeech/final.mdl"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "fst_rxfilename"
|
||||
value: {
|
||||
string_value: "/data/models/LibriSpeech/HCLG.fst"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "word_syms_rxfilename"
|
||||
value: {
|
||||
string_value:"/data/models/LibriSpeech/words.txt"
|
||||
}
|
||||
}
|
||||
parameters: [{
|
||||
key: "beam"
|
||||
value: {
|
||||
string_value:"10"
|
||||
}
|
||||
},{
|
||||
key: "num_worker_threads"
|
||||
value: {
|
||||
string_value:"40"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "max_execution_batch_size"
|
||||
value: {
|
||||
string_value:"400"
|
||||
}
|
||||
}]
|
||||
parameters: {
|
||||
key: "lattice_beam"
|
||||
value: {
|
||||
string_value:"7"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "max_active"
|
||||
value: {
|
||||
string_value:"10000"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "frame_subsampling_factor"
|
||||
value: {
|
||||
string_value:"3"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "acoustic_scale"
|
||||
value: {
|
||||
string_value:"1.0"
|
||||
}
|
||||
backend: "kaldi"
|
||||
max_batch_size: 600
|
||||
model_transaction_policy {
|
||||
decoupled: True
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "config_filename"
|
||||
value {
|
||||
string_value: "/data/models/LibriSpeech/conf/online.conf"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "nnet3_rxfilename"
|
||||
value {
|
||||
string_value: "/data/models/LibriSpeech/final.mdl"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "fst_rxfilename"
|
||||
value {
|
||||
string_value: "/data/models/LibriSpeech/HCLG.fst"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "word_syms_rxfilename"
|
||||
value {
|
||||
string_value: "/data/models/LibriSpeech/words.txt"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "lattice_postprocessor_rxfilename"
|
||||
value {
|
||||
string_value: ""
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "use_tensor_cores"
|
||||
value {
|
||||
string_value: "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "main_q_capacity"
|
||||
value {
|
||||
string_value: "30000"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "aux_q_capacity"
|
||||
value {
|
||||
string_value: "400000"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "beam"
|
||||
value {
|
||||
string_value: "10"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "num_worker_threads"
|
||||
value {
|
||||
string_value: "40"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "num_channels"
|
||||
value {
|
||||
string_value: "4000"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "lattice_beam"
|
||||
value {
|
||||
string_value: "7"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "max_active"
|
||||
value {
|
||||
string_value: "10000"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "frame_subsampling_factor"
|
||||
value {
|
||||
string_value: "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
key: "acoustic_scale"
|
||||
value {
|
||||
string_value: "1.0"
|
||||
}
|
||||
}
|
||||
]
|
||||
sequence_batching {
|
||||
max_sequence_idle_microseconds:5000000
|
||||
max_sequence_idle_microseconds: 5000000
|
||||
control_input [
|
||||
{
|
||||
name: "START"
|
||||
|
@ -108,16 +131,16 @@ max_sequence_idle_microseconds:5000000
|
|||
control [
|
||||
{
|
||||
kind: CONTROL_SEQUENCE_CORRID
|
||||
data_type: TYPE_UINT64
|
||||
data_type: TYPE_UINT64
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
oldest {
|
||||
max_candidate_sequences:2200
|
||||
preferred_batch_size:[400]
|
||||
max_queue_delay_microseconds:1000
|
||||
}
|
||||
oldest {
|
||||
max_candidate_sequences: 4000
|
||||
preferred_batch_size: [ 600 ]
|
||||
max_queue_delay_microseconds: 1000
|
||||
}
|
||||
},
|
||||
|
||||
input [
|
||||
|
@ -142,13 +165,17 @@ output [
|
|||
name: "TEXT"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "CTM"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
}
|
||||
|
||||
]
|
||||
instance_group [
|
||||
{
|
||||
count: 2
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -15,10 +15,10 @@ oovtok=$(cat $result_path/words.txt | grep "<unk>" | awk '{print $2}')
|
|||
# convert lattice to transcript
|
||||
/opt/kaldi/src/latbin/lattice-best-path \
|
||||
"ark:gunzip -c $result_path/lat.cuda-asr.gz |"\
|
||||
"ark,t:|gzip -c > $result_path/trans.cuda-asr.gz" 2> /dev/null
|
||||
"ark,t:$result_path/trans.cuda-asr" 2> /dev/null
|
||||
|
||||
# calculate wer
|
||||
/opt/kaldi/src/bin/compute-wer --mode=present \
|
||||
"ark:$result_path/text_ints" \
|
||||
"ark:gunzip -c $result_path/trans.cuda-asr.gz |" 2>&1
|
||||
"ark:$result_path/trans.cuda-asr" 2> /dev/null
|
||||
|
||||
|
|
|
@ -13,5 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
docker build . -f Dockerfile --rm -t trtis_kaldi_server
|
||||
docker build . -f Dockerfile.client --rm -t trtis_kaldi_client
|
||||
set -eu
|
||||
|
||||
# Use development branch of Kaldi for latest feature support
|
||||
docker build . -f Dockerfile \
|
||||
--rm -t triton_kaldi_server
|
||||
docker build . -f Dockerfile.client \
|
||||
--rm -t triton_kaldi_client
|
||||
|
|
|
@ -19,4 +19,5 @@ docker run --rm -it \
|
|||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
-v $PWD/data:/data \
|
||||
trtis_kaldi_client /workspace/scripts/docker/run_client.sh $@
|
||||
--entrypoint /bin/bash \
|
||||
triton_kaldi_client /workspace/scripts/docker/run_client.sh $@
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -13,12 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Start TRTIS server container for download - need some kaldi tools
|
||||
nvidia-docker run --rm \
|
||||
# Start Triton server container for download - need some kaldi tools
|
||||
docker run --rm \
|
||||
--shm-size=1g \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
-v $PWD/data:/mnt/data \
|
||||
trtis_kaldi_server /workspace/scripts/docker/dataset_setup.sh $(id -u) $(id -g)
|
||||
triton_kaldi_server /workspace/scripts/docker/dataset_setup.sh $(id -u) $(id -g)
|
||||
|
||||
# --user $(id -u):$(id -g) \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -15,8 +15,9 @@
|
|||
|
||||
NV_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-"0"}
|
||||
|
||||
# Start TRTIS server
|
||||
nvidia-docker run --rm -it \
|
||||
# Start Triton server
|
||||
docker run --rm -it \
|
||||
--gpus $NV_VISIBLE_DEVICES \
|
||||
--shm-size=1g \
|
||||
--ulimit memlock=-1 \
|
||||
--ulimit stack=67108864 \
|
||||
|
@ -24,7 +25,6 @@ nvidia-docker run --rm -it \
|
|||
-p8001:8001 \
|
||||
-p8002:8002 \
|
||||
--name trt_server_asr \
|
||||
-e NVIDIA_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES \
|
||||
-v $PWD/data:/data \
|
||||
-v $PWD/model-repo:/mnt/model-repo \
|
||||
trtis_kaldi_server trtserver --model-repo=/workspace/model-repo/
|
||||
triton_kaldi_server
|
||||
|
|
|
@ -7,7 +7,7 @@ then
|
|||
rm -rf $results_dir
|
||||
fi
|
||||
mkdir $results_dir
|
||||
install/bin/kaldi_asr_parallel_client $@
|
||||
kaldi-asr-parallel-client $@
|
||||
echo "Computing WER..."
|
||||
/workspace/scripts/compute_wer.sh
|
||||
rm -rf $results_dir
|
||||
|
|
|
@ -19,4 +19,4 @@ if [ -d "/mnt/model-repo/kaldi_online" ]; then
|
|||
ln -s /mnt/model-repo/kaldi_online/config.pbtxt /workspace/model-repo/kaldi_online/
|
||||
fi
|
||||
|
||||
/opt/tensorrtserver/nvidia_entrypoint.sh $@
|
||||
/opt/tritonserver/nvidia_entrypoint.sh $@
|
26
Kaldi/SpeechRecognition/trtis-kaldi-backend/libkaldi_online.ldscript → Kaldi/SpeechRecognition/scripts/run_inference_all_a100.sh
Normal file → Executable file
26
Kaldi/SpeechRecognition/trtis-kaldi-backend/libkaldi_online.ldscript → Kaldi/SpeechRecognition/scripts/run_inference_all_a100.sh
Normal file → Executable file
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -11,11 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
{
|
||||
global:
|
||||
CustomErrorString;
|
||||
CustomExecute;
|
||||
CustomFinalize;
|
||||
CustomInitialize;
|
||||
local: *;
|
||||
};
|
||||
set -e
|
||||
|
||||
if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then
|
||||
printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "\nOffline benchmarks:\n"
|
||||
|
||||
scripts/docker/launch_client.sh -i 5 -c 4000
|
||||
|
||||
printf "\nOnline benchmarks:\n"
|
||||
|
||||
scripts/docker/launch_client.sh -i 10 -c 2000 -o
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
set -e
|
||||
|
||||
if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
|
||||
if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then
|
||||
printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
|
||||
exit 1
|
||||
fi
|
||||
|
@ -26,5 +26,5 @@ scripts/docker/launch_client.sh -i 5 -c 1000
|
|||
|
||||
printf "\nOnline benchmarks:\n"
|
||||
|
||||
scripts/docker/launch_client.sh -i 10 -c 700 -o
|
||||
scripts/docker/launch_client.sh -i 10 -c 600 -o
|
||||
scripts/docker/launch_client.sh -i 10 -c 400 -o
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -15,7 +15,7 @@
|
|||
|
||||
set -e
|
||||
|
||||
if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
|
||||
if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then
|
||||
printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
|
||||
exit 1
|
||||
fi
|
||||
|
@ -26,6 +26,5 @@ scripts/docker/launch_client.sh -i 5 -c 2000
|
|||
|
||||
printf "\nOnline benchmarks:\n"
|
||||
|
||||
scripts/docker/launch_client.sh -i 10 -c 1500 -o
|
||||
scripts/docker/launch_client.sh -i 10 -c 2000 -o
|
||||
scripts/docker/launch_client.sh -i 10 -c 1000 -o
|
||||
scripts/docker/launch_client.sh -i 5 -c 800 -o
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
.PHONY: all
|
||||
all: kaldibackend
|
||||
|
||||
kaldibackend: kaldi-backend.cc kaldi-backend-utils.cc
|
||||
g++ -fpic -shared -std=c++11 -o libkaldi-trtisbackend.so kaldi-backend.cc kaldi-backend-utils.cc -Icustom-backend-sdk/include custom-backend-sdk/lib/libcustombackend.a -I/opt/kaldi/src/ -I/usr/local/cuda/include -I/opt/kaldi/tools/openfst/include/ -L/opt/kaldi/src/lib/ -lkaldi-cudadecoder
|
|
@ -1,155 +0,0 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "kaldi-backend-utils.h"
|
||||
|
||||
namespace nvidia {
|
||||
namespace inferenceserver {
|
||||
namespace custom {
|
||||
namespace kaldi_cbe {
|
||||
|
||||
int GetInputTensor(CustomGetNextInputFn_t input_fn, void* input_context,
|
||||
const char* name, const size_t expected_byte_size,
|
||||
std::vector<uint8_t>* input, const void** out) {
|
||||
input->clear(); // reset buffer
|
||||
// The values for an input tensor are not necessarily in one
|
||||
// contiguous chunk, so we might copy the chunks into 'input' vector.
|
||||
// If possible, we use the data in place
|
||||
uint64_t total_content_byte_size = 0;
|
||||
while (true) {
|
||||
const void* content;
|
||||
uint64_t content_byte_size = expected_byte_size - total_content_byte_size;
|
||||
if (!input_fn(input_context, name, &content, &content_byte_size)) {
|
||||
return kInputContents;
|
||||
}
|
||||
|
||||
// If 'content' returns nullptr we have all the input.
|
||||
if (content == nullptr) break;
|
||||
|
||||
// If the total amount of content received exceeds what we expect
|
||||
// then something is wrong.
|
||||
total_content_byte_size += content_byte_size;
|
||||
if (total_content_byte_size > expected_byte_size)
|
||||
return kInputSize;
|
||||
|
||||
if (content_byte_size == expected_byte_size) {
|
||||
*out = content;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
input->insert(input->end(), static_cast<const uint8_t*>(content),
|
||||
static_cast<const uint8_t*>(content) + content_byte_size);
|
||||
}
|
||||
|
||||
// Make sure we end up with exactly the amount of input we expect.
|
||||
if (total_content_byte_size != expected_byte_size) {
|
||||
return kInputSize;
|
||||
}
|
||||
*out = &input[0];
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
void LatticeToString(fst::SymbolTable& word_syms,
|
||||
const kaldi::CompactLattice& dlat, std::string* out_str) {
|
||||
kaldi::CompactLattice best_path_clat;
|
||||
kaldi::CompactLatticeShortestPath(dlat, &best_path_clat);
|
||||
|
||||
kaldi::Lattice best_path_lat;
|
||||
fst::ConvertLattice(best_path_clat, &best_path_lat);
|
||||
|
||||
std::vector<int32> alignment;
|
||||
std::vector<int32> words;
|
||||
kaldi::LatticeWeight weight;
|
||||
fst::GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < words.size(); i++) {
|
||||
std::string s = word_syms.Find(words[i]);
|
||||
if (s == "") std::cerr << "Word-id " << words[i] << " not in symbol table.";
|
||||
oss << s << " ";
|
||||
}
|
||||
*out_str = std::move(oss.str());
|
||||
}
|
||||
|
||||
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
|
||||
std::string* param) {
|
||||
auto it = model_config_.parameters().find(key);
|
||||
if (it == model_config_.parameters().end()) {
|
||||
std::cerr << "Parameter \"" << key
|
||||
<< "\" missing from config file. Exiting." << std::endl;
|
||||
return kInvalidModelConfig;
|
||||
}
|
||||
*param = it->second.string_value();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
|
||||
int* param) {
|
||||
std::string tmp;
|
||||
int err = ReadParameter(model_config_, key, &tmp);
|
||||
*param = std::stoi(tmp);
|
||||
return err;
|
||||
}
|
||||
|
||||
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
|
||||
float* param) {
|
||||
std::string tmp;
|
||||
int err = ReadParameter(model_config_, key, &tmp);
|
||||
*param = std::stof(tmp);
|
||||
return err;
|
||||
}
|
||||
|
||||
const char* CustomErrorString(int errcode) {
|
||||
switch (errcode) {
|
||||
case kSuccess:
|
||||
return "success";
|
||||
case kInvalidModelConfig:
|
||||
return "invalid model configuration";
|
||||
case kGpuNotSupported:
|
||||
return "execution on GPU not supported";
|
||||
case kSequenceBatcher:
|
||||
return "model configuration must configure sequence batcher";
|
||||
case kModelControl:
|
||||
return "'START' and 'READY' must be configured as the control inputs";
|
||||
case kInputOutput:
|
||||
return "model must have four inputs and one output with shape [-1]";
|
||||
case kInputName:
|
||||
return "names for input don't exist";
|
||||
case kOutputName:
|
||||
return "model output must be named 'OUTPUT'";
|
||||
case kInputOutputDataType:
|
||||
return "model inputs or outputs data_type cannot be specified";
|
||||
case kInputContents:
|
||||
return "unable to get input tensor values";
|
||||
case kInputSize:
|
||||
return "unexpected size for input tensor";
|
||||
case kOutputBuffer:
|
||||
return "unable to get buffer for output tensor values";
|
||||
case kBatchTooBig:
|
||||
return "unable to execute batch larger than max-batch-size";
|
||||
case kTimesteps:
|
||||
return "unable to execute more than 1 timestep at a time";
|
||||
case kChunkTooBig:
|
||||
return "a chunk cannot contain more samples than the WAV_DATA dimension";
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return "unknown error";
|
||||
}
|
||||
|
||||
} // kaldi
|
||||
} // custom
|
||||
} // inferenceserver
|
||||
} // nvidia
|
|
@ -1,66 +0,0 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "lat/lattice-functions.h"
|
||||
#include "src/core/model_config.h"
|
||||
#include "src/core/model_config.pb.h"
|
||||
#include "src/custom/sdk/custom_instance.h"
|
||||
|
||||
namespace nvidia {
|
||||
namespace inferenceserver {
|
||||
namespace custom {
|
||||
namespace kaldi_cbe {
|
||||
|
||||
enum ErrorCodes {
|
||||
kSuccess,
|
||||
kUnknown,
|
||||
kInvalidModelConfig,
|
||||
kGpuNotSupported,
|
||||
kSequenceBatcher,
|
||||
kModelControl,
|
||||
kInputOutput,
|
||||
kInputName,
|
||||
kOutputName,
|
||||
kInputOutputDataType,
|
||||
kInputContents,
|
||||
kInputSize,
|
||||
kOutputBuffer,
|
||||
kBatchTooBig,
|
||||
kTimesteps,
|
||||
kChunkTooBig
|
||||
};
|
||||
|
||||
int GetInputTensor(CustomGetNextInputFn_t input_fn, void* input_context,
|
||||
const char* name, const size_t expected_byte_size,
|
||||
std::vector<uint8_t>* input, const void** out);
|
||||
|
||||
void LatticeToString(fst::SymbolTable& word_syms,
|
||||
const kaldi::CompactLattice& dlat, std::string* out_str);
|
||||
|
||||
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
|
||||
std::string* param);
|
||||
|
||||
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
|
||||
int* param);
|
||||
int ReadParameter(const ModelConfig& model_config_, const std::string& key,
|
||||
float* param);
|
||||
|
||||
const char* CustomErrorString(int errcode);
|
||||
|
||||
} // kaldi
|
||||
} // custom
|
||||
} // inferenceserver
|
||||
} // nvidia
|
|
@ -1,423 +0,0 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "kaldi-backend.h"
|
||||
#include "kaldi-backend-utils.h"
|
||||
|
||||
namespace nvidia {
|
||||
namespace inferenceserver {
|
||||
namespace custom {
|
||||
namespace kaldi_cbe {
|
||||
|
||||
Context::Context(const std::string& instance_name,
|
||||
const ModelConfig& model_config, const int gpu_device)
|
||||
: instance_name_(instance_name),
|
||||
model_config_(model_config),
|
||||
gpu_device_(gpu_device),
|
||||
num_channels_(
|
||||
model_config_
|
||||
.max_batch_size()), // diff in def between kaldi and trtis
|
||||
int32_byte_size_(GetDataTypeByteSize(TYPE_INT32)),
|
||||
int64_byte_size_(GetDataTypeByteSize(TYPE_INT64)) {}
|
||||
|
||||
Context::~Context() { delete word_syms_; }
|
||||
|
||||
int Context::ReadModelParameters() {
|
||||
// Reading config
|
||||
float beam, lattice_beam;
|
||||
int max_active;
|
||||
int frame_subsampling_factor;
|
||||
float acoustic_scale;
|
||||
int num_worker_threads;
|
||||
int err =
|
||||
ReadParameter(model_config_, "mfcc_filename",
|
||||
&batched_decoder_config_.feature_opts.mfcc_config) ||
|
||||
ReadParameter(
|
||||
model_config_, "ivector_filename",
|
||||
&batched_decoder_config_.feature_opts.ivector_extraction_config) ||
|
||||
ReadParameter(model_config_, "beam", &beam) ||
|
||||
ReadParameter(model_config_, "lattice_beam", &lattice_beam) ||
|
||||
ReadParameter(model_config_, "max_active", &max_active) ||
|
||||
ReadParameter(model_config_, "frame_subsampling_factor",
|
||||
&frame_subsampling_factor) ||
|
||||
ReadParameter(model_config_, "acoustic_scale", &acoustic_scale) ||
|
||||
ReadParameter(model_config_, "nnet3_rxfilename", &nnet3_rxfilename_) ||
|
||||
ReadParameter(model_config_, "fst_rxfilename", &fst_rxfilename_) ||
|
||||
ReadParameter(model_config_, "word_syms_rxfilename",
|
||||
&word_syms_rxfilename_) ||
|
||||
ReadParameter(model_config_, "num_worker_threads", &num_worker_threads) ||
|
||||
ReadParameter(model_config_, "max_execution_batch_size",
|
||||
&max_batch_size_);
|
||||
if (err) return err;
|
||||
max_batch_size_ = std::max<int>(max_batch_size_, 1);
|
||||
num_channels_ = std::max<int>(num_channels_, 1);
|
||||
|
||||
// Sanity checks
|
||||
if (beam <= 0) return kInvalidModelConfig;
|
||||
if (lattice_beam <= 0) return kInvalidModelConfig;
|
||||
if (max_active <= 0) return kInvalidModelConfig;
|
||||
if (acoustic_scale <= 0) return kInvalidModelConfig;
|
||||
if (num_worker_threads <= 0) return kInvalidModelConfig;
|
||||
if (num_channels_ <= max_batch_size_) return kInvalidModelConfig;
|
||||
|
||||
batched_decoder_config_.compute_opts.frame_subsampling_factor =
|
||||
frame_subsampling_factor;
|
||||
batched_decoder_config_.compute_opts.acoustic_scale = acoustic_scale;
|
||||
batched_decoder_config_.decoder_opts.default_beam = beam;
|
||||
batched_decoder_config_.decoder_opts.lattice_beam = lattice_beam;
|
||||
batched_decoder_config_.decoder_opts.max_active = max_active;
|
||||
batched_decoder_config_.num_worker_threads = num_worker_threads;
|
||||
batched_decoder_config_.max_batch_size = max_batch_size_;
|
||||
batched_decoder_config_.num_channels = num_channels_;
|
||||
|
||||
auto feature_config = batched_decoder_config_.feature_opts;
|
||||
kaldi::OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
|
||||
sample_freq_ = feature_info.mfcc_opts.frame_opts.samp_freq;
|
||||
BaseFloat frame_shift = feature_info.FrameShiftInSeconds();
|
||||
seconds_per_chunk_ = chunk_num_samps_ / sample_freq_;
|
||||
|
||||
int samp_per_frame = static_cast<int>(sample_freq_ * frame_shift);
|
||||
float n_input_framesf = chunk_num_samps_ / samp_per_frame;
|
||||
bool is_integer = (n_input_framesf == std::floor(n_input_framesf));
|
||||
if (!is_integer) {
|
||||
std::cerr << "WAVE_DATA dim must be a multiple fo samples per frame ("
|
||||
<< samp_per_frame << ")" << std::endl;
|
||||
return kInvalidModelConfig;
|
||||
}
|
||||
int n_input_frames = static_cast<int>(std::floor(n_input_framesf));
|
||||
batched_decoder_config_.compute_opts.frames_per_chunk = n_input_frames;
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int Context::InitializeKaldiPipeline() {
|
||||
batch_corr_ids_.reserve(max_batch_size_);
|
||||
batch_wave_samples_.reserve(max_batch_size_);
|
||||
batch_is_first_chunk_.reserve(max_batch_size_);
|
||||
batch_is_last_chunk_.reserve(max_batch_size_);
|
||||
wave_byte_buffers_.resize(max_batch_size_);
|
||||
output_shape_ = {1, 1};
|
||||
kaldi::CuDevice::Instantiate()
|
||||
.SelectAndInitializeGpuIdWithExistingCudaContext(gpu_device_);
|
||||
kaldi::CuDevice::Instantiate().AllowMultithreading();
|
||||
|
||||
// Loading models
|
||||
{
|
||||
bool binary;
|
||||
kaldi::Input ki(nnet3_rxfilename_, &binary);
|
||||
trans_model_.Read(ki.Stream(), binary);
|
||||
am_nnet_.Read(ki.Stream(), binary);
|
||||
|
||||
kaldi::nnet3::SetBatchnormTestMode(true, &(am_nnet_.GetNnet()));
|
||||
kaldi::nnet3::SetDropoutTestMode(true, &(am_nnet_.GetNnet()));
|
||||
kaldi::nnet3::CollapseModel(kaldi::nnet3::CollapseModelConfig(),
|
||||
&(am_nnet_.GetNnet()));
|
||||
}
|
||||
fst::Fst<fst::StdArc>* decode_fst = fst::ReadFstKaldiGeneric(fst_rxfilename_);
|
||||
cuda_pipeline_.reset(
|
||||
new kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline(
|
||||
batched_decoder_config_, *decode_fst, am_nnet_, trans_model_));
|
||||
delete decode_fst;
|
||||
|
||||
// Loading word syms for text output
|
||||
if (word_syms_rxfilename_ != "") {
|
||||
if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) {
|
||||
std::cerr << "Could not read symbol table from file "
|
||||
<< word_syms_rxfilename_;
|
||||
return kInvalidModelConfig;
|
||||
}
|
||||
}
|
||||
chunk_num_samps_ = cuda_pipeline_->GetNSampsPerChunk();
|
||||
chunk_num_bytes_ = chunk_num_samps_ * sizeof(BaseFloat);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int Context::Init() {
|
||||
return InputOutputSanityCheck() || ReadModelParameters() ||
|
||||
InitializeKaldiPipeline();
|
||||
}
|
||||
|
||||
bool Context::CheckPayloadError(const CustomPayload& payload) {
|
||||
int err = payload.error_code;
|
||||
if (err) std::cerr << "Error: " << CustomErrorString(err) << std::endl;
|
||||
return (err != 0);
|
||||
}
|
||||
|
||||
int Context::Execute(const uint32_t payload_cnt, CustomPayload* payloads,
|
||||
CustomGetNextInputFn_t input_fn,
|
||||
CustomGetOutputFn_t output_fn) {
|
||||
// kaldi::Timer timer;
|
||||
if (payload_cnt > num_channels_) return kBatchTooBig;
|
||||
// Each payload is a chunk for one sequence
|
||||
// Currently using dynamic batcher, not sequence batcher
|
||||
for (uint32_t pidx = 0; pidx < payload_cnt; ++pidx) {
|
||||
if (batch_corr_ids_.size() == max_batch_size_) FlushBatch();
|
||||
|
||||
CustomPayload& payload = payloads[pidx];
|
||||
if (payload.batch_size != 1) payload.error_code = kTimesteps;
|
||||
if (CheckPayloadError(payload)) continue;
|
||||
|
||||
// Get input tensors
|
||||
int32_t start, dim, end, ready;
|
||||
CorrelationID corr_id;
|
||||
const BaseFloat* wave_buffer;
|
||||
payload.error_code = GetSequenceInput(
|
||||
input_fn, payload.input_context, &corr_id, &start, &ready, &dim, &end,
|
||||
&wave_buffer, &wave_byte_buffers_[pidx]);
|
||||
if (CheckPayloadError(payload)) continue;
|
||||
if (!ready) continue;
|
||||
if (dim > chunk_num_samps_) payload.error_code = kChunkTooBig;
|
||||
if (CheckPayloadError(payload)) continue;
|
||||
|
||||
kaldi::SubVector<BaseFloat> wave_part(wave_buffer, dim);
|
||||
// Initialize corr_id if first chunk
|
||||
if (start) {
|
||||
if (!cuda_pipeline_->TryInitCorrID(corr_id)) {
|
||||
printf("ERR %i \n", __LINE__);
|
||||
// TODO add error code
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Add to batch
|
||||
batch_corr_ids_.push_back(corr_id);
|
||||
batch_wave_samples_.push_back(wave_part);
|
||||
batch_is_first_chunk_.push_back(start);
|
||||
batch_is_last_chunk_.push_back(end);
|
||||
|
||||
if (end) {
|
||||
// If last chunk, set the callback for that seq
|
||||
cuda_pipeline_->SetLatticeCallback(
|
||||
corr_id, [this, &output_fn, &payloads, pidx,
|
||||
corr_id](kaldi::CompactLattice& clat) {
|
||||
SetOutputs(clat, output_fn, payloads[pidx]);
|
||||
});
|
||||
}
|
||||
}
|
||||
FlushBatch();
|
||||
cuda_pipeline_->WaitForLatticeCallbacks();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int Context::FlushBatch() {
|
||||
if (!batch_corr_ids_.empty()) {
|
||||
cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_,
|
||||
batch_is_first_chunk_, batch_is_last_chunk_);
|
||||
batch_corr_ids_.clear();
|
||||
batch_wave_samples_.clear();
|
||||
batch_is_first_chunk_.clear();
|
||||
batch_is_last_chunk_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
int Context::InputOutputSanityCheck() {
|
||||
if (!model_config_.has_sequence_batching()) {
|
||||
return kSequenceBatcher;
|
||||
}
|
||||
|
||||
auto& batcher = model_config_.sequence_batching();
|
||||
if (batcher.control_input_size() != 4) {
|
||||
return kModelControl;
|
||||
}
|
||||
|
||||
std::set<std::string> control_input_names;
|
||||
for (int i = 0; i < 4; ++i)
|
||||
control_input_names.insert(batcher.control_input(i).name());
|
||||
if (!(control_input_names.erase("START") &&
|
||||
control_input_names.erase("END") &&
|
||||
control_input_names.erase("CORRID") &&
|
||||
control_input_names.erase("READY"))) {
|
||||
return kModelControl;
|
||||
}
|
||||
|
||||
if (model_config_.input_size() != 2) {
|
||||
return kInputOutput;
|
||||
}
|
||||
if ((model_config_.input(0).dims().size() != 1) ||
|
||||
(model_config_.input(0).dims(0) <= 0) ||
|
||||
(model_config_.input(1).dims().size() != 1) ||
|
||||
(model_config_.input(1).dims(0) != 1)) {
|
||||
return kInputOutput;
|
||||
}
|
||||
chunk_num_samps_ = model_config_.input(0).dims(0);
|
||||
chunk_num_bytes_ = chunk_num_samps_ * sizeof(float);
|
||||
|
||||
if ((model_config_.input(0).data_type() != DataType::TYPE_FP32) ||
|
||||
(model_config_.input(1).data_type() != DataType::TYPE_INT32)) {
|
||||
return kInputOutputDataType;
|
||||
}
|
||||
if ((model_config_.input(0).name() != "WAV_DATA") ||
|
||||
(model_config_.input(1).name() != "WAV_DATA_DIM")) {
|
||||
return kInputName;
|
||||
}
|
||||
|
||||
if (model_config_.output_size() != 2) return kInputOutput;
|
||||
|
||||
for (int ioutput = 0; ioutput < 2; ++ioutput) {
|
||||
if ((model_config_.output(ioutput).dims().size() != 1) ||
|
||||
(model_config_.output(ioutput).dims(0) != 1)) {
|
||||
return kInputOutput;
|
||||
}
|
||||
if (model_config_.output(ioutput).data_type() != DataType::TYPE_STRING) {
|
||||
return kInputOutputDataType;
|
||||
}
|
||||
}
|
||||
|
||||
if (model_config_.output(0).name() != "RAW_LATTICE") return kOutputName;
|
||||
if (model_config_.output(1).name() != "TEXT") return kOutputName;
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int Context::GetSequenceInput(CustomGetNextInputFn_t& input_fn,
|
||||
void* input_context, CorrelationID* corr_id,
|
||||
int32_t* start, int32_t* ready, int32_t* dim,
|
||||
int32_t* end, const BaseFloat** wave_buffer,
|
||||
std::vector<uint8_t>* input_buffer) {
|
||||
int err;
|
||||
//&input_buffer[0]: char pointer -> alias with any types
|
||||
// wave_data[0] will holds the struct
|
||||
|
||||
// Get start of sequence tensor
|
||||
const void* out;
|
||||
err = GetInputTensor(input_fn, input_context, "WAV_DATA_DIM",
|
||||
int32_byte_size_, &byte_buffer_, &out);
|
||||
if (err != kSuccess) return err;
|
||||
*dim = *reinterpret_cast<const int32_t*>(out);
|
||||
|
||||
err = GetInputTensor(input_fn, input_context, "END", int32_byte_size_,
|
||||
&byte_buffer_, &out);
|
||||
if (err != kSuccess) return err;
|
||||
*end = *reinterpret_cast<const int32_t*>(out);
|
||||
|
||||
err = GetInputTensor(input_fn, input_context, "START", int32_byte_size_,
|
||||
&byte_buffer_, &out);
|
||||
if (err != kSuccess) return err;
|
||||
*start = *reinterpret_cast<const int32_t*>(out);
|
||||
|
||||
err = GetInputTensor(input_fn, input_context, "READY", int32_byte_size_,
|
||||
&byte_buffer_, &out);
|
||||
if (err != kSuccess) return err;
|
||||
*ready = *reinterpret_cast<const int32_t*>(out);
|
||||
|
||||
err = GetInputTensor(input_fn, input_context, "CORRID", int64_byte_size_,
|
||||
&byte_buffer_, &out);
|
||||
if (err != kSuccess) return err;
|
||||
*corr_id = *reinterpret_cast<const CorrelationID*>(out);
|
||||
|
||||
// Get pointer to speech tensor
|
||||
err = GetInputTensor(input_fn, input_context, "WAV_DATA", chunk_num_bytes_,
|
||||
input_buffer, &out);
|
||||
if (err != kSuccess) return err;
|
||||
*wave_buffer = reinterpret_cast<const BaseFloat*>(out);
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int Context::SetOutputs(kaldi::CompactLattice& clat,
|
||||
CustomGetOutputFn_t output_fn, CustomPayload payload) {
|
||||
int status = kSuccess;
|
||||
if (payload.error_code != kSuccess) return payload.error_code;
|
||||
for (int ioutput = 0; ioutput < payload.output_cnt; ++ioutput) {
|
||||
const char* output_name = payload.required_output_names[ioutput];
|
||||
if (!strcmp(output_name, "RAW_LATTICE")) {
|
||||
std::ostringstream oss;
|
||||
kaldi::WriteCompactLattice(oss, true, clat);
|
||||
status = SetOutputByName(output_name, oss.str(), output_fn, payload);
|
||||
if(status != kSuccess) return status;
|
||||
} else if (!strcmp(output_name, "TEXT")) {
|
||||
std::string output;
|
||||
LatticeToString(*word_syms_, clat, &output);
|
||||
status = SetOutputByName(output_name, output, output_fn, payload);
|
||||
if(status != kSuccess) return status;
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
int Context::SetOutputByName(const char* output_name,
|
||||
const std::string& out_bytes,
|
||||
CustomGetOutputFn_t output_fn,
|
||||
CustomPayload payload) {
|
||||
uint32_t byte_size_with_size_int = out_bytes.size() + sizeof(int32);
|
||||
void* obuffer; // output buffer
|
||||
if (!output_fn(payload.output_context, output_name, output_shape_.size(),
|
||||
&output_shape_[0], byte_size_with_size_int, &obuffer)) {
|
||||
payload.error_code = kOutputBuffer;
|
||||
return payload.error_code;
|
||||
}
|
||||
if (obuffer == nullptr) return kOutputBuffer;
|
||||
|
||||
int32* buffer_as_int = reinterpret_cast<int32*>(obuffer);
|
||||
buffer_as_int[0] = out_bytes.size();
|
||||
memcpy(&buffer_as_int[1], out_bytes.data(), out_bytes.size());
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
/////////////
|
||||
|
||||
extern "C" {
|
||||
|
||||
int CustomInitialize(const CustomInitializeData* data, void** custom_context) {
|
||||
// Convert the serialized model config to a ModelConfig object.
|
||||
ModelConfig model_config;
|
||||
if (!model_config.ParseFromString(std::string(
|
||||
data->serialized_model_config, data->serialized_model_config_size))) {
|
||||
return kInvalidModelConfig;
|
||||
}
|
||||
|
||||
// Create the context and validate that the model configuration is
|
||||
// something that we can handle.
|
||||
Context* context = new Context(std::string(data->instance_name), model_config,
|
||||
data->gpu_device_id);
|
||||
int err = context->Init();
|
||||
if (err != kSuccess) {
|
||||
return err;
|
||||
}
|
||||
|
||||
*custom_context = static_cast<void*>(context);
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
int CustomFinalize(void* custom_context) {
|
||||
if (custom_context != nullptr) {
|
||||
Context* context = static_cast<Context*>(custom_context);
|
||||
delete context;
|
||||
}
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
const char* CustomErrorString(void* custom_context, int errcode) {
|
||||
return CustomErrorString(errcode);
|
||||
}
|
||||
|
||||
int CustomExecute(void* custom_context, const uint32_t payload_cnt,
|
||||
CustomPayload* payloads, CustomGetNextInputFn_t input_fn,
|
||||
CustomGetOutputFn_t output_fn) {
|
||||
if (custom_context == nullptr) {
|
||||
return kUnknown;
|
||||
}
|
||||
|
||||
Context* context = static_cast<Context*>(custom_context);
|
||||
return context->Execute(payload_cnt, payloads, input_fn, output_fn);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
} // namespace kaldi_cbe
|
||||
} // namespace custom
|
||||
} // namespace inferenceserver
|
||||
} // namespace nvidia
|
|
@ -1,126 +0,0 @@
|
|||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define HAVE_CUDA 1 // Loading Kaldi headers with GPU
|
||||
|
||||
#include <cfloat>
|
||||
#include <sstream>
|
||||
#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h"
|
||||
#include "fstext/fstext-lib.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
#include "lat/lattice-functions.h"
|
||||
#include "nnet3/am-nnet-simple.h"
|
||||
#include "nnet3/nnet-utils.h"
|
||||
#include "util/kaldi-thread.h"
|
||||
|
||||
#include "src/core/model_config.h"
|
||||
#include "src/core/model_config.pb.h"
|
||||
#include "src/custom/sdk/custom_instance.h"
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
|
||||
namespace nvidia {
|
||||
namespace inferenceserver {
|
||||
namespace custom {
|
||||
namespace kaldi_cbe {
|
||||
|
||||
// Context object. All state must be kept in this object.
|
||||
class Context {
|
||||
public:
|
||||
Context(const std::string& instance_name, const ModelConfig& config,
|
||||
const int gpu_device);
|
||||
virtual ~Context();
|
||||
|
||||
// Initialize the context. Validate that the model configuration,
|
||||
// etc. is something that we can handle.
|
||||
int Init();
|
||||
|
||||
// Perform custom execution on the payloads.
|
||||
int Execute(const uint32_t payload_cnt, CustomPayload* payloads,
|
||||
CustomGetNextInputFn_t input_fn, CustomGetOutputFn_t output_fn);
|
||||
|
||||
private:
|
||||
// init kaldi pipeline
|
||||
int InitializeKaldiPipeline();
|
||||
int InputOutputSanityCheck();
|
||||
int ReadModelParameters();
|
||||
int GetSequenceInput(CustomGetNextInputFn_t& input_fn, void* input_context,
|
||||
CorrelationID* corr_id, int32_t* start, int32_t* ready,
|
||||
int32_t* dim, int32_t* end,
|
||||
const kaldi::BaseFloat** wave_buffer,
|
||||
std::vector<uint8_t>* input_buffer);
|
||||
|
||||
int SetOutputs(kaldi::CompactLattice& clat,
|
||||
CustomGetOutputFn_t output_fn, CustomPayload payload);
|
||||
|
||||
int SetOutputByName(const char* output_name,
|
||||
const std::string& out_bytes,
|
||||
CustomGetOutputFn_t output_fn,
|
||||
CustomPayload payload);
|
||||
|
||||
bool CheckPayloadError(const CustomPayload& payload);
|
||||
int FlushBatch();
|
||||
|
||||
// The name of this instance of the backend.
|
||||
const std::string instance_name_;
|
||||
|
||||
// The model configuration.
|
||||
const ModelConfig model_config_;
|
||||
|
||||
// The GPU device ID to execute on or CUSTOM_NO_GPU_DEVICE if should
|
||||
// execute on CPU.
|
||||
const int gpu_device_;
|
||||
|
||||
// Models paths
|
||||
std::string nnet3_rxfilename_, fst_rxfilename_;
|
||||
std::string word_syms_rxfilename_;
|
||||
|
||||
// batch_size
|
||||
int max_batch_size_;
|
||||
int num_channels_;
|
||||
int num_worker_threads_;
|
||||
std::vector<CorrelationID> batch_corr_ids_;
|
||||
std::vector<kaldi::SubVector<kaldi::BaseFloat>> batch_wave_samples_;
|
||||
std::vector<bool> batch_is_first_chunk_;
|
||||
std::vector<bool> batch_is_last_chunk_;
|
||||
|
||||
BaseFloat sample_freq_, seconds_per_chunk_;
|
||||
int chunk_num_bytes_, chunk_num_samps_;
|
||||
|
||||
// feature_config includes configuration for the iVector adaptation,
|
||||
// as well as the basic features.
|
||||
kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipelineConfig
|
||||
batched_decoder_config_;
|
||||
std::unique_ptr<kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline>
|
||||
cuda_pipeline_;
|
||||
// Maintain the state of some shared objects
|
||||
kaldi::TransitionModel trans_model_;
|
||||
|
||||
kaldi::nnet3::AmNnetSimple am_nnet_;
|
||||
fst::SymbolTable* word_syms_;
|
||||
|
||||
const uint64_t int32_byte_size_;
|
||||
const uint64_t int64_byte_size_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
|
||||
std::vector<uint8_t> byte_buffer_;
|
||||
std::vector<std::vector<uint8_t>> wave_byte_buffers_;
|
||||
};
|
||||
|
||||
} // namespace kaldi_cbe
|
||||
} // namespace custom
|
||||
} // namespace inferenceserver
|
||||
} // namespace nvidia
|
Loading…
Reference in New Issue