From 1a5c7556b5b351d4c9930a92141e5328a9077b44 Mon Sep 17 00:00:00 2001 From: Alec Kohlhoff Date: Tue, 12 Oct 2021 17:45:31 -0700 Subject: [PATCH] [Kaldi] Update to 21.08 --- Kaldi/SpeechRecognition/.gitignore | 1 + Kaldi/SpeechRecognition/Dockerfile | 110 +- Kaldi/SpeechRecognition/Dockerfile.client | 99 +- Kaldi/SpeechRecognition/Dockerfile.notebook | 2 +- Kaldi/SpeechRecognition/README.md | 50 +- Kaldi/SpeechRecognition/data/README.md | 0 .../kaldi-asr-backend/CMakeLists.txt | 132 ++ .../kaldi-asr-backend/kaldi-backend-utils.cc | 142 ++ .../kaldi-asr-backend/kaldi-backend-utils.h | 57 + .../kaldi-asr-backend/triton-kaldi-backend.cc | 1187 +++++++++++++++++ .../kaldi-asr-client/CMakeLists.txt | 126 +- .../kaldi-asr-client/asr_client_imp.cc | 291 ++-- .../kaldi-asr-client/asr_client_imp.h | 47 +- .../kaldi_asr_parallel_client.cc | 166 ++- .../model-repo/kaldi_online/config.pbtxt | 193 +-- .../SpeechRecognition/scripts/compute_wer.sh | 4 +- .../SpeechRecognition/scripts/docker/build.sh | 9 +- .../scripts/docker/launch_client.sh | 3 +- .../scripts/docker/launch_download.sh | 8 +- .../scripts/docker/launch_server.sh | 10 +- .../scripts/docker/run_client.sh | 2 +- ...t.sh => nvidia_kaldi_triton_entrypoint.sh} | 2 +- .../run_inference_all_a100.sh} | 26 +- .../scripts/run_inference_all_t4.sh | 4 +- .../scripts/run_inference_all_v100.sh | 7 +- .../trtis-kaldi-backend/Makefile | 5 - .../kaldi-backend-utils.cc | 155 --- .../trtis-kaldi-backend/kaldi-backend-utils.h | 66 - .../trtis-kaldi-backend/kaldi-backend.cc | 423 ------ .../trtis-kaldi-backend/kaldi-backend.h | 126 -- 30 files changed, 2258 insertions(+), 1195 deletions(-) delete mode 100644 Kaldi/SpeechRecognition/data/README.md create mode 100644 Kaldi/SpeechRecognition/kaldi-asr-backend/CMakeLists.txt create mode 100644 Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.cc create mode 100644 Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.h create mode 100644 Kaldi/SpeechRecognition/kaldi-asr-backend/triton-kaldi-backend.cc rename Kaldi/SpeechRecognition/scripts/{nvidia_kaldi_trtis_entrypoint.sh => nvidia_kaldi_triton_entrypoint.sh} (94%) rename Kaldi/SpeechRecognition/{trtis-kaldi-backend/libkaldi_online.ldscript => scripts/run_inference_all_a100.sh} (55%) mode change 100644 => 100755 delete mode 100644 Kaldi/SpeechRecognition/trtis-kaldi-backend/Makefile delete mode 100644 Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.cc delete mode 100644 Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.h delete mode 100644 Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.cc delete mode 100644 Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.h diff --git a/Kaldi/SpeechRecognition/.gitignore b/Kaldi/SpeechRecognition/.gitignore index 10b251fb..ace33d56 100644 --- a/Kaldi/SpeechRecognition/.gitignore +++ b/Kaldi/SpeechRecognition/.gitignore @@ -2,3 +2,4 @@ data/* !data/README.md .*.swp .*.swo +.clang-format diff --git a/Kaldi/SpeechRecognition/Dockerfile b/Kaldi/SpeechRecognition/Dockerfile index 6723cace..8b60e0f4 100644 --- a/Kaldi/SpeechRecognition/Dockerfile +++ b/Kaldi/SpeechRecognition/Dockerfile @@ -1,4 +1,4 @@ -# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,12 +11,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb -FROM nvcr.io/nvidia/tritonserver:20.03-py3 -ENV DEBIAN_FRONTEND=noninteractive +ARG TRITONSERVER_IMAGE=nvcr.io/nvidia/tritonserver:21.05-py3 +ARG KALDI_IMAGE=nvcr.io/nvidia/kaldi:21.08-py3 +ARG PYTHON_VER=3.8 + +# +# Kaldi shared library dependencies +# +FROM ${KALDI_IMAGE} as kaldi + +# +# Builder image based on Triton Server SDK image +# +FROM ${TRITONSERVER_IMAGE}-sdk as builder +ARG PYTHON_VER # Kaldi dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ +RUN set -eux; \ + apt-get update; \ + apt-get install -yq --no-install-recommends \ automake \ autoconf \ cmake \ @@ -24,29 +37,80 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ gawk \ libatlas3-base \ libtool \ - python3.6 \ - python3.6-dev \ + python${PYTHON_VER} \ + python${PYTHON_VER}-dev \ sox \ subversion \ unzip \ bc \ libatlas-base-dev \ - zlib1g-dev + gfortran \ + zlib1g-dev; \ + rm -rf /var/lib/apt/lists/* -RUN mkdir /opt/trtis-kaldi && mkdir -p /workspace/model-repo/kaldi_online/1 && mkdir -p /mnt/model-repo -# Copying static files +# Add Kaldi dependency +COPY --from=kaldi /opt/kaldi /opt/kaldi + +# Set up Atlas +RUN set -eux; \ + ln -sf /usr/include/x86_64-linux-gnu/atlas /usr/local/include/atlas; \ + ln -sf /usr/include/x86_64-linux-gnu/cblas.h /usr/local/include/cblas.h; \ + ln -sf /usr/include/x86_64-linux-gnu/clapack.h /usr/local/include/clapack.h; \ + ln -sf /usr/lib/x86_64-linux-gnu/atlas /usr/local/lib/atlas + + +# +# Kaldi custom backend build +# +FROM builder as backend-build + +# Build the custom backend +COPY kaldi-asr-backend /workspace/triton-kaldi-backend +RUN set -eux; \ + cd /workspace/triton-kaldi-backend; \ + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="$(pwd)/install" \ + -B build .; \ + cmake --build build --parallel; \ + cmake --install build + + +# +# Final server image +# +FROM ${TRITONSERVER_IMAGE} +ARG PYTHON_VER + +# Kaldi dependencies +RUN set -eux; \ + apt-get update; \ + apt-get install -yq --no-install-recommends \ + automake \ + autoconf \ + cmake \ + flac \ + gawk \ + libatlas3-base \ + libtool \ + python${PYTHON_VER} \ + python${PYTHON_VER}-dev \ + sox \ + subversion \ + unzip \ + bc \ + libatlas-base-dev \ + zlib1g-dev; \ + rm -rf /var/lib/apt/lists/* + +# Add Kaldi dependency +COPY --from=kaldi /opt/kaldi /opt/kaldi + +# Add Kaldi custom backend shared library and scripts +COPY --from=backend-build /workspace/triton-kaldi-backend/install/backends/kaldi/libtriton_kaldi.so /workspace/model-repo/kaldi_online/1/ COPY scripts /workspace/scripts -# Moving Kaldi to container -COPY --from=kb /opt/kaldi /opt/kaldi -ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH - -# Building the custom backend -COPY trtis-kaldi-backend /workspace/trtis-kaldi-backend -#COPY --from=cbe /workspace/install/custom-backend-sdk /workspace/trtis-kaldi-backend/custom-backend-sdk -RUN cd /workspace/trtis-kaldi-backend && wget https://github.com/NVIDIA/tensorrt-inference-server/releases/download/v1.9.0/v1.9.0_ubuntu1804.custombackend.tar.gz -O custom-backend-sdk.tar.gz && tar -xzf custom-backend-sdk.tar.gz -RUN cd /workspace/trtis-kaldi-backend/ && make && cp libkaldi-trtisbackend.so /workspace/model-repo/kaldi_online/1/ && cd - && rm -r /workspace/trtis-kaldi-backend - -COPY scripts/nvidia_kaldi_trtis_entrypoint.sh /opt/trtis-kaldi - -ENTRYPOINT ["/opt/trtis-kaldi/nvidia_kaldi_trtis_entrypoint.sh"] +# Setup entrypoint and environment +ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:/opt/tritonserver/lib:$LD_LIBRARY_PATH +COPY scripts/nvidia_kaldi_triton_entrypoint.sh /opt/triton-kaldi/ +VOLUME /mnt/model-repo +ENTRYPOINT ["/opt/triton-kaldi/nvidia_kaldi_triton_entrypoint.sh"] +CMD ["tritonserver", "--model-repo=/workspace/model-repo"] diff --git a/Kaldi/SpeechRecognition/Dockerfile.client b/Kaldi/SpeechRecognition/Dockerfile.client index dc22efc0..5f23f3ca 100644 --- a/Kaldi/SpeechRecognition/Dockerfile.client +++ b/Kaldi/SpeechRecognition/Dockerfile.client @@ -1,4 +1,4 @@ -# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,11 +11,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb -FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk +ARG TRITONSERVER_IMAGE=nvcr.io/nvidia/tritonserver:21.05-py3 +ARG KALDI_IMAGE=nvcr.io/nvidia/kaldi:21.08-py3 +ARG PYTHON_VER=3.8 + + +# +# Kaldi shared library dependencies +# +FROM ${KALDI_IMAGE} as kaldi + + +# +# Builder image based on Triton Server SDK image +# +FROM ${TRITONSERVER_IMAGE}-sdk as builder +ARG PYTHON_VER # Kaldi dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ +RUN set -eux; \ + apt-get update; \ + apt-get install -yq --no-install-recommends \ automake \ autoconf \ cmake \ @@ -23,21 +39,78 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ gawk \ libatlas3-base \ libtool \ - python3.6 \ - python3.6-dev \ + python${PYTHON_VER} \ + python${PYTHON_VER}-dev \ sox \ subversion \ unzip \ bc \ libatlas-base-dev \ - zlib1g-dev + gfortran \ + zlib1g-dev; \ + rm -rf /var/lib/apt/lists/* -# Moving Kaldi to container -COPY --from=kb /opt/kaldi /opt/kaldi -ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH +# Add Kaldi dependency +COPY --from=kaldi /opt/kaldi /opt/kaldi +# Set up Atlas +RUN set -eux; \ + ln -sf /usr/include/x86_64-linux-gnu/atlas /usr/local/include/atlas; \ + ln -sf /usr/include/x86_64-linux-gnu/cblas.h /usr/local/include/cblas.h; \ + ln -sf /usr/include/x86_64-linux-gnu/clapack.h /usr/local/include/clapack.h; \ + ln -sf /usr/lib/x86_64-linux-gnu/atlas /usr/local/lib/atlas + + +# +# Triton Kaldi client build +# +FROM builder as client-build + +# Build the clients +COPY kaldi-asr-client /workspace/triton-client +RUN set -eux; \ + cd /workspace; \ + echo 'add_subdirectory(../../../triton-client src/c++/triton-client)' \ + >> /workspace/client/src/c++/CMakeLists.txt; \ + cmake -DCMAKE_BUILD_TYPE=Release -B build client; \ + cmake --build build --parallel --target cc-clients + + +# +# Final gRPC client image +# +FROM ${TRITONSERVER_IMAGE} +ARG PYTHON_VER + +# Kaldi dependencies +RUN set -eux; \ + apt-get update; \ + apt-get install -yq --no-install-recommends \ + automake \ + autoconf \ + cmake \ + flac \ + gawk \ + libatlas3-base \ + libtool \ + python${PYTHON_VER} \ + python${PYTHON_VER}-dev \ + sox \ + subversion \ + unzip \ + bc \ + libatlas-base-dev \ + zlib1g-dev; \ + rm -rf /var/lib/apt/lists/* + +# Add Kaldi dependency +COPY --from=kaldi /opt/kaldi /opt/kaldi + +# Add Triton clients and scripts +COPY --from=client-build /workspace/build/cc-clients/src/c++/triton-client/kaldi-asr-parallel-client /usr/local/bin/ COPY scripts /workspace/scripts -COPY kaldi-asr-client /workspace/src/clients/c++/kaldi-asr-client -RUN echo "add_subdirectory(kaldi-asr-client)" >> "/workspace/src/clients/c++/CMakeLists.txt" -RUN cd /workspace/build/ && make -j16 trtis-clients +# Setup environment and entrypoint +ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:/opt/tritonserver/lib:$LD_LIBRARY_PATH +VOLUME /mnt/model-repo +ENTRYPOINT ["/usr/local/bin/kaldi-asr-parallel-client"] diff --git a/Kaldi/SpeechRecognition/Dockerfile.notebook b/Kaldi/SpeechRecognition/Dockerfile.notebook index 94271290..1d6825ff 100644 --- a/Kaldi/SpeechRecognition/Dockerfile.notebook +++ b/Kaldi/SpeechRecognition/Dockerfile.notebook @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk +FROM nvcr.io/nvidia/tritonserver:21.05-py3-sdk # Kaldi dependencies RUN apt-get update && apt-get install -y jupyter \ diff --git a/Kaldi/SpeechRecognition/README.md b/Kaldi/SpeechRecognition/README.md index f3a22646..6f959259 100644 --- a/Kaldi/SpeechRecognition/README.md +++ b/Kaldi/SpeechRecognition/README.md @@ -46,15 +46,18 @@ A reference model is used by all test scripts and benchmarks presented in this r Details about parameters can be found in the [Parameters](#parameters) section. * `model path`: Configured to use the pretrained LibriSpeech model. +* `use_tensor_cores`: 1 +* `main_q_capacity`: 30000 +* `aux_q_capacity`: 400000 * `beam`: 10 +* `num_channels`: 4000 * `lattice_beam`: 7 * `max_active`: 10,000 * `frame_subsampling_factor`: 3 * `acoustic_scale`: 1.0 -* `num_worker_threads`: 20 -* `max_execution_batch_size`: 256 -* `max_batch_size`: 4096 -* `instance_group.count`: 2 +* `num_worker_threads`: 40 +* `max_batch_size`: 400 +* `instance_group.count`: 1 ## Setup @@ -134,9 +137,8 @@ The model configuration parameters are passed to the model and have an impact o The inference engine configuration parameters configure the inference engine. They impact performance, but not accuracy. -* `max_batch_size`: The maximum number of inference channels opened at a given time. If set to `4096`, then one instance will handle at most 4096 concurrent users. +* `max_batch_size`: The size of one execution batch on the GPU. This parameter should be set as large as necessary to saturate the GPU, but not bigger. Larger batches will lead to a higher throughput, smaller batches to lower latency. * `num_worker_threads`: The number of CPU threads for the postprocessing CPU tasks, such as lattice determinization and text generation from the lattice. -* `max_execution_batch_size`: The size of one execution batch on the GPU. This parameter should be set as large as necessary to saturate the GPU, but not bigger. Larger batches will lead to a higher throughput, smaller batches to lower latency. * `input.WAV_DATA.dims`: The maximum number of samples per chunk. The value must be a multiple of `frame_subsampling_factor * chunks_per_frame`. ### Inference process @@ -156,7 +158,7 @@ The client can be configured through a set of parameters that define its behavio -u -o : Only feed each channel at realtime speed. Simulates online clients. -p : Print text outputs - + -b : Print partial (best path) text outputs ``` ### Input/Output @@ -187,13 +189,8 @@ Even if only the best path is used, we are still generating a full lattice for b Support for Kaldi ASR models that are different from the provided LibriSpeech model is experimental. However, it is possible to modify the [Model Path](#model-path) section of the config file `model-repo/kaldi_online/config.pbtxt` to set up your own model. -The models and Kaldi allocators are currently not shared between instances. This means that if your model is large, you may end up with not enough memory on the GPU to store two different instances. If that's the case, -you can set `count` to `1` in the [`instance_group` section](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html#instance-groups) of the config file. - ## Performance -The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference). - ### Metrics @@ -207,8 +204,7 @@ Latency is defined as the delay between the availability of the last chunk of au 4. *Server:* Compute inference of last chunk 5. *Server:* Generate the raw lattice for the full utterance 6. *Server:* Determinize the raw lattice -7. *Server:* Generate the text output associated with the best path in the determinized lattice -8. *Client:* Receive text output +8. *Client:* Receive lattice output 9. *Client:* Call callback with output 10. ***t1** <- Current time* @@ -219,20 +215,18 @@ The latency is defined such as `latency = t1 - t0`. Our results were obtained by: 1. Building and starting the server as described in [Quick Start Guide](#quick-start-guide). -2. Running `scripts/run_inference_all_v100.sh` and `scripts/run_inference_all_t4.sh` +2. Running `scripts/run_inference_all_a100.sh`, `scripts/run_inference_all_v100.sh` and `scripts/run_inference_all_t4.sh` -| GPU | Realtime I/O | Number of parallel audio channels | Throughput (RTFX) | Latency | | | | -| ------ | ------ | ------ | ------ | ------ | ------ | ------ |------ | -| | | | | 90% | 95% | 99% | Avg | -| V100 | No | 2000 | 1506.5 | N/A | N/A | N/A | N/A | -| V100 | Yes | 1500 | 1243.2 | 0.582 | 0.699 | 1.04 | 0.400 | -| V100 | Yes | 1000 | 884.1 | 0.379 | 0.393 | 0.788 | 0.333 | -| V100 | Yes | 800 | 660.2 | 0.334 | 0.340 | 0.438 | 0.288 | -| T4 | No | 1000 | 675.2 | N/A | N/A | N/A| N/A | -| T4 | Yes | 700 | 629.2 | 0.945 | 1.08 | 1.27 | 0.645 | -| T4 | Yes | 400 | 373.7 | 0.579 | 0.624 | 0.758 | 0.452 | - +| GPU | Realtime I/O | Number of parallel audio channels | Latency (s) | | | | +| ----- | ------------ | --------------------------------- | ----------- | ----- | ----- | ----- | +| | | | 90% | 95% | 99% | Avg | +| A100 | Yes | 2000 | 0.11 | 0.12 | 0.14 | 0.09 | +| V100 | Yes | 2000 | 0.42 | 0.50 | 0.61 | 0.23 | +| V100 | Yes | 1000 | 0.09 | 0.09 | 0.11 | 0.07 | +| T4 | Yes | 600 | 0.17 | 0.18 | 0.22 | 0.14 | +| T4 | Yes | 400 | 0.12 | 0.13 | 0.15 | 0.10 | + ## Release notes ### Changelog @@ -244,5 +238,9 @@ April 2020 * Printing WER accuracy in Triton client * Using the latest Kaldi GPU ASR pipeline, extended support for features (ivectors, fbanks) +July 2021 +* Significantly improve latency and throughput for the backend +* Update Triton to v2.10.0 + ### Known issues * No multi-gpu support for the Triton integration diff --git a/Kaldi/SpeechRecognition/data/README.md b/Kaldi/SpeechRecognition/data/README.md deleted file mode 100644 index e69de29b..00000000 diff --git a/Kaldi/SpeechRecognition/kaldi-asr-backend/CMakeLists.txt b/Kaldi/SpeechRecognition/kaldi-asr-backend/CMakeLists.txt new file mode 100644 index 00000000..7ed6bf5c --- /dev/null +++ b/Kaldi/SpeechRecognition/kaldi-asr-backend/CMakeLists.txt @@ -0,0 +1,132 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17..3.20) + +project(TritonKaldiBackend LANGUAGES C CXX) + +# +# Options +# +# Must include options required for this project as well as any +# projects included in this one by FetchContent. +# +# GPU support is enabled by default because the Kaldi backend requires GPUs. +# +option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON) +option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) + +set(TRITON_COMMON_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG "r21.05" CACHE STRING "Tag for triton-inference-server/backend repo") + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# +# Dependencies +# +# FetchContent's composibility isn't very good. We must include the +# transitive closure of all repos so that we can override the tag. +# +include(FetchContent) +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend) + +# +# Shared library implementing the Triton Backend API +# +add_library(triton-kaldi-backend SHARED) +add_library(TritonKaldiBackend::triton-kaldi-backend ALIAS triton-kaldi-backend) +target_sources(triton-kaldi-backend + PRIVATE + triton-kaldi-backend.cc + kaldi-backend-utils.cc + kaldi-backend-utils.h +) +target_include_directories(triton-kaldi-backend SYSTEM + PRIVATE + $<$:${CUDA_INCLUDE_DIRS}> + /opt/kaldi/src + /opt/kaldi/tools/openfst/include +) +target_include_directories(triton-kaldi-backend + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) +target_compile_features(triton-kaldi-backend PRIVATE cxx_std_17) +target_compile_options(triton-kaldi-backend + PRIVATE + $<$,$,$>:-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> +) +target_link_directories(triton-kaldi-backend + PRIVATE + /opt/kaldi/src/lib +) +target_link_libraries(triton-kaldi-backend + PRIVATE + TritonCore::triton-core-serverapi # from repo-core + TritonCore::triton-core-backendapi # from repo-core + TritonCore::triton-core-serverstub # from repo-core + TritonBackend::triton-backend-utils # from repo-backend + -lkaldi-cudadecoder +) +set_target_properties(triton-kaldi-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_kaldi +) + +# +# Install +# +include(GNUInstallDirs) + +install( + TARGETS + triton-kaldi-backend + EXPORT + triton-kaldi-backend-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/kaldi + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/kaldi +) diff --git a/Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.cc b/Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.cc new file mode 100644 index 00000000..05e1bc80 --- /dev/null +++ b/Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kaldi-backend-utils.h" + +#include + +using triton::common::TritonJson; + +namespace triton { +namespace backend { + +TRITONSERVER_Error* GetInputTensor(TRITONBACKEND_Request* request, + const std::string& input_name, + const size_t expected_byte_size, + std::vector* buffer, + const void** out) { + buffer->clear(); // reset buffer + + TRITONBACKEND_Input* input; + RETURN_IF_ERROR( + TRITONBACKEND_RequestInput(request, input_name.c_str(), &input)); + + uint64_t input_byte_size; + uint32_t input_buffer_count; + RETURN_IF_ERROR( + TRITONBACKEND_InputProperties(input, nullptr, nullptr, nullptr, nullptr, + &input_byte_size, &input_buffer_count)); + RETURN_ERROR_IF_FALSE( + input_byte_size == expected_byte_size, TRITONSERVER_ERROR_INVALID_ARG, + std::string(std::string("unexpected byte size ") + + std::to_string(expected_byte_size) + " requested for " + + input_name.c_str() + ", received " + + std::to_string(input_byte_size))); + + // The values for an input tensor are not necessarily in one + // contiguous chunk, so we might copy the chunks into 'input' vector. + // If possible, we use the data in place + uint64_t total_content_byte_size = 0; + for (uint32_t b = 0; b < input_buffer_count; ++b) { + const void* input_buffer = nullptr; + uint64_t input_buffer_byte_size = 0; + TRITONSERVER_MemoryType input_memory_type = TRITONSERVER_MEMORY_CPU; + int64_t input_memory_type_id = 0; + RETURN_IF_ERROR(TRITONBACKEND_InputBuffer( + input, b, &input_buffer, &input_buffer_byte_size, &input_memory_type, + &input_memory_type_id)); + RETURN_ERROR_IF_FALSE(input_memory_type != TRITONSERVER_MEMORY_GPU, + TRITONSERVER_ERROR_INTERNAL, + std::string("expected input tensor in CPU memory")); + + // Skip the copy if input already exists as a single contiguous + // block + if ((input_buffer_byte_size == expected_byte_size) && (b == 0)) { + *out = input_buffer; + return nullptr; + } + + buffer->insert( + buffer->end(), static_cast(input_buffer), + static_cast(input_buffer) + input_buffer_byte_size); + total_content_byte_size += input_buffer_byte_size; + } + + // Make sure we end up with exactly the amount of input we expect. + RETURN_ERROR_IF_FALSE( + total_content_byte_size == expected_byte_size, + TRITONSERVER_ERROR_INVALID_ARG, + std::string(std::string("total unexpected byte size ") + + std::to_string(expected_byte_size) + " requested for " + + input_name.c_str() + ", received " + + std::to_string(total_content_byte_size))); + *out = &buffer[0]; + + return nullptr; +} + +void LatticeToString(fst::SymbolTable& word_syms, + const kaldi::CompactLattice& dlat, std::string* out_str) { + kaldi::CompactLattice best_path_clat; + kaldi::CompactLatticeShortestPath(dlat, &best_path_clat); + + kaldi::Lattice best_path_lat; + fst::ConvertLattice(best_path_clat, &best_path_lat); + + std::vector alignment; + std::vector words; + kaldi::LatticeWeight weight; + fst::GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); + std::ostringstream oss; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms.Find(words[i]); + if (s == "") { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + ("word-id " + std::to_string(words[i]) + " not in symbol table") + .c_str()); + } + oss << s << " "; + } + *out_str = std::move(oss.str()); +} + +TRITONSERVER_Error* ReadParameter(TritonJson::Value& params, + const std::string& key, std::string* param) { + TritonJson::Value value; + RETURN_ERROR_IF_FALSE( + params.Find(key.c_str(), &value), TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration is missing the parameter ") + key); + RETURN_IF_ERROR(value.MemberAsString("string_value", param)); + return nullptr; // success +} + +TRITONSERVER_Error* ReadParameter(TritonJson::Value& params, + const std::string& key, int* param) { + std::string tmp; + RETURN_IF_ERROR(ReadParameter(params, key, &tmp)); + *param = std::stoi(tmp); + return nullptr; // success +} + +TRITONSERVER_Error* ReadParameter(TritonJson::Value& params, + const std::string& key, float* param) { + std::string tmp; + RETURN_IF_ERROR(ReadParameter(params, key, &tmp)); + *param = std::stof(tmp); + return nullptr; // success +} + +} // namespace backend +} // namespace triton diff --git a/Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.h b/Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.h new file mode 100644 index 00000000..f23f849d --- /dev/null +++ b/Kaldi/SpeechRecognition/kaldi-asr-backend/kaldi-backend-utils.h @@ -0,0 +1,57 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace triton { +namespace backend { + +using triton::common::TritonJson; + +#define RETURN_AND_LOG_IF_ERROR(X, MSG) \ + do { \ + TRITONSERVER_Error* rie_err__ = (X); \ + if (rie_err__ != nullptr) { \ + LOG_MESSAGE(TRITONSERVER_LOG_INFO, MSG); \ + return rie_err__; \ + } \ + } while (false) + +TRITONSERVER_Error* GetInputTensor(TRITONBACKEND_Request* request, + const std::string& input_name, + const size_t expected_byte_size, + std::vector* input, + const void** out); + +TRITONSERVER_Error* LatticeToString(TRITONBACKEND_Request* request, + const std::string& input_name, char* buffer, + size_t* buffer_byte_size); + +void LatticeToString(fst::SymbolTable& word_syms, + const kaldi::CompactLattice& dlat, std::string* out_str); + +TRITONSERVER_Error* ReadParameter(TritonJson::Value& params, + const std::string& key, std::string* param); + +TRITONSERVER_Error* ReadParameter(TritonJson::Value& params, + const std::string& key, int* param); +TRITONSERVER_Error* ReadParameter(TritonJson::Value& params, + const std::string& key, float* param); + +} // namespace backend +} // namespace triton diff --git a/Kaldi/SpeechRecognition/kaldi-asr-backend/triton-kaldi-backend.cc b/Kaldi/SpeechRecognition/kaldi-asr-backend/triton-kaldi-backend.cc new file mode 100644 index 00000000..4eee6c09 --- /dev/null +++ b/Kaldi/SpeechRecognition/kaldi-asr-backend/triton-kaldi-backend.cc @@ -0,0 +1,1187 @@ +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define HAVE_CUDA 1 // Loading Kaldi headers with GPU + +#include + +#include +#include +#include +#include + +#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h" +#include "fstext/fstext-lib.h" +#include "kaldi-backend-utils.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-utils.h" +#include "util/kaldi-thread.h" + +using kaldi::BaseFloat; + +namespace ni = triton::common; +namespace nib = triton::backend; + +namespace { + +#define RESPOND_AND_RETURN_IF_ERROR(REQUEST, X) \ + do { \ + TRITONSERVER_Error* rarie_err__ = (X); \ + if (rarie_err__ != nullptr) { \ + TRITONBACKEND_Response* rarie_response__ = nullptr; \ + LOG_IF_ERROR(TRITONBACKEND_ResponseNew(&rarie_response__, REQUEST), \ + "failed to create response"); \ + if (rarie_response__ != nullptr) { \ + LOG_IF_ERROR(TRITONBACKEND_ResponseSend( \ + rarie_response__, \ + TRITONSERVER_RESPONSE_COMPLETE_FINAL, rarie_err__), \ + "failed to send error response"); \ + } \ + TRITONSERVER_ErrorDelete(rarie_err__); \ + return; \ + } \ + } while (false) + +#define RESPOND_FACTORY_AND_RETURN_IF_ERROR(FACTORY, X) \ + do { \ + TRITONSERVER_Error* rfarie_err__ = (X); \ + if (rfarie_err__ != nullptr) { \ + TRITONBACKEND_Response* rfarie_response__ = nullptr; \ + LOG_IF_ERROR( \ + TRITONBACKEND_ResponseNewFromFactory(&rfarie_response__, FACTORY), \ + "failed to create response"); \ + if (rfarie_response__ != nullptr) { \ + LOG_IF_ERROR(TRITONBACKEND_ResponseSend( \ + rfarie_response__, \ + TRITONSERVER_RESPONSE_COMPLETE_FINAL, rfarie_err__), \ + "failed to send error response"); \ + } \ + TRITONSERVER_ErrorDelete(rfarie_err__); \ + return; \ + } \ + } while (false) + +// +// ResponseOutput +// +// Bit flags for desired response outputs +// +enum ResponseOutput { + kResponseOutputRawLattice = 1 << 0, + kResponseOutputText = 1 << 1, + kResponseOutputCTM = 1 << 2 +}; + +// +// ModelParams +// +// The parameters parsed from the model configuration. +// +struct ModelParams { + // Model paths + std::string nnet3_rxfilename; + std::string fst_rxfilename; + std::string word_syms_rxfilename; + std::string lattice_postprocessor_rxfilename; + + // Filenames + std::string config_filename; + + uint64_t max_batch_size; + int num_channels; + int num_worker_threads; + + int use_tensor_cores; + float beam; + float lattice_beam; + int max_active; + int frame_subsampling_factor; + float acoustic_scale; + int main_q_capacity; + int aux_q_capacity; + + int chunk_num_bytes; + int chunk_num_samps; +}; + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. +// +class ModelState { + public: + static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, + ModelState** state); + + // Get the handle to the TRITONBACKEND model. + TRITONBACKEND_Model* TritonModel() { return triton_model_; } + + // Validate and parse the model configuration + TRITONSERVER_Error* ValidateModelConfig(); + + // Obtain the parameters parsed from the model configuration + const ModelParams* Parameters() { return &model_params_; } + + private: + ModelState(TRITONBACKEND_Model* triton_model, + ni::TritonJson::Value&& model_config); + + TRITONBACKEND_Model* triton_model_; + ni::TritonJson::Value model_config_; + + ModelParams model_params_; +}; + +TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, + ModelState** state) { + TRITONSERVER_Message* config_message; + RETURN_IF_ERROR(TRITONBACKEND_ModelConfig( + triton_model, 1 /* config_version */, &config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR( + TRITONSERVER_MessageSerializeToJson(config_message, &buffer, &byte_size)); + + ni::TritonJson::Value model_config; + TRITONSERVER_Error* err = model_config.Parse(buffer, byte_size); + RETURN_IF_ERROR(TRITONSERVER_MessageDelete(config_message)); + RETURN_IF_ERROR(err); + + *state = new ModelState(triton_model, std::move(model_config)); + return nullptr; // success +} + +ModelState::ModelState(TRITONBACKEND_Model* triton_model, + ni::TritonJson::Value&& model_config) + : triton_model_(triton_model), model_config_(std::move(model_config)) {} + +TRITONSERVER_Error* ModelState::ValidateModelConfig() { + // We have the json DOM for the model configuration... + ni::TritonJson::WriteBuffer buffer; + RETURN_AND_LOG_IF_ERROR(model_config_.PrettyWrite(&buffer), + "failed to pretty write model configuration"); + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("model configuration:\n") + buffer.Contents()).c_str()); + + RETURN_AND_LOG_IF_ERROR(model_config_.MemberAsUInt( + "max_batch_size", &model_params_.max_batch_size), + "failed to get max batch size"); + + ni::TritonJson::Value batcher; + RETURN_ERROR_IF_FALSE( + model_config_.Find("sequence_batching", &batcher), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration must configure sequence batcher")); + ni::TritonJson::Value control_inputs; + RETURN_AND_LOG_IF_ERROR( + batcher.MemberAsArray("control_input", &control_inputs), + "failed to read control input array"); + std::set control_input_names; + for (uint32_t i = 0; i < control_inputs.ArraySize(); i++) { + ni::TritonJson::Value control_input; + RETURN_AND_LOG_IF_ERROR(control_inputs.IndexAsObject(i, &control_input), + "failed to get control input"); + std::string control_input_name; + RETURN_AND_LOG_IF_ERROR( + control_input.MemberAsString("name", &control_input_name), + "failed to get control input name"); + control_input_names.insert(control_input_name); + } + + RETURN_ERROR_IF_FALSE( + (control_input_names.erase("START") && control_input_names.erase("END") && + control_input_names.erase("CORRID") && + control_input_names.erase("READY")), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("missing control input names in the model configuration")); + + // Check the Model Transaction Policy + ni::TritonJson::Value txn_policy; + RETURN_ERROR_IF_FALSE( + model_config_.Find("model_transaction_policy", &txn_policy), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration must specify a transaction policy")); + bool is_decoupled; + RETURN_AND_LOG_IF_ERROR(txn_policy.MemberAsBool("decoupled", &is_decoupled), + "failed to read the decouled txn policy"); + RETURN_ERROR_IF_FALSE( + is_decoupled, TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration must use decoupled transaction policy")); + + // Check the Inputs and Outputs + ni::TritonJson::Value inputs, outputs; + RETURN_AND_LOG_IF_ERROR(model_config_.MemberAsArray("input", &inputs), + "failed to read input array"); + RETURN_AND_LOG_IF_ERROR(model_config_.MemberAsArray("output", &outputs), + "failed to read output array"); + + // There must be 2 inputs and 3 outputs. + RETURN_ERROR_IF_FALSE(inputs.ArraySize() == 2, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected 2 inputs, got ") + + std::to_string(inputs.ArraySize())); + RETURN_ERROR_IF_FALSE(outputs.ArraySize() == 3, + TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected 3 outputs, got ") + + std::to_string(outputs.ArraySize())); + + // Here we rely on the model configuation listing the inputs and + // outputs in a specific order, which we shouldn't really require... + // TODO use sets and loops + ni::TritonJson::Value in0, in1, out0, out1, out2; + RETURN_AND_LOG_IF_ERROR(inputs.IndexAsObject(0, &in0), + "failed to get the first input"); + RETURN_AND_LOG_IF_ERROR(inputs.IndexAsObject(1, &in1), + "failed to get the second input"); + RETURN_AND_LOG_IF_ERROR(outputs.IndexAsObject(0, &out0), + "failed to get the first output"); + RETURN_AND_LOG_IF_ERROR(outputs.IndexAsObject(1, &out1), + "failed to get the second output"); + RETURN_AND_LOG_IF_ERROR(outputs.IndexAsObject(2, &out2), + "failed to get the third output"); + + // Check tensor names + std::string in0_name, in1_name, out0_name, out1_name, out2_name; + RETURN_AND_LOG_IF_ERROR(in0.MemberAsString("name", &in0_name), + "failed to get the first input name"); + RETURN_AND_LOG_IF_ERROR(in1.MemberAsString("name", &in1_name), + "failed to get the second input name"); + RETURN_AND_LOG_IF_ERROR(out0.MemberAsString("name", &out0_name), + "failed to get the first output name"); + RETURN_AND_LOG_IF_ERROR(out1.MemberAsString("name", &out1_name), + "failed to get the second output name"); + RETURN_AND_LOG_IF_ERROR(out2.MemberAsString("name", &out2_name), + "failed to get the third output name"); + + RETURN_ERROR_IF_FALSE( + in0_name == "WAV_DATA", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected first input tensor name to be WAV_DATA, got ") + + in0_name); + RETURN_ERROR_IF_FALSE( + in1_name == "WAV_DATA_DIM", TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "expected second input tensor name to be WAV_DATA_DIM, got ") + + in1_name); + RETURN_ERROR_IF_FALSE( + out0_name == "RAW_LATTICE", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected first output tensor name to be RAW_LATTICE, got ") + + out0_name); + RETURN_ERROR_IF_FALSE( + out1_name == "TEXT", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected second output tensor name to be TEXT, got ") + + out1_name); + RETURN_ERROR_IF_FALSE( + out2_name == "CTM", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected second output tensor name to be CTM, got ") + + out2_name); + + // Check shapes + std::vector in0_shape, in1_shape, out0_shape, out1_shape; + RETURN_AND_LOG_IF_ERROR(nib::ParseShape(in0, "dims", &in0_shape), + " first input shape"); + RETURN_AND_LOG_IF_ERROR(nib::ParseShape(in1, "dims", &in1_shape), + " second input shape"); + RETURN_AND_LOG_IF_ERROR(nib::ParseShape(out0, "dims", &out0_shape), + " first output shape"); + RETURN_AND_LOG_IF_ERROR(nib::ParseShape(out1, "dims", &out1_shape), + " second ouput shape"); + + RETURN_ERROR_IF_FALSE( + in0_shape.size() == 1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected WAV_DATA shape to have one dimension, got ") + + nib::ShapeToString(in0_shape)); + RETURN_ERROR_IF_FALSE( + in0_shape[0] > 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected WAV_DATA shape to be greater than 0, got ") + + nib::ShapeToString(in0_shape)); + model_params_.chunk_num_samps = in0_shape[0]; + model_params_.chunk_num_bytes = model_params_.chunk_num_samps * sizeof(float); + + RETURN_ERROR_IF_FALSE( + ((in1_shape.size() == 1) && (in1_shape[0] == 1)), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected WAV_DATA_DIM shape to be [1], got ") + + nib::ShapeToString(in1_shape)); + RETURN_ERROR_IF_FALSE( + ((out0_shape.size() == 1) && (out0_shape[0] == 1)), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected RAW_LATTICE shape to be [1], got ") + + nib::ShapeToString(out0_shape)); + RETURN_ERROR_IF_FALSE(((out1_shape.size() == 1) && (out1_shape[0] == 1)), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected TEXT shape to be [1], got ") + + nib::ShapeToString(out1_shape)); + + // Check datatypes + std::string in0_dtype, in1_dtype, out0_dtype, out1_dtype; + RETURN_AND_LOG_IF_ERROR(in0.MemberAsString("data_type", &in0_dtype), + "first input data type"); + RETURN_AND_LOG_IF_ERROR(in1.MemberAsString("data_type", &in1_dtype), + "second input datatype"); + RETURN_AND_LOG_IF_ERROR(out0.MemberAsString("data_type", &out0_dtype), + "first output datatype"); + RETURN_AND_LOG_IF_ERROR(out1.MemberAsString("data_type", &out1_dtype), + "second output datatype"); + + RETURN_ERROR_IF_FALSE( + in0_dtype == "TYPE_FP32", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected IN datatype to be INT32, got ") + in0_dtype); + RETURN_ERROR_IF_FALSE( + in1_dtype == "TYPE_INT32", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected DELAY datatype to be UINT32, got ") + in1_dtype); + RETURN_ERROR_IF_FALSE( + out0_dtype == "TYPE_STRING", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected WAIT datatype to be UINT32, got ") + out0_dtype); + RETURN_ERROR_IF_FALSE( + out1_dtype == "TYPE_STRING", TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected OUT datatype to be INT32, got ") + out1_dtype); + + // Validate and set parameters + ni::TritonJson::Value params; + RETURN_ERROR_IF_FALSE( + (model_config_.Find("parameters", ¶ms)), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("missing parameters in the model configuration")); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "config_filename", + &(model_params_.config_filename)), + "config_filename"); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "use_tensor_cores", + &(model_params_.use_tensor_cores)), + "cuda use tensor cores"); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "main_q_capacity", + &(model_params_.main_q_capacity)), + "cuda use tensor cores"); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "aux_q_capacity", + &(model_params_.aux_q_capacity)), + "cuda use tensor cores"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "beam", &(model_params_.beam)), "beam"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "lattice_beam", &(model_params_.lattice_beam)), + "lattice beam"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "max_active", &(model_params_.max_active)), + "max active"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "frame_subsampling_factor", + &(model_params_.frame_subsampling_factor)), + "frame_subsampling_factor"); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "acoustic_scale", + &(model_params_.acoustic_scale)), + "acoustic_scale"); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "nnet3_rxfilename", + &(model_params_.nnet3_rxfilename)), + "nnet3_rxfilename"); + RETURN_AND_LOG_IF_ERROR(nib::ReadParameter(params, "fst_rxfilename", + &(model_params_.fst_rxfilename)), + "fst_rxfilename"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "word_syms_rxfilename", + &(model_params_.word_syms_rxfilename)), + "word_syms_rxfilename"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "num_worker_threads", + &(model_params_.num_worker_threads)), + "num_worker_threads"); + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "num_channels", &(model_params_.num_channels)), + "num_channels"); + + RETURN_AND_LOG_IF_ERROR( + nib::ReadParameter(params, "lattice_postprocessor_rxfilename", + &(model_params_.lattice_postprocessor_rxfilename)), + "(optional) lattice postprocessor config file"); + + model_params_.max_batch_size = std::max(model_params_.max_batch_size, 1); + model_params_.num_channels = std::max(model_params_.num_channels, 1); + + // Sanity checks + RETURN_ERROR_IF_FALSE( + model_params_.beam > 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected parameter \"beam\" to be greater than 0, got ") + + std::to_string(model_params_.beam)); + RETURN_ERROR_IF_FALSE( + model_params_.lattice_beam > 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "expected parameter \"lattice_beam\" to be greater than 0, got ") + + std::to_string(model_params_.lattice_beam)); + RETURN_ERROR_IF_FALSE( + model_params_.max_active > 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "expected parameter \"max_active\" to be greater than 0, got ") + + std::to_string(model_params_.max_active)); + RETURN_ERROR_IF_FALSE(model_params_.main_q_capacity >= -1, + TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected parameter \"main_q_capacity\" to " + "be greater than or equal to -1, got ") + + std::to_string(model_params_.main_q_capacity)); + RETURN_ERROR_IF_FALSE(model_params_.aux_q_capacity >= -1, + TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected parameter \"aux_q_capacity\" to " + "be greater than or equal to -1, got ") + + std::to_string(model_params_.aux_q_capacity)); + RETURN_ERROR_IF_FALSE( + model_params_.acoustic_scale > 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "expected parameter \"acoustic_scale\" to be greater than 0, got ") + + std::to_string(model_params_.acoustic_scale)); + RETURN_ERROR_IF_FALSE( + model_params_.num_worker_threads >= -1, TRITONSERVER_ERROR_INVALID_ARG, + std::string("expected parameter \"num_worker_threads\" to be greater " + "than or equal to -1, got ") + + std::to_string(model_params_.num_worker_threads)); + RETURN_ERROR_IF_FALSE( + model_params_.num_channels > 0, TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "expected parameter \"num_channels\" to be greater than 0, got ") + + std::to_string(model_params_.num_channels)); + + return nullptr; // success +} + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each TRITONBACKEND_ModelInstance. +// +class ModelInstanceState { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + ~ModelInstanceState(); + + // Get the handle to the TRITONBACKEND model instance. + TRITONBACKEND_ModelInstance* TritonModelInstance() { + return triton_model_instance_; + } + + // Get the name, kind and device ID of the instance. + const std::string& Name() const { return name_; } + TRITONSERVER_InstanceGroupKind Kind() const { return kind_; } + int32_t DeviceId() const { return device_id_; } + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + // Initialize this object + TRITONSERVER_Error* Init(); + + // Initialize kaldi pipeline with this object + TRITONSERVER_Error* InitializeKaldiPipeline(); + + // Prepares the requests for kaldi pipeline + void PrepareRequest(TRITONBACKEND_Request* request, uint32_t slot_idx); + + // Executes the batch on the decoder + void FlushBatch(); + + // Waits for all pipeline callbacks to complete + void WaitForLatticeCallbacks(); + + private: + ModelInstanceState(ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + const char* name, + const TRITONSERVER_InstanceGroupKind kind, + const int32_t device_id); + + TRITONSERVER_Error* GetSequenceInput(TRITONBACKEND_Request* request, + int32_t* start, int32_t* ready, + int32_t* dim, int32_t* end, + uint64_t* corr_id, + const BaseFloat** wave_buffer, + std::vector* input_buffer); + + void DeliverPartialResponse(const std::string& text, + TRITONBACKEND_ResponseFactory* response_factory, + uint8_t response_outputs); + void DeliverResponse( + std::vector& results, + uint64_t corr_id, TRITONBACKEND_ResponseFactory* response_factory, + uint8_t response_outputs); + void SetPartialOutput(const std::string& text, + TRITONBACKEND_ResponseFactory* response_factory, + TRITONBACKEND_Response* response); + void SetOutput(std::vector& results, + uint64_t corr_id, const std::string& output_name, + TRITONBACKEND_ResponseFactory* response_factory, + TRITONBACKEND_Response* response); + + void SetOutputBuffer(const std::string& out_bytes, + TRITONBACKEND_Response* response, + TRITONBACKEND_Output* response_output); + + ModelState* model_state_; + TRITONBACKEND_ModelInstance* triton_model_instance_; + const std::string name_; + const TRITONSERVER_InstanceGroupKind kind_; + const int32_t device_id_; + + std::mutex partial_resfactory_mu_; + std::unordered_map>> + partial_responsefactory_; + std::vector batch_corr_ids_; + std::vector> batch_wave_samples_; + std::vector batch_is_first_chunk_; + std::vector batch_is_last_chunk_; + + BaseFloat sample_freq_, seconds_per_chunk_; + int chunk_num_bytes_, chunk_num_samps_; + + // feature_config includes configuration for the iVector adaptation, + // as well as the basic features. + kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipelineConfig + batched_decoder_config_; + std::unique_ptr + cuda_pipeline_; + // Maintain the state of some shared objects + kaldi::TransitionModel trans_model_; + + kaldi::nnet3::AmNnetSimple am_nnet_; + fst::SymbolTable* word_syms_; + + std::vector byte_buffer_; + std::vector> wave_byte_buffers_; + + std::vector output_shape_; + std::vector request_outputs_; +}; + +TRITONSERVER_Error* ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) { + const char* instance_name; + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceName(triton_model_instance, &instance_name)); + + TRITONSERVER_InstanceGroupKind instance_kind; + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceKind(triton_model_instance, &instance_kind)); + + int32_t instance_id; + RETURN_IF_ERROR( + TRITONBACKEND_ModelInstanceDeviceId(triton_model_instance, &instance_id)); + + *state = new ModelInstanceState(model_state, triton_model_instance, + instance_name, instance_kind, instance_id); + return nullptr; // success +} + +ModelInstanceState::ModelInstanceState( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + const char* name, const TRITONSERVER_InstanceGroupKind kind, + const int32_t device_id) + : model_state_(model_state), + triton_model_instance_(triton_model_instance), + name_(name), + kind_(kind), + device_id_(device_id) {} + +ModelInstanceState::~ModelInstanceState() { delete word_syms_; } + +TRITONSERVER_Error* ModelInstanceState::Init() { + const ModelParams* model_params = model_state_->Parameters(); + + chunk_num_samps_ = model_params->chunk_num_samps; + chunk_num_bytes_ = model_params->chunk_num_bytes; + + + { + std::ostringstream usage_str; + usage_str << "Parsing config from " << "from '" << model_params->config_filename << "'"; + kaldi::ParseOptions po(usage_str.str().c_str()); + batched_decoder_config_.Register(&po); + po.DisableOption("cuda-decoder-copy-threads"); + po.DisableOption("cuda-worker-threads"); + po.DisableOption("max-active"); + po.DisableOption("max-batch-size"); + po.DisableOption("num-channels"); + po.ReadConfigFile(model_params->config_filename); + } + kaldi::CuDevice::EnableTensorCores(bool(model_params->use_tensor_cores)); + + batched_decoder_config_.compute_opts.frame_subsampling_factor = + model_params->frame_subsampling_factor; + batched_decoder_config_.compute_opts.acoustic_scale = + model_params->acoustic_scale; + batched_decoder_config_.decoder_opts.default_beam = model_params->beam; + batched_decoder_config_.decoder_opts.lattice_beam = + model_params->lattice_beam; + batched_decoder_config_.decoder_opts.max_active = model_params->max_active; + batched_decoder_config_.num_worker_threads = model_params->num_worker_threads; + batched_decoder_config_.max_batch_size = model_params->max_batch_size; + batched_decoder_config_.num_channels = model_params->num_channels; + batched_decoder_config_.decoder_opts.main_q_capacity = + model_params->main_q_capacity; + batched_decoder_config_.decoder_opts.aux_q_capacity = + model_params->aux_q_capacity; + + auto feature_config = batched_decoder_config_.feature_opts; + kaldi::OnlineNnet2FeaturePipelineInfo feature_info(feature_config); + sample_freq_ = feature_info.mfcc_opts.frame_opts.samp_freq; + BaseFloat frame_shift = feature_info.FrameShiftInSeconds(); + seconds_per_chunk_ = chunk_num_samps_ / sample_freq_; + + int samp_per_frame = static_cast(sample_freq_ * frame_shift); + float n_input_framesf = chunk_num_samps_ / samp_per_frame; + RETURN_ERROR_IF_FALSE( + (n_input_framesf == std::floor(n_input_framesf)), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("WAVE_DATA dim must be a multiple fo samples per frame (") + + std::to_string(samp_per_frame) + std::string(")")); + int n_input_frames = static_cast(std::floor(n_input_framesf)); + batched_decoder_config_.compute_opts.frames_per_chunk = n_input_frames; + + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::InitializeKaldiPipeline() { + const ModelParams* model_params = model_state_->Parameters(); + + batch_corr_ids_.reserve(model_params->max_batch_size); + batch_wave_samples_.reserve(model_params->max_batch_size); + batch_is_first_chunk_.reserve(model_params->max_batch_size); + batch_is_last_chunk_.reserve(model_params->max_batch_size); + wave_byte_buffers_.resize(model_params->max_batch_size); + for (auto& wbb : wave_byte_buffers_) { + wbb.resize(chunk_num_bytes_); + } + output_shape_ = {1, 1}; + kaldi::g_cuda_allocator.SetOptions(kaldi::g_allocator_options); + kaldi::CuDevice::Instantiate() + .SelectAndInitializeGpuIdWithExistingCudaContext(device_id_); + kaldi::CuDevice::Instantiate().AllowMultithreading(); + + // Loading models + { + bool binary; + kaldi::Input ki(model_params->nnet3_rxfilename, &binary); + trans_model_.Read(ki.Stream(), binary); + am_nnet_.Read(ki.Stream(), binary); + + kaldi::nnet3::SetBatchnormTestMode(true, &(am_nnet_.GetNnet())); + kaldi::nnet3::SetDropoutTestMode(true, &(am_nnet_.GetNnet())); + kaldi::nnet3::CollapseModel(kaldi::nnet3::CollapseModelConfig(), + &(am_nnet_.GetNnet())); + } + fst::Fst* decode_fst = + fst::ReadFstKaldiGeneric(model_params->fst_rxfilename); + cuda_pipeline_.reset( + new kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline( + batched_decoder_config_, *decode_fst, am_nnet_, trans_model_)); + delete decode_fst; + + // Loading word syms for text output + if (model_params->word_syms_rxfilename != "") { + RETURN_ERROR_IF_FALSE( + (word_syms_ = + fst::SymbolTable::ReadText(model_params->word_syms_rxfilename)), + TRITONSERVER_ERROR_INVALID_ARG, + std::string("could not read symbol table from file ") + + model_params->word_syms_rxfilename); + cuda_pipeline_->SetSymbolTable(*word_syms_); + } + + // Load lattice postprocessor, required if using CTM + if (!model_params->lattice_postprocessor_rxfilename.empty()) { + LoadAndSetLatticePostprocessor( + model_params->lattice_postprocessor_rxfilename, cuda_pipeline_.get()); + } + chunk_num_samps_ = cuda_pipeline_->GetNSampsPerChunk(); + chunk_num_bytes_ = chunk_num_samps_ * sizeof(BaseFloat); + + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::GetSequenceInput( + TRITONBACKEND_Request* request, int32_t* start, int32_t* ready, + int32_t* dim, int32_t* end, uint64_t* corr_id, + const BaseFloat** wave_buffer, std::vector* input_buffer) { + size_t dim_bsize = sizeof(*dim); + RETURN_IF_ERROR(nib::ReadInputTensor( + request, "WAV_DATA_DIM", reinterpret_cast(dim), &dim_bsize)); + + size_t end_bsize = sizeof(*end); + RETURN_IF_ERROR(nib::ReadInputTensor( + request, "END", reinterpret_cast(end), &end_bsize)); + + size_t start_bsize = sizeof(*start); + RETURN_IF_ERROR(nib::ReadInputTensor( + request, "START", reinterpret_cast(start), &start_bsize)); + + size_t ready_bsize = sizeof(*ready); + RETURN_IF_ERROR(nib::ReadInputTensor( + request, "READY", reinterpret_cast(ready), &ready_bsize)); + + size_t corrid_bsize = sizeof(*corr_id); + RETURN_IF_ERROR(nib::ReadInputTensor( + request, "CORRID", reinterpret_cast(corr_id), &corrid_bsize)); + + // Get pointer to speech tensor + size_t wavdata_bsize = input_buffer->size(); + RETURN_IF_ERROR(nib::ReadInputTensor( + request, "WAV_DATA", reinterpret_cast(input_buffer->data()), + &wavdata_bsize)); + *wave_buffer = reinterpret_cast(input_buffer->data()); + + return nullptr; +} + +void ModelInstanceState::PrepareRequest(TRITONBACKEND_Request* request, + uint32_t slot_idx) { + const ModelParams* model_params = model_state_->Parameters(); + + if (batch_corr_ids_.size() == (uint32_t)model_params->max_batch_size) { + FlushBatch(); + } + + int32_t start, dim, end, ready; + uint64_t corr_id; + const BaseFloat* wave_buffer; + + if (slot_idx >= (uint32_t)model_params->max_batch_size) { + RESPOND_AND_RETURN_IF_ERROR( + request, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, + "slot_idx exceeded")); + } + RESPOND_AND_RETURN_IF_ERROR( + request, GetSequenceInput(request, &start, &ready, &dim, &end, &corr_id, + &wave_buffer, &wave_byte_buffers_[slot_idx])); + + uint32_t output_count; + RESPOND_AND_RETURN_IF_ERROR( + request, TRITONBACKEND_RequestOutputCount(request, &output_count)); + + uint8_t response_outputs = 0; + int kaldi_result_type = 0; + for (uint32_t index = 0; index < output_count; index++) { + const char* output_name; + RESPOND_AND_RETURN_IF_ERROR( + request, TRITONBACKEND_RequestOutputName(request, index, &output_name)); + std::string output_name_str = output_name; + if (output_name_str == "RAW_LATTICE") { + response_outputs |= kResponseOutputRawLattice; + kaldi_result_type |= + kaldi::cuda_decoder::CudaPipelineResult::RESULT_TYPE_LATTICE; + } else if (output_name_str == "TEXT") { + response_outputs |= kResponseOutputText; + kaldi_result_type |= + kaldi::cuda_decoder::CudaPipelineResult::RESULT_TYPE_LATTICE; + } else if (output_name_str == "CTM") { + response_outputs |= kResponseOutputCTM; + kaldi_result_type |= + kaldi::cuda_decoder::CudaPipelineResult::RESULT_TYPE_CTM; + } else { + TRITONSERVER_LogMessage( + TRITONSERVER_LOG_WARN, __FILE__, __LINE__, + ("unrecognized requested output " + output_name_str).c_str()); + } + } + + if (dim > chunk_num_samps_) { + RESPOND_AND_RETURN_IF_ERROR( + request, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "a chunk cannot contain more samples than the WAV_DATA dimension")); + } + + if (!ready) { + RESPOND_AND_RETURN_IF_ERROR( + request, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, + "request is not yet ready")); + } + + // Initialize corr_id if first chunk + if (start) { + if (!cuda_pipeline_->TryInitCorrID(corr_id)) { + RESPOND_AND_RETURN_IF_ERROR( + request, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, + "failed to start cuda pipeline")); + } + + { + std::lock_guard lock_partial_resfactory( + partial_resfactory_mu_); + cuda_pipeline_->SetBestPathCallback( + corr_id, [this, corr_id](const std::string& str, bool partial, + bool endpoint_detected) { + // Bestpath callbacks are synchronous in regards to each correlation + // ID, so the lock is only needed for acquiring a reference to the + // queue. + std::unique_lock lock_partial_resfactory( + partial_resfactory_mu_); + auto& resfactory_queue = partial_responsefactory_.at(corr_id); + if (!partial) { + if (!endpoint_detected) { + // while (!resfactory_queue.empty()) { + // auto response_factory = resfactory_queue.front(); + // resfactory_queue.pop(); + // if (response_factory != nullptr) { + // LOG_IF_ERROR( + // TRITONBACKEND_ResponseFactoryDelete(response_factory), + // "error deleting response factory"); + // } + // } + partial_responsefactory_.erase(corr_id); + } + return; + } + if (resfactory_queue.empty()) { + TRITONSERVER_LogMessage( + TRITONSERVER_LOG_WARN, __FILE__, __LINE__, + "response factory queue unexpectedly empty"); + return; + } + + auto response_factory = resfactory_queue.front(); + resfactory_queue.pop(); + lock_partial_resfactory.unlock(); + if (response_factory == nullptr) return; + + DeliverPartialResponse(str, response_factory.get(), + kResponseOutputText); + }); + partial_responsefactory_.emplace( + corr_id, + std::queue>()); + } + } + + kaldi::SubVector wave_part(wave_buffer, dim); + + // Add to batch + batch_corr_ids_.push_back(corr_id); + batch_wave_samples_.push_back(wave_part); + batch_is_first_chunk_.push_back(start); + batch_is_last_chunk_.push_back(end); + + TRITONBACKEND_ResponseFactory* response_factory_ptr; + RESPOND_AND_RETURN_IF_ERROR(request, TRITONBACKEND_ResponseFactoryNew( + &response_factory_ptr, request)); + std::shared_ptr response_factory( + response_factory_ptr, [](TRITONBACKEND_ResponseFactory* f) { + LOG_IF_ERROR(TRITONBACKEND_ResponseFactoryDelete(f), + "failed deleting response factory"); + }); + + if (end) { + auto segmented_lattice_callback_fn = + [this, response_factory, response_outputs, + corr_id](kaldi::cuda_decoder::SegmentedLatticeCallbackParams& params) { + DeliverResponse(params.results, corr_id, response_factory.get(), + response_outputs); + }; + cuda_pipeline_->SetLatticeCallback(corr_id, segmented_lattice_callback_fn, + kaldi_result_type); + } else if (response_outputs & kResponseOutputText) { + std::lock_guard lock_partial_resfactory(partial_resfactory_mu_); + auto& resfactory_queue = partial_responsefactory_.at(corr_id); + resfactory_queue.push(response_factory); + } else { + { + std::lock_guard lock_partial_resfactory( + partial_resfactory_mu_); + auto& resfactory_queue = partial_responsefactory_.at(corr_id); + resfactory_queue.emplace(nullptr); + } + + // Mark the response complete without sending any responses + LOG_IF_ERROR( + TRITONBACKEND_ResponseFactorySendFlags( + response_factory.get(), TRITONSERVER_RESPONSE_COMPLETE_FINAL), + "failed sending final response"); + } +} + +void ModelInstanceState::FlushBatch() { + if (!batch_corr_ids_.empty()) { + cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_, + batch_is_first_chunk_, batch_is_last_chunk_); + batch_corr_ids_.clear(); + batch_wave_samples_.clear(); + batch_is_first_chunk_.clear(); + batch_is_last_chunk_.clear(); + } +} + +void ModelInstanceState::WaitForLatticeCallbacks() { + cuda_pipeline_->WaitForLatticeCallbacks(); +} + +void ModelInstanceState::DeliverPartialResponse( + const std::string& text, TRITONBACKEND_ResponseFactory* response_factory, + uint8_t response_outputs) { + if (response_outputs & kResponseOutputText) { + TRITONBACKEND_Response* response; + RESPOND_FACTORY_AND_RETURN_IF_ERROR( + response_factory, + TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)); + SetPartialOutput(text, response_factory, response); + LOG_IF_ERROR(TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed sending response"); + } else { + LOG_IF_ERROR(TRITONBACKEND_ResponseFactorySendFlags( + response_factory, TRITONSERVER_RESPONSE_COMPLETE_FINAL), + "failed to send final flag for partial result"); + } +} + +void ModelInstanceState::DeliverResponse( + std::vector& results, + uint64_t corr_id, TRITONBACKEND_ResponseFactory* response_factory, + uint8_t response_outputs) { + TRITONBACKEND_Response* response; + RESPOND_FACTORY_AND_RETURN_IF_ERROR( + response_factory, + TRITONBACKEND_ResponseNewFromFactory(&response, response_factory)); + if (response_outputs & kResponseOutputRawLattice) { + SetOutput(results, corr_id, "RAW_LATTICE", response_factory, response); + } + if (response_outputs & kResponseOutputText) { + SetOutput(results, corr_id, "TEXT", response_factory, response); + } + if (response_outputs & kResponseOutputCTM) { + SetOutput(results, corr_id, "CTM", response_factory, response); + } + // Send the response. + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, + nullptr /* success */), + "failed sending response"); +} + +void ModelInstanceState::SetPartialOutput( + const std::string& text, TRITONBACKEND_ResponseFactory* response_factory, + TRITONBACKEND_Response* response) { + TRITONBACKEND_Output* response_output; + RESPOND_FACTORY_AND_RETURN_IF_ERROR( + response_factory, TRITONBACKEND_ResponseOutput( + response, &response_output, "TEXT", + TRITONSERVER_TYPE_BYTES, &output_shape_[0], 2)); + SetOutputBuffer(text, response, response_output); +} + +void ModelInstanceState::SetOutput( + std::vector& results, + uint64_t corr_id, const std::string& output_name, + TRITONBACKEND_ResponseFactory* response_factory, + TRITONBACKEND_Response* response) { + TRITONBACKEND_Output* response_output; + RESPOND_FACTORY_AND_RETURN_IF_ERROR( + response_factory, + TRITONBACKEND_ResponseOutput(response, &response_output, + output_name.c_str(), TRITONSERVER_TYPE_BYTES, + &output_shape_[0], 2 /* dims_count */)); + + if (output_name.compare("RAW_LATTICE") == 0) { + assert(!results.empty()); + kaldi::CompactLattice& clat = results[0].GetLatticeResult(); + + std::ostringstream oss; + kaldi::WriteCompactLattice(oss, true, clat); + SetOutputBuffer(oss.str(), response, response_output); + } else if (output_name.compare("TEXT") == 0) { + assert(!results.empty()); + kaldi::CompactLattice& clat = results[0].GetLatticeResult(); + std::string output; + nib::LatticeToString(*word_syms_, clat, &output); + SetOutputBuffer(output, response, response_output); + } else if (output_name.compare("CTM") == 0) { + std::ostringstream oss; + MergeSegmentsToCTMOutput(results, std::to_string(corr_id), oss, word_syms_, + /* use segment offset*/ false); + SetOutputBuffer(oss.str(), response, response_output); + } +} + +void ModelInstanceState::SetOutputBuffer( + const std::string& out_bytes, TRITONBACKEND_Response* response, + TRITONBACKEND_Output* response_output) { + TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU; + int64_t actual_memory_type_id = 0; + uint32_t byte_size_with_size_int = out_bytes.size() + sizeof(int32); + void* obuffer; // output buffer + auto err = TRITONBACKEND_OutputBuffer( + response_output, &obuffer, byte_size_with_size_int, &actual_memory_type, + &actual_memory_type_id); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(&response, err); + } + + int32* buffer_as_int = reinterpret_cast(obuffer); + buffer_as_int[0] = out_bytes.size(); + memcpy(&buffer_as_int[1], out_bytes.data(), out_bytes.size()); +} + +} // namespace + +///////////// + +extern "C" { + +TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); + std::string name(cname); + + uint64_t version; + RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInitialize: ") + name + + " (version " + std::to_string(version) + ")") + .c_str()); + + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + RETURN_IF_ERROR(model_state->ValidateModelConfig()); + + return nullptr; // success +} + +TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelFinalize: delete model state"); + + delete model_state; + + return nullptr; // success +} + +TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize( + TRITONBACKEND_ModelInstance* instance) { + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname)); + std::string name(cname); + + int32_t device_id; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceDeviceId(instance, &device_id)); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name + + " (device " + std::to_string(device_id) + ")") + .c_str()); + + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + RETURN_IF_ERROR(instance_state->Init()); + RETURN_IF_ERROR(instance_state->InitializeKaldiPipeline()); + + return nullptr; // success +} + +TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize( + TRITONBACKEND_ModelInstance* instance) { + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelInstanceFinalize: waiting for lattice callbacks"); + instance_state->WaitForLatticeCallbacks(); + + LOG_MESSAGE(TRITONSERVER_LOG_INFO, + "TRITONBACKEND_ModelInstanceFinalize: delete instance state"); + delete instance_state; + + return nullptr; // success +} + +TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) { + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, + (std::string("model instance ") + instance_state->Name() + + ", executing " + std::to_string(request_count) + " requests") + .c_str()); + + RETURN_ERROR_IF_FALSE( + request_count <= + instance_state->StateForModel()->Parameters()->max_batch_size, + TRITONSERVER_ERROR_INVALID_ARG, + std::string("request count exceeded the provided maximum batch size")); + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + // Each request is a chunk for one sequence + // Using the oldest strategy in the sequence batcher ensures that + // there will only be a single chunk for each sequence. + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + instance_state->PrepareRequest(request, r); + } + + instance_state->FlushBatch(); + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportStatistics( + instance_state->TritonModelInstance(), request, true /* success */, + exec_start_ns, exec_start_ns, exec_end_ns, exec_end_ns), + "failed reporting request statistics"); + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + + LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics( + instance_state->TritonModelInstance(), request_count, + exec_start_ns, exec_start_ns, exec_end_ns, exec_end_ns), + "failed reporting batch request statistics"); + + return nullptr; // success +} + +} // extern "C" diff --git a/Kaldi/SpeechRecognition/kaldi-asr-client/CMakeLists.txt b/Kaldi/SpeechRecognition/kaldi-asr-client/CMakeLists.txt index 9349a059..8df19ca0 100644 --- a/Kaldi/SpeechRecognition/kaldi-asr-client/CMakeLists.txt +++ b/Kaldi/SpeechRecognition/kaldi-asr-client/CMakeLists.txt @@ -1,70 +1,78 @@ -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. # -# http://www.apache.org/licenses/LICENSE-2.0 +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17..3.20) + # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -cmake_minimum_required (VERSION 3.5) - -add_executable(kaldi_asr_parallel_client kaldi_asr_parallel_client.cc asr_client_imp.cc) - -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE request -) - -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE protobuf::libprotobuf -) - -target_include_directories( - kaldi_asr_parallel_client +# gRPC client for custom Kaldi backend +# +add_executable(kaldi-asr-parallel-client) +add_executable(TritonKaldiGrpcClient::kaldi-asr-parallel-client ALIAS kaldi-asr-parallel-client) +target_sources(kaldi-asr-parallel-client PRIVATE - /opt/kaldi/src/ + kaldi_asr_parallel_client.cc + asr_client_imp.cc ) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") # openfst yields many warnings -target_include_directories( - kaldi_asr_parallel_client +target_include_directories(kaldi-asr-parallel-client SYSTEM PRIVATE - /opt/kaldi/tools/openfst-1.6.7/include/ + /opt/kaldi/src + /opt/kaldi/tools/openfst/include +) +target_include_directories(kaldi-asr-parallel-client + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) +target_compile_features(kaldi-asr-parallel-client PRIVATE cxx_std_17) +target_compile_options(kaldi-asr-parallel-client + PRIVATE + $<$,$,$>:-Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> +) +target_link_directories(kaldi-asr-parallel-client + PRIVATE + /opt/kaldi/src/lib +) +target_link_libraries(kaldi-asr-parallel-client + PRIVATE + TritonClient::grpcclient_static + -lkaldi-base + -lkaldi-util + -lkaldi-matrix + -lkaldi-feat + -lkaldi-lat ) -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE /opt/kaldi/src/lib/libkaldi-feat.so -) - -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE /opt/kaldi/src/lib/libkaldi-util.so -) - -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE /opt/kaldi/src/lib/libkaldi-matrix.so -) - -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE /opt/kaldi/src/lib/libkaldi-base.so -) - -target_link_libraries( - kaldi_asr_parallel_client - PRIVATE /opt/kaldi/src/lat/libkaldi-lat.so -) +# +# Install +# +include(GNUInstallDirs) install( - TARGETS kaldi_asr_parallel_client - RUNTIME DESTINATION bin + TARGETS + kaldi-asr-parallel-client + EXPORT + kaldi-asr-parallel-client-targets + RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin ) diff --git a/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.cc b/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.cc index 2c1cd026..3c85e02c 100644 --- a/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.cc +++ b/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,9 @@ // limitations under the License. #include "asr_client_imp.h" + #include + #include #include #include @@ -33,117 +35,179 @@ } \ } -void TRTISASRClient::CreateClientContext() { - contextes_.emplace_back(); - ClientContext& client = contextes_.back(); +void TritonASRClient::CreateClientContext() { + clients_.emplace_back(); + TritonClient& client = clients_.back(); + FAIL_IF_ERR(nic::InferenceServerGrpcClient::Create(&client.triton_client, + url_, /*verbose*/ false), + "unable to create triton client"); + FAIL_IF_ERR( - nic::InferGrpcStreamContext::Create(&client.trtis_context, - /*corr_id*/ -1, url_, model_name_, - /*model_version*/ -1, - /*verbose*/ false), - "unable to create context"); + client.triton_client->StartStream( + [&](nic::InferResult* result) { + double end_timestamp = gettime_monotonic(); + std::unique_ptr result_ptr(result); + FAIL_IF_ERR(result_ptr->RequestStatus(), + "inference request failed"); + std::string request_id; + FAIL_IF_ERR(result_ptr->Id(&request_id), + "unable to get request id for response"); + uint64_t corr_id = + std::stoi(std::string(request_id, 0, request_id.find("_"))); + bool end_of_stream = (request_id.back() == '1'); + if (!end_of_stream) { + if (print_partial_results_) { + std::vector text; + FAIL_IF_ERR(result_ptr->StringData("TEXT", &text), + "unable to get TEXT output"); + std::lock_guard lk(stdout_m_); + std::cout << "CORR_ID " << corr_id << "\t[partial]\t" << text[0] + << '\n'; + } + return; + } + + double start_timestamp; + { + std::lock_guard lk(start_timestamps_m_); + auto it = start_timestamps_.find(corr_id); + if (it != start_timestamps_.end()) { + start_timestamp = it->second; + start_timestamps_.erase(it); + } else { + std::cerr << "start_timestamp not found" << std::endl; + exit(1); + } + } + + if (print_results_) { + std::vector text; + FAIL_IF_ERR(result_ptr->StringData(ctm_ ? "CTM" : "TEXT", &text), + "unable to get TEXT or CTM output"); + std::lock_guard lk(stdout_m_); + std::cout << "CORR_ID " << corr_id; + std::cout << (ctm_ ? "\n" : "\t\t"); + std::cout << text[0] << std::endl; + } + + std::vector lattice_bytes; + FAIL_IF_ERR(result_ptr->StringData("RAW_LATTICE", &lattice_bytes), + "unable to get RAW_LATTICE output"); + + { + double elapsed = end_timestamp - start_timestamp; + std::lock_guard lk(results_m_); + results_.insert( + {corr_id, {std::move(lattice_bytes[0]), elapsed}}); + } + + n_in_flight_.fetch_sub(1, std::memory_order_relaxed); + }, + false), + "unable to establish a streaming connection to server"); } -void TRTISASRClient::SendChunk(ni::CorrelationID corr_id, - bool start_of_sequence, bool end_of_sequence, - float* chunk, int chunk_byte_size) { - ClientContext* client = &contextes_[corr_id % ncontextes_]; - nic::InferContext& context = *client->trtis_context; - if (start_of_sequence) n_in_flight_.fetch_add(1, std::memory_order_consume); - +void TritonASRClient::SendChunk(uint64_t corr_id, bool start_of_sequence, + bool end_of_sequence, float* chunk, + int chunk_byte_size, const uint64_t index) { // Setting options - std::unique_ptr options; - FAIL_IF_ERR(nic::InferContext::Options::Create(&options), - "unable to create inference options"); - options->SetBatchSize(1); - options->SetFlags(0); - options->SetCorrelationId(corr_id); - if (start_of_sequence) - options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_START, - start_of_sequence); - if (end_of_sequence) { - options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_END, - end_of_sequence); - for (const auto& output : context.Outputs()) { - if (output->Name() == "TEXT" && !print_results_) - continue; // no need for text output if not printing - options->AddRawResult(output); - } - } + nic::InferOptions options(model_name_); + options.sequence_id_ = corr_id; + options.sequence_start_ = start_of_sequence; + options.sequence_end_ = end_of_sequence; + options.request_id_ = std::to_string(corr_id) + "_" + std::to_string(index) + + "_" + (start_of_sequence ? "1" : "0") + "_" + + (end_of_sequence ? "1" : "0"); - FAIL_IF_ERR(context.SetRunOptions(*options), "unable to set context options"); - std::shared_ptr in_wave_data, in_wave_data_dim; - FAIL_IF_ERR(context.GetInput("WAV_DATA", &in_wave_data), - "unable to get WAV_DATA"); - FAIL_IF_ERR(context.GetInput("WAV_DATA_DIM", &in_wave_data_dim), - "unable to get WAV_DATA_DIM"); - - // Wave data input - FAIL_IF_ERR(in_wave_data->Reset(), "unable to reset WAVE_DATA"); + // Initialize the inputs with the data. + nic::InferInput* wave_data_ptr; + std::vector wav_shape{1, samps_per_chunk_}; + FAIL_IF_ERR( + nic::InferInput::Create(&wave_data_ptr, "WAV_DATA", wav_shape, "FP32"), + "unable to create 'WAV_DATA'"); + std::shared_ptr wave_data_in(wave_data_ptr); + FAIL_IF_ERR(wave_data_in->Reset(), "unable to reset 'WAV_DATA'"); uint8_t* wave_data = reinterpret_cast(chunk); if (chunk_byte_size < max_chunk_byte_size_) { std::memcpy(&chunk_buf_[0], chunk, chunk_byte_size); wave_data = &chunk_buf_[0]; } - FAIL_IF_ERR(in_wave_data->SetRaw(wave_data, max_chunk_byte_size_), - "unable to set data for WAVE_DATA"); + FAIL_IF_ERR(wave_data_in->AppendRaw(wave_data, max_chunk_byte_size_), + "unable to set data for 'WAV_DATA'"); + // Dim - FAIL_IF_ERR(in_wave_data_dim->Reset(), "unable to reset WAVE_DATA_DIM"); + nic::InferInput* dim_ptr; + std::vector shape{1, 1}; + FAIL_IF_ERR(nic::InferInput::Create(&dim_ptr, "WAV_DATA_DIM", shape, "INT32"), + "unable to create 'WAV_DATA_DIM'"); + std::shared_ptr dim_in(dim_ptr); + FAIL_IF_ERR(dim_in->Reset(), "unable to reset WAVE_DATA_DIM"); int nsamples = chunk_byte_size / sizeof(float); - FAIL_IF_ERR(in_wave_data_dim->SetRaw(reinterpret_cast(&nsamples), - sizeof(int32_t)), - "unable to set data for WAVE_DATA_DIM"); + FAIL_IF_ERR( + dim_in->AppendRaw(reinterpret_cast(&nsamples), sizeof(int32_t)), + "unable to set data for WAVE_DATA_DIM"); - total_audio_ += (static_cast(nsamples) / 16000.); // TODO freq - double start = gettime_monotonic(); - FAIL_IF_ERR(context.AsyncRun([corr_id, end_of_sequence, start, this]( - nic::InferContext* ctx, - const std::shared_ptr< - nic::InferContext::Request>& request) { - if (end_of_sequence) { - double elapsed = gettime_monotonic() - start; - std::map> results; - ctx->GetAsyncRunResults(request, &results); + std::vector inputs = {wave_data_in.get(), dim_in.get()}; - if (results.empty()) { - std::cerr << "Warning: Could not read " - "output for corr_id " - << corr_id << std::endl; - } else { - if (print_results_) { - std::string text; - FAIL_IF_ERR(results["TEXT"]->GetRawAtCursor(0, &text), - "unable to get TEXT output"); - std::lock_guard lk(stdout_m_); - std::cout << "CORR_ID " << corr_id << "\t\t" << text << std::endl; - } + std::vector outputs; + std::shared_ptr raw_lattice, text; + outputs.reserve(2); + if (end_of_sequence) { + nic::InferRequestedOutput* raw_lattice_ptr; + FAIL_IF_ERR( + nic::InferRequestedOutput::Create(&raw_lattice_ptr, "RAW_LATTICE"), + "unable to get 'RAW_LATTICE'"); + raw_lattice.reset(raw_lattice_ptr); + outputs.push_back(raw_lattice.get()); - std::string lattice_bytes; - FAIL_IF_ERR(results["RAW_LATTICE"]->GetRawAtCursor(0, &lattice_bytes), - "unable to get RAW_LATTICE output"); - { - std::lock_guard lk(results_m_); - results_.insert({corr_id, {std::move(lattice_bytes), elapsed}}); - } - } - n_in_flight_.fetch_sub(1, std::memory_order_relaxed); + // Request the TEXT results only when required for printing + if (print_results_) { + nic::InferRequestedOutput* text_ptr; + FAIL_IF_ERR( + nic::InferRequestedOutput::Create(&text_ptr, ctm_ ? "CTM" : "TEXT"), + "unable to get 'TEXT' or 'CTM'"); + text.reset(text_ptr); + outputs.push_back(text.get()); } - }), + } else if (print_partial_results_) { + nic::InferRequestedOutput* text_ptr; + FAIL_IF_ERR(nic::InferRequestedOutput::Create(&text_ptr, "TEXT"), + "unable to get 'TEXT'"); + text.reset(text_ptr); + outputs.push_back(text.get()); + } + + total_audio_ += (static_cast(nsamples) / samp_freq_); + + if (start_of_sequence) { + n_in_flight_.fetch_add(1, std::memory_order_consume); + } + + // Record the timestamp when the last chunk was made available. + if (end_of_sequence) { + std::lock_guard lk(start_timestamps_m_); + start_timestamps_[corr_id] = gettime_monotonic(); + } + + TritonClient* client = &clients_[corr_id % nclients_]; + // nic::InferenceServerGrpcClient& triton_client = *client->triton_client; + FAIL_IF_ERR(client->triton_client->AsyncStreamInfer(options, inputs, outputs), "unable to run model"); } -void TRTISASRClient::WaitForCallbacks() { - int n; - while ((n = n_in_flight_.load(std::memory_order_consume))) { +void TritonASRClient::WaitForCallbacks() { + while (n_in_flight_.load(std::memory_order_consume)) { usleep(1000); } } -void TRTISASRClient::PrintStats(bool print_latency_stats) { +void TritonASRClient::PrintStats(bool print_latency_stats, + bool print_throughput) { double now = gettime_monotonic(); double diff = now - started_at_; double rtf = total_audio_ / diff; - std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl; + if (print_throughput) + std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl; std::vector latencies; { std::lock_guard lk(results_m_); @@ -176,20 +240,33 @@ void TRTISASRClient::PrintStats(bool print_latency_stats) { } } -TRTISASRClient::TRTISASRClient(const std::string& url, - const std::string& model_name, - const int ncontextes, bool print_results) +TritonASRClient::TritonASRClient(const std::string& url, + const std::string& model_name, + const int nclients, bool print_results, + bool print_partial_results, bool ctm, + float samp_freq) : url_(url), model_name_(model_name), - ncontextes_(ncontextes), - print_results_(print_results) { - ncontextes_ = std::max(ncontextes_, 1); - for (int i = 0; i < ncontextes_; ++i) CreateClientContext(); + nclients_(nclients), + print_results_(print_results), + print_partial_results_(print_partial_results), + ctm_(ctm), + samp_freq_(samp_freq) { + nclients_ = std::max(nclients_, 1); + for (int i = 0; i < nclients_; ++i) CreateClientContext(); - std::shared_ptr in_wave_data; - FAIL_IF_ERR(contextes_[0].trtis_context->GetInput("WAV_DATA", &in_wave_data), - "unable to get WAV_DATA"); - max_chunk_byte_size_ = in_wave_data->ByteSize(); + inference::ModelMetadataResponse model_metadata; + FAIL_IF_ERR( + clients_[0].triton_client->ModelMetadata(&model_metadata, model_name), + "unable to get model metadata"); + + for (const auto& in_tensor : model_metadata.inputs()) { + if (in_tensor.name().compare("WAV_DATA") == 0) { + samps_per_chunk_ = in_tensor.shape()[1]; + } + } + + max_chunk_byte_size_ = samps_per_chunk_ * sizeof(float); chunk_buf_.resize(max_chunk_byte_size_); shape_ = {max_chunk_byte_size_}; n_in_flight_.store(0); @@ -197,20 +274,24 @@ TRTISASRClient::TRTISASRClient(const std::string& url, total_audio_ = 0; } -void TRTISASRClient::WriteLatticesToFile( +void TritonASRClient::WriteLatticesToFile( const std::string& clat_wspecifier, - const std::unordered_map& - corr_id_and_keys) { + const std::unordered_map& corr_id_and_keys) { kaldi::CompactLatticeWriter clat_writer; clat_writer.Open(clat_wspecifier); + std::unordered_map key_count; std::lock_guard lk(results_m_); for (auto& p : corr_id_and_keys) { - ni::CorrelationID corr_id = p.first; - const std::string& key = p.second; + uint64_t corr_id = p.first; + std::string key = p.second; + const auto iter = key_count[key]++; + if (iter > 0) { + key += std::to_string(iter); + } auto it = results_.find(corr_id); - if(it == results_.end()) { - std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl; - continue; + if (it == results_.end()) { + std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl; + continue; } const std::string& raw_lattice = it->second.raw_lattice; // We could in theory write directly the binary hold in raw_lattice (it is diff --git a/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.h b/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.h index ef8994b1..30cc9aba 100644 --- a/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.h +++ b/Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include #include -#include #include +#include -#include "request_grpc.h" +#ifndef TRITON_KALDI_ASR_CLIENT_H_ +#define TRITON_KALDI_ASR_CLIENT_H_ -#ifndef TRTIS_KALDI_ASR_CLIENT_H_ -#define TRTIS_KALDI_ASR_CLIENT_H_ namespace ni = nvidia::inferenceserver; namespace nic = nvidia::inferenceserver::client; @@ -33,16 +34,16 @@ double inline gettime_monotonic() { return time; } -class TRTISASRClient { - struct ClientContext { - std::unique_ptr trtis_context; +class TritonASRClient { + struct TritonClient { + std::unique_ptr triton_client; }; std::string url_; std::string model_name_; - std::vector contextes_; - int ncontextes_; + std::vector clients_; + int nclients_; std::vector chunk_buf_; std::vector shape_; int max_chunk_byte_size_; @@ -50,26 +51,36 @@ class TRTISASRClient { double started_at_; double total_audio_; bool print_results_; + bool print_partial_results_; + bool ctm_; std::mutex stdout_m_; + int samps_per_chunk_; + float samp_freq_; struct Result { std::string raw_lattice; double latency; }; - std::unordered_map results_; + std::unordered_map start_timestamps_; + std::mutex start_timestamps_m_; + + std::unordered_map results_; std::mutex results_m_; public: + TritonASRClient(const std::string& url, const std::string& model_name, + const int ncontextes, bool print_results, + bool print_partial_results, bool ctm, float samp_freq); + void CreateClientContext(); void SendChunk(uint64_t corr_id, bool start_of_sequence, bool end_of_sequence, - float* chunk, int chunk_byte_size); + float* chunk, int chunk_byte_size, uint64_t index); void WaitForCallbacks(); - void PrintStats(bool print_latency_stats); - void WriteLatticesToFile(const std::string &clat_wspecifier, const std::unordered_map &corr_id_and_keys); - - TRTISASRClient(const std::string& url, const std::string& model_name, - const int ncontextes, bool print_results); + void PrintStats(bool print_latency_stats, bool print_throughput); + void WriteLatticesToFile( + const std::string& clat_wspecifier, + const std::unordered_map& corr_id_and_keys); }; -#endif // TRTIS_KALDI_ASR_CLIENT_H_ +#endif // TRITON_KALDI_ASR_CLIENT_H_ diff --git a/Kaldi/SpeechRecognition/kaldi-asr-client/kaldi_asr_parallel_client.cc b/Kaldi/SpeechRecognition/kaldi-asr-client/kaldi_asr_parallel_client.cc index d1c0ceb5..cbb247a1 100644 --- a/Kaldi/SpeechRecognition/kaldi-asr-client/kaldi_asr_parallel_client.cc +++ b/Kaldi/SpeechRecognition/kaldi-asr-client/kaldi_asr_parallel_client.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ // limitations under the License. #include + #include +#include #include #include + #include "asr_client_imp.h" #include "feat/wave-reader.h" // to read the wav.scp #include "util/kaldi-table.h" @@ -41,6 +44,8 @@ void Usage(char** argv, const std::string& msg = std::string()) { "online clients." << std::endl; std::cerr << "\t-p : Print text outputs" << std::endl; + std::cerr << "\t-b : Print partial (best path) text outputs" << std::endl; + //std::cerr << "\t-t : Print text with timings (CTM)" << std::endl; std::cerr << std::endl; exit(1); @@ -49,7 +54,7 @@ void Usage(char** argv, const std::string& msg = std::string()) { int main(int argc, char** argv) { std::cout << "\n"; std::cout << "==================================================\n" - << "============= TRTIS Kaldi ASR Client =============\n" + << "============= Triton Kaldi ASR Client ============\n" << "==================================================\n" << std::endl; @@ -65,14 +70,15 @@ int main(int argc, char** argv) { size_t nchannels = 1000; int niterations = 5; bool verbose = false; - float samp_freq = 16000; - int ncontextes = 10; + int nclients = 10; bool online = false; bool print_results = false; + bool print_partial_results = false; + bool ctm = false; // Parse commandline... int opt; - while ((opt = getopt(argc, argv, "va:u:i:c:ophl:")) != -1) { + while ((opt = getopt(argc, argv, "va:u:i:c:otpbhl:")) != -1) { switch (opt) { case 'i': niterations = std::atoi(optarg); @@ -95,6 +101,13 @@ int main(int argc, char** argv) { case 'p': print_results = true; break; + case 'b': + print_partial_results = true; + break; + case 't': + ctm = true; + print_results = true; + break; case 'l': chunk_length = std::atoi(optarg); break; @@ -116,19 +129,21 @@ int main(int argc, char** argv) { std::cout << "Server URL\t\t\t: " << url << std::endl; std::cout << "Print text outputs\t\t: " << (print_results ? "Yes" : "No") << std::endl; + std::cout << "Print partial text outputs\t: " + << (print_partial_results ? "Yes" : "No") << std::endl; std::cout << "Online - Realtime I/O\t\t: " << (online ? "Yes" : "No") << std::endl; std::cout << std::endl; - float chunk_seconds = (double)chunk_length / samp_freq; - // need to read wav files - SequentialTableReader wav_reader(wav_rspecifier); - + float samp_freq = 0; double total_audio = 0; // pre-loading data // we don't want to measure I/O std::vector> all_wav; std::vector all_wav_keys; + + // need to read wav files + SequentialTableReader wav_reader(wav_rspecifier); { std::cout << "Loading eval dataset..." << std::flush; for (; !wav_reader.Done(); wav_reader.Next()) { @@ -138,90 +153,119 @@ int main(int argc, char** argv) { all_wav.push_back(wave_data); all_wav_keys.push_back(utt); total_audio += wave_data->Duration(); + samp_freq = wave_data->SampFreq(); } std::cout << "done" << std::endl; } + if (all_wav.empty()) { + std::cerr << "Empty dataset"; + exit(0); + } + + std::cout << "Loaded dataset with " << all_wav.size() + << " utterances, frequency " << samp_freq << "hz, total audio " + << total_audio << " seconds" << std::endl; + + double chunk_seconds = (double)chunk_length / samp_freq; + double seconds_per_sample = chunk_seconds / chunk_length; + struct Stream { std::shared_ptr wav; - ni::CorrelationID corr_id; + uint64_t corr_id; int offset; - float send_next_chunk_at; - std::atomic received_output; + double send_next_chunk_at; - Stream(const std::shared_ptr& _wav, ni::CorrelationID _corr_id) - : wav(_wav), corr_id(_corr_id), offset(0), received_output(true) { - send_next_chunk_at = gettime_monotonic(); + Stream(const std::shared_ptr& _wav, uint64_t _corr_id, + double _send_next_chunk_at) + : wav(_wav), + corr_id(_corr_id), + offset(0), + send_next_chunk_at(_send_next_chunk_at) {} + + bool operator<(const Stream& other) const { + return (send_next_chunk_at > other.send_next_chunk_at); } }; + std::cout << "Opening GRPC contextes..." << std::flush; - TRTISASRClient asr_client(url, model_name, ncontextes, print_results); + std::unordered_map corr_id_and_keys; + TritonASRClient asr_client(url, model_name, nclients, print_results, + print_partial_results, ctm, samp_freq); std::cout << "done" << std::endl; - std::cout << "Streaming utterances..." << std::flush; - std::vector> curr_tasks, next_tasks; - curr_tasks.reserve(nchannels); - next_tasks.reserve(nchannels); + std::cout << "Streaming utterances..." << std::endl; + std::priority_queue streams; size_t all_wav_i = 0; size_t all_wav_max = all_wav.size() * niterations; + uint64_t index = 0; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + bool add_random_offset = true; while (true) { - while (curr_tasks.size() < nchannels && all_wav_i < all_wav_max) { + while (streams.size() < nchannels && all_wav_i < all_wav_max) { // Creating new tasks uint64_t corr_id = static_cast(all_wav_i) + 1; + auto all_wav_i_modulo = all_wav_i % (all_wav.size()); + double stream_will_start_at = gettime_monotonic(); + if (add_random_offset) stream_will_start_at += dis(gen); + double first_chunk_available_at = + stream_will_start_at + + std::min(static_cast(all_wav[all_wav_i_modulo]->Duration()), + chunk_seconds); - std::unique_ptr ptr( - new Stream(all_wav[all_wav_i % (all_wav.size())], corr_id)); - curr_tasks.emplace_back(std::move(ptr)); + corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i_modulo]}); + streams.emplace(all_wav[all_wav_i_modulo], corr_id, + first_chunk_available_at); ++all_wav_i; } // If still empty, done - if (curr_tasks.empty()) break; + if (streams.empty()) break; - for (size_t itask = 0; itask < curr_tasks.size(); ++itask) { - Stream& task = *(curr_tasks[itask]); - - SubVector data(task.wav->Data(), 0); - int32 samp_offset = task.offset; - int32 nsamp = data.Dim(); - int32 samp_remaining = nsamp - samp_offset; - int32 num_samp = - chunk_length < samp_remaining ? chunk_length : samp_remaining; - bool is_last_chunk = (chunk_length >= samp_remaining); - SubVector wave_part(data, samp_offset, num_samp); - bool is_first_chunk = (samp_offset == 0); - if (online) { - double now = gettime_monotonic(); - double wait_for = task.send_next_chunk_at - now; - if (wait_for > 0) usleep(wait_for * 1e6); - } - asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk, - wave_part.Data(), wave_part.SizeInBytes()); - task.send_next_chunk_at += chunk_seconds; - if (verbose) - std::cout << "Sending correlation_id=" << task.corr_id - << " chunk offset=" << num_samp << std::endl; - - task.offset += num_samp; - if (!is_last_chunk) next_tasks.push_back(std::move(curr_tasks[itask])); + auto task = streams.top(); + streams.pop(); + if (online) { + double wait_for = task.send_next_chunk_at - gettime_monotonic(); + if (wait_for > 0) usleep(wait_for * 1e6); } + add_random_offset = false; + + SubVector data(task.wav->Data(), 0); + int32 samp_offset = task.offset; + int32 nsamp = data.Dim(); + int32 samp_remaining = nsamp - samp_offset; + int32 num_samp = + chunk_length < samp_remaining ? chunk_length : samp_remaining; + bool is_last_chunk = (chunk_length >= samp_remaining); + SubVector wave_part(data, samp_offset, num_samp); + bool is_first_chunk = (samp_offset == 0); + asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk, + wave_part.Data(), wave_part.SizeInBytes(), index++); + + if (verbose) + std::cout << "Sending correlation_id=" << task.corr_id + << " chunk offset=" << num_samp << std::endl; + + task.offset += num_samp; + int32 next_chunk_num_samp = std::min(nsamp - task.offset, chunk_length); + task.send_next_chunk_at += next_chunk_num_samp * seconds_per_sample; + + if (!is_last_chunk) streams.push(task); - curr_tasks.swap(next_tasks); - next_tasks.clear(); // Showing activity if necessary - if (!print_results && !verbose) std::cout << "." << std::flush; + if (!print_results && !print_partial_results && !verbose && + index % nchannels == 0) + std::cout << "." << std::flush; } std::cout << "done" << std::endl; std::cout << "Waiting for all results..." << std::flush; asr_client.WaitForCallbacks(); std::cout << "done" << std::endl; - asr_client.PrintStats(online); - - std::unordered_map corr_id_and_keys; - for (size_t all_wav_i = 0; all_wav_i < all_wav.size(); ++all_wav_i) { - ni::CorrelationID corr_id = static_cast(all_wav_i) + 1; - corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i]}); - } + asr_client.PrintStats( + online, + !online); // Print latency if online, do not print throughput if online asr_client.WriteLatticesToFile("ark:|gzip -c > /data/results/lat.cuda-asr.gz", corr_id_and_keys); diff --git a/Kaldi/SpeechRecognition/model-repo/kaldi_online/config.pbtxt b/Kaldi/SpeechRecognition/model-repo/kaldi_online/config.pbtxt index 96628bc0..a7a0eeba 100644 --- a/Kaldi/SpeechRecognition/model-repo/kaldi_online/config.pbtxt +++ b/Kaldi/SpeechRecognition/model-repo/kaldi_online/config.pbtxt @@ -1,80 +1,103 @@ name: "kaldi_online" -platform: "custom" -default_model_filename: "libkaldi-trtisbackend.so" -max_batch_size: 2200 -parameters: { -key: "mfcc_filename" -value: { -string_value:"/data/models/LibriSpeech/conf/mfcc.conf" -} -} -parameters: { -key: "ivector_filename" -value: { -string_value:"/data/models/LibriSpeech/conf/ivector_extractor.conf" -} -} -parameters: { -key: "nnet3_rxfilename" -value: { -string_value: "/data/models/LibriSpeech/final.mdl" -} -} -parameters: { -key: "fst_rxfilename" -value: { -string_value: "/data/models/LibriSpeech/HCLG.fst" -} -} -parameters: { -key: "word_syms_rxfilename" -value: { -string_value:"/data/models/LibriSpeech/words.txt" -} -} -parameters: [{ -key: "beam" -value: { -string_value:"10" -} -},{ -key: "num_worker_threads" -value: { -string_value:"40" -} -}, -{ -key: "max_execution_batch_size" -value: { -string_value:"400" -} -}] -parameters: { -key: "lattice_beam" -value: { -string_value:"7" -} -} -parameters: { -key: "max_active" -value: { -string_value:"10000" -} -} -parameters: { -key: "frame_subsampling_factor" -value: { -string_value:"3" -} -} -parameters: { -key: "acoustic_scale" -value: { -string_value:"1.0" -} +backend: "kaldi" +max_batch_size: 600 +model_transaction_policy { + decoupled: True } +parameters [ + { + key: "config_filename" + value { + string_value: "/data/models/LibriSpeech/conf/online.conf" + } + }, + { + key: "nnet3_rxfilename" + value { + string_value: "/data/models/LibriSpeech/final.mdl" + } + }, + { + key: "fst_rxfilename" + value { + string_value: "/data/models/LibriSpeech/HCLG.fst" + } + }, + { + key: "word_syms_rxfilename" + value { + string_value: "/data/models/LibriSpeech/words.txt" + } + }, + { + key: "lattice_postprocessor_rxfilename" + value { + string_value: "" + } + }, + { + key: "use_tensor_cores" + value { + string_value: "1" + } + }, + { + key: "main_q_capacity" + value { + string_value: "30000" + } + }, + { + key: "aux_q_capacity" + value { + string_value: "400000" + } + }, + { + key: "beam" + value { + string_value: "10" + } + }, + { + key: "num_worker_threads" + value { + string_value: "40" + } + }, + { + key: "num_channels" + value { + string_value: "4000" + } + }, + { + key: "lattice_beam" + value { + string_value: "7" + } + }, + { + key: "max_active" + value { + string_value: "10000" + } + }, + { + key: "frame_subsampling_factor" + value { + string_value: "3" + } + }, + { + key: "acoustic_scale" + value { + string_value: "1.0" + } + } +] sequence_batching { -max_sequence_idle_microseconds:5000000 + max_sequence_idle_microseconds: 5000000 control_input [ { name: "START" @@ -108,16 +131,16 @@ max_sequence_idle_microseconds:5000000 control [ { kind: CONTROL_SEQUENCE_CORRID - data_type: TYPE_UINT64 + data_type: TYPE_UINT64 } ] } ] -oldest { -max_candidate_sequences:2200 -preferred_batch_size:[400] -max_queue_delay_microseconds:1000 -} + oldest { + max_candidate_sequences: 4000 + preferred_batch_size: [ 600 ] + max_queue_delay_microseconds: 1000 + } }, input [ @@ -142,13 +165,17 @@ output [ name: "TEXT" data_type: TYPE_STRING dims: [ 1 ] + }, + { + name: "CTM" + data_type: TYPE_STRING + dims: [ 1 ] } + ] instance_group [ { - count: 2 + count: 1 kind: KIND_GPU } ] - - diff --git a/Kaldi/SpeechRecognition/scripts/compute_wer.sh b/Kaldi/SpeechRecognition/scripts/compute_wer.sh index 0f57c3d0..57d0792f 100755 --- a/Kaldi/SpeechRecognition/scripts/compute_wer.sh +++ b/Kaldi/SpeechRecognition/scripts/compute_wer.sh @@ -15,10 +15,10 @@ oovtok=$(cat $result_path/words.txt | grep "" | awk '{print $2}') # convert lattice to transcript /opt/kaldi/src/latbin/lattice-best-path \ "ark:gunzip -c $result_path/lat.cuda-asr.gz |"\ - "ark,t:|gzip -c > $result_path/trans.cuda-asr.gz" 2> /dev/null + "ark,t:$result_path/trans.cuda-asr" 2> /dev/null # calculate wer /opt/kaldi/src/bin/compute-wer --mode=present \ "ark:$result_path/text_ints" \ - "ark:gunzip -c $result_path/trans.cuda-asr.gz |" 2>&1 + "ark:$result_path/trans.cuda-asr" 2> /dev/null diff --git a/Kaldi/SpeechRecognition/scripts/docker/build.sh b/Kaldi/SpeechRecognition/scripts/docker/build.sh index ed60abe4..c3016b6b 100755 --- a/Kaldi/SpeechRecognition/scripts/docker/build.sh +++ b/Kaldi/SpeechRecognition/scripts/docker/build.sh @@ -13,5 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -docker build . -f Dockerfile --rm -t trtis_kaldi_server -docker build . -f Dockerfile.client --rm -t trtis_kaldi_client +set -eu + +# Use development branch of Kaldi for latest feature support +docker build . -f Dockerfile \ + --rm -t triton_kaldi_server +docker build . -f Dockerfile.client \ + --rm -t triton_kaldi_client diff --git a/Kaldi/SpeechRecognition/scripts/docker/launch_client.sh b/Kaldi/SpeechRecognition/scripts/docker/launch_client.sh index 375e636a..c5812955 100755 --- a/Kaldi/SpeechRecognition/scripts/docker/launch_client.sh +++ b/Kaldi/SpeechRecognition/scripts/docker/launch_client.sh @@ -19,4 +19,5 @@ docker run --rm -it \ --ulimit memlock=-1 \ --ulimit stack=67108864 \ -v $PWD/data:/data \ - trtis_kaldi_client /workspace/scripts/docker/run_client.sh $@ + --entrypoint /bin/bash \ + triton_kaldi_client /workspace/scripts/docker/run_client.sh $@ diff --git a/Kaldi/SpeechRecognition/scripts/docker/launch_download.sh b/Kaldi/SpeechRecognition/scripts/docker/launch_download.sh index 6cb67a43..006de82d 100755 --- a/Kaldi/SpeechRecognition/scripts/docker/launch_download.sh +++ b/Kaldi/SpeechRecognition/scripts/docker/launch_download.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Start TRTIS server container for download - need some kaldi tools -nvidia-docker run --rm \ +# Start Triton server container for download - need some kaldi tools +docker run --rm \ --shm-size=1g \ --ulimit memlock=-1 \ --ulimit stack=67108864 \ -v $PWD/data:/mnt/data \ - trtis_kaldi_server /workspace/scripts/docker/dataset_setup.sh $(id -u) $(id -g) + triton_kaldi_server /workspace/scripts/docker/dataset_setup.sh $(id -u) $(id -g) # --user $(id -u):$(id -g) \ diff --git a/Kaldi/SpeechRecognition/scripts/docker/launch_server.sh b/Kaldi/SpeechRecognition/scripts/docker/launch_server.sh index c75fec74..39a4389b 100755 --- a/Kaldi/SpeechRecognition/scripts/docker/launch_server.sh +++ b/Kaldi/SpeechRecognition/scripts/docker/launch_server.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,8 +15,9 @@ NV_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-"0"} -# Start TRTIS server -nvidia-docker run --rm -it \ +# Start Triton server +docker run --rm -it \ + --gpus $NV_VISIBLE_DEVICES \ --shm-size=1g \ --ulimit memlock=-1 \ --ulimit stack=67108864 \ @@ -24,7 +25,6 @@ nvidia-docker run --rm -it \ -p8001:8001 \ -p8002:8002 \ --name trt_server_asr \ - -e NVIDIA_VISIBLE_DEVICES=$NV_VISIBLE_DEVICES \ -v $PWD/data:/data \ -v $PWD/model-repo:/mnt/model-repo \ - trtis_kaldi_server trtserver --model-repo=/workspace/model-repo/ + triton_kaldi_server diff --git a/Kaldi/SpeechRecognition/scripts/docker/run_client.sh b/Kaldi/SpeechRecognition/scripts/docker/run_client.sh index 381f8424..dc6e0aa8 100755 --- a/Kaldi/SpeechRecognition/scripts/docker/run_client.sh +++ b/Kaldi/SpeechRecognition/scripts/docker/run_client.sh @@ -7,7 +7,7 @@ then rm -rf $results_dir fi mkdir $results_dir -install/bin/kaldi_asr_parallel_client $@ +kaldi-asr-parallel-client $@ echo "Computing WER..." /workspace/scripts/compute_wer.sh rm -rf $results_dir diff --git a/Kaldi/SpeechRecognition/scripts/nvidia_kaldi_trtis_entrypoint.sh b/Kaldi/SpeechRecognition/scripts/nvidia_kaldi_triton_entrypoint.sh similarity index 94% rename from Kaldi/SpeechRecognition/scripts/nvidia_kaldi_trtis_entrypoint.sh rename to Kaldi/SpeechRecognition/scripts/nvidia_kaldi_triton_entrypoint.sh index 7c157bbb..ace6b92c 100755 --- a/Kaldi/SpeechRecognition/scripts/nvidia_kaldi_trtis_entrypoint.sh +++ b/Kaldi/SpeechRecognition/scripts/nvidia_kaldi_triton_entrypoint.sh @@ -19,4 +19,4 @@ if [ -d "/mnt/model-repo/kaldi_online" ]; then ln -s /mnt/model-repo/kaldi_online/config.pbtxt /workspace/model-repo/kaldi_online/ fi -/opt/tensorrtserver/nvidia_entrypoint.sh $@ +/opt/tritonserver/nvidia_entrypoint.sh $@ diff --git a/Kaldi/SpeechRecognition/trtis-kaldi-backend/libkaldi_online.ldscript b/Kaldi/SpeechRecognition/scripts/run_inference_all_a100.sh old mode 100644 new mode 100755 similarity index 55% rename from Kaldi/SpeechRecognition/trtis-kaldi-backend/libkaldi_online.ldscript rename to Kaldi/SpeechRecognition/scripts/run_inference_all_a100.sh index 18e2c008..92a0e783 --- a/Kaldi/SpeechRecognition/trtis-kaldi-backend/libkaldi_online.ldscript +++ b/Kaldi/SpeechRecognition/scripts/run_inference_all_a100.sh @@ -1,4 +1,6 @@ -# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +#!/bin/bash + +# Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,11 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -{ - global: - CustomErrorString; - CustomExecute; - CustomFinalize; - CustomInitialize; - local: *; -}; +set -e + +if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then + printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n" + exit 1 +fi + +printf "\nOffline benchmarks:\n" + +scripts/docker/launch_client.sh -i 5 -c 4000 + +printf "\nOnline benchmarks:\n" + +scripts/docker/launch_client.sh -i 10 -c 2000 -o diff --git a/Kaldi/SpeechRecognition/scripts/run_inference_all_t4.sh b/Kaldi/SpeechRecognition/scripts/run_inference_all_t4.sh index 83180a01..96767f53 100755 --- a/Kaldi/SpeechRecognition/scripts/run_inference_all_t4.sh +++ b/Kaldi/SpeechRecognition/scripts/run_inference_all_t4.sh @@ -15,7 +15,7 @@ set -e -if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then +if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n" exit 1 fi @@ -26,5 +26,5 @@ scripts/docker/launch_client.sh -i 5 -c 1000 printf "\nOnline benchmarks:\n" -scripts/docker/launch_client.sh -i 10 -c 700 -o +scripts/docker/launch_client.sh -i 10 -c 600 -o scripts/docker/launch_client.sh -i 10 -c 400 -o diff --git a/Kaldi/SpeechRecognition/scripts/run_inference_all_v100.sh b/Kaldi/SpeechRecognition/scripts/run_inference_all_v100.sh index cf7357fe..b0b6d456 100755 --- a/Kaldi/SpeechRecognition/scripts/run_inference_all_v100.sh +++ b/Kaldi/SpeechRecognition/scripts/run_inference_all_v100.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,7 @@ set -e -if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then +if [[ "$(docker ps | grep triton_kaldi_server | wc -l)" == "0" ]]; then printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n" exit 1 fi @@ -26,6 +26,5 @@ scripts/docker/launch_client.sh -i 5 -c 2000 printf "\nOnline benchmarks:\n" -scripts/docker/launch_client.sh -i 10 -c 1500 -o +scripts/docker/launch_client.sh -i 10 -c 2000 -o scripts/docker/launch_client.sh -i 10 -c 1000 -o -scripts/docker/launch_client.sh -i 5 -c 800 -o diff --git a/Kaldi/SpeechRecognition/trtis-kaldi-backend/Makefile b/Kaldi/SpeechRecognition/trtis-kaldi-backend/Makefile deleted file mode 100644 index 3159d325..00000000 --- a/Kaldi/SpeechRecognition/trtis-kaldi-backend/Makefile +++ /dev/null @@ -1,5 +0,0 @@ -.PHONY: all -all: kaldibackend - -kaldibackend: kaldi-backend.cc kaldi-backend-utils.cc - g++ -fpic -shared -std=c++11 -o libkaldi-trtisbackend.so kaldi-backend.cc kaldi-backend-utils.cc -Icustom-backend-sdk/include custom-backend-sdk/lib/libcustombackend.a -I/opt/kaldi/src/ -I/usr/local/cuda/include -I/opt/kaldi/tools/openfst/include/ -L/opt/kaldi/src/lib/ -lkaldi-cudadecoder diff --git a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.cc b/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.cc deleted file mode 100644 index c9c86a41..00000000 --- a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "kaldi-backend-utils.h" - -namespace nvidia { -namespace inferenceserver { -namespace custom { -namespace kaldi_cbe { - -int GetInputTensor(CustomGetNextInputFn_t input_fn, void* input_context, - const char* name, const size_t expected_byte_size, - std::vector* input, const void** out) { - input->clear(); // reset buffer - // The values for an input tensor are not necessarily in one - // contiguous chunk, so we might copy the chunks into 'input' vector. - // If possible, we use the data in place - uint64_t total_content_byte_size = 0; - while (true) { - const void* content; - uint64_t content_byte_size = expected_byte_size - total_content_byte_size; - if (!input_fn(input_context, name, &content, &content_byte_size)) { - return kInputContents; - } - - // If 'content' returns nullptr we have all the input. - if (content == nullptr) break; - - // If the total amount of content received exceeds what we expect - // then something is wrong. - total_content_byte_size += content_byte_size; - if (total_content_byte_size > expected_byte_size) - return kInputSize; - - if (content_byte_size == expected_byte_size) { - *out = content; - return kSuccess; - } - - input->insert(input->end(), static_cast(content), - static_cast(content) + content_byte_size); - } - - // Make sure we end up with exactly the amount of input we expect. - if (total_content_byte_size != expected_byte_size) { - return kInputSize; - } - *out = &input[0]; - - return kSuccess; -} - -void LatticeToString(fst::SymbolTable& word_syms, - const kaldi::CompactLattice& dlat, std::string* out_str) { - kaldi::CompactLattice best_path_clat; - kaldi::CompactLatticeShortestPath(dlat, &best_path_clat); - - kaldi::Lattice best_path_lat; - fst::ConvertLattice(best_path_clat, &best_path_lat); - - std::vector alignment; - std::vector words; - kaldi::LatticeWeight weight; - fst::GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); - std::ostringstream oss; - for (size_t i = 0; i < words.size(); i++) { - std::string s = word_syms.Find(words[i]); - if (s == "") std::cerr << "Word-id " << words[i] << " not in symbol table."; - oss << s << " "; - } - *out_str = std::move(oss.str()); -} - -int ReadParameter(const ModelConfig& model_config_, const std::string& key, - std::string* param) { - auto it = model_config_.parameters().find(key); - if (it == model_config_.parameters().end()) { - std::cerr << "Parameter \"" << key - << "\" missing from config file. Exiting." << std::endl; - return kInvalidModelConfig; - } - *param = it->second.string_value(); - return kSuccess; -} - -int ReadParameter(const ModelConfig& model_config_, const std::string& key, - int* param) { - std::string tmp; - int err = ReadParameter(model_config_, key, &tmp); - *param = std::stoi(tmp); - return err; -} - -int ReadParameter(const ModelConfig& model_config_, const std::string& key, - float* param) { - std::string tmp; - int err = ReadParameter(model_config_, key, &tmp); - *param = std::stof(tmp); - return err; -} - -const char* CustomErrorString(int errcode) { - switch (errcode) { - case kSuccess: - return "success"; - case kInvalidModelConfig: - return "invalid model configuration"; - case kGpuNotSupported: - return "execution on GPU not supported"; - case kSequenceBatcher: - return "model configuration must configure sequence batcher"; - case kModelControl: - return "'START' and 'READY' must be configured as the control inputs"; - case kInputOutput: - return "model must have four inputs and one output with shape [-1]"; - case kInputName: - return "names for input don't exist"; - case kOutputName: - return "model output must be named 'OUTPUT'"; - case kInputOutputDataType: - return "model inputs or outputs data_type cannot be specified"; - case kInputContents: - return "unable to get input tensor values"; - case kInputSize: - return "unexpected size for input tensor"; - case kOutputBuffer: - return "unable to get buffer for output tensor values"; - case kBatchTooBig: - return "unable to execute batch larger than max-batch-size"; - case kTimesteps: - return "unable to execute more than 1 timestep at a time"; - case kChunkTooBig: - return "a chunk cannot contain more samples than the WAV_DATA dimension"; - default: - break; - } - - return "unknown error"; -} - -} // kaldi -} // custom -} // inferenceserver -} // nvidia diff --git a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.h b/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.h deleted file mode 100644 index 613e2713..00000000 --- a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend-utils.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "lat/lattice-functions.h" -#include "src/core/model_config.h" -#include "src/core/model_config.pb.h" -#include "src/custom/sdk/custom_instance.h" - -namespace nvidia { -namespace inferenceserver { -namespace custom { -namespace kaldi_cbe { - -enum ErrorCodes { - kSuccess, - kUnknown, - kInvalidModelConfig, - kGpuNotSupported, - kSequenceBatcher, - kModelControl, - kInputOutput, - kInputName, - kOutputName, - kInputOutputDataType, - kInputContents, - kInputSize, - kOutputBuffer, - kBatchTooBig, - kTimesteps, - kChunkTooBig -}; - -int GetInputTensor(CustomGetNextInputFn_t input_fn, void* input_context, - const char* name, const size_t expected_byte_size, - std::vector* input, const void** out); - -void LatticeToString(fst::SymbolTable& word_syms, - const kaldi::CompactLattice& dlat, std::string* out_str); - -int ReadParameter(const ModelConfig& model_config_, const std::string& key, - std::string* param); - -int ReadParameter(const ModelConfig& model_config_, const std::string& key, - int* param); -int ReadParameter(const ModelConfig& model_config_, const std::string& key, - float* param); - -const char* CustomErrorString(int errcode); - -} // kaldi -} // custom -} // inferenceserver -} // nvidia diff --git a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.cc b/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.cc deleted file mode 100644 index ed855519..00000000 --- a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.cc +++ /dev/null @@ -1,423 +0,0 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "kaldi-backend.h" -#include "kaldi-backend-utils.h" - -namespace nvidia { -namespace inferenceserver { -namespace custom { -namespace kaldi_cbe { - -Context::Context(const std::string& instance_name, - const ModelConfig& model_config, const int gpu_device) - : instance_name_(instance_name), - model_config_(model_config), - gpu_device_(gpu_device), - num_channels_( - model_config_ - .max_batch_size()), // diff in def between kaldi and trtis - int32_byte_size_(GetDataTypeByteSize(TYPE_INT32)), - int64_byte_size_(GetDataTypeByteSize(TYPE_INT64)) {} - -Context::~Context() { delete word_syms_; } - -int Context::ReadModelParameters() { - // Reading config - float beam, lattice_beam; - int max_active; - int frame_subsampling_factor; - float acoustic_scale; - int num_worker_threads; - int err = - ReadParameter(model_config_, "mfcc_filename", - &batched_decoder_config_.feature_opts.mfcc_config) || - ReadParameter( - model_config_, "ivector_filename", - &batched_decoder_config_.feature_opts.ivector_extraction_config) || - ReadParameter(model_config_, "beam", &beam) || - ReadParameter(model_config_, "lattice_beam", &lattice_beam) || - ReadParameter(model_config_, "max_active", &max_active) || - ReadParameter(model_config_, "frame_subsampling_factor", - &frame_subsampling_factor) || - ReadParameter(model_config_, "acoustic_scale", &acoustic_scale) || - ReadParameter(model_config_, "nnet3_rxfilename", &nnet3_rxfilename_) || - ReadParameter(model_config_, "fst_rxfilename", &fst_rxfilename_) || - ReadParameter(model_config_, "word_syms_rxfilename", - &word_syms_rxfilename_) || - ReadParameter(model_config_, "num_worker_threads", &num_worker_threads) || - ReadParameter(model_config_, "max_execution_batch_size", - &max_batch_size_); - if (err) return err; - max_batch_size_ = std::max(max_batch_size_, 1); - num_channels_ = std::max(num_channels_, 1); - - // Sanity checks - if (beam <= 0) return kInvalidModelConfig; - if (lattice_beam <= 0) return kInvalidModelConfig; - if (max_active <= 0) return kInvalidModelConfig; - if (acoustic_scale <= 0) return kInvalidModelConfig; - if (num_worker_threads <= 0) return kInvalidModelConfig; - if (num_channels_ <= max_batch_size_) return kInvalidModelConfig; - - batched_decoder_config_.compute_opts.frame_subsampling_factor = - frame_subsampling_factor; - batched_decoder_config_.compute_opts.acoustic_scale = acoustic_scale; - batched_decoder_config_.decoder_opts.default_beam = beam; - batched_decoder_config_.decoder_opts.lattice_beam = lattice_beam; - batched_decoder_config_.decoder_opts.max_active = max_active; - batched_decoder_config_.num_worker_threads = num_worker_threads; - batched_decoder_config_.max_batch_size = max_batch_size_; - batched_decoder_config_.num_channels = num_channels_; - - auto feature_config = batched_decoder_config_.feature_opts; - kaldi::OnlineNnet2FeaturePipelineInfo feature_info(feature_config); - sample_freq_ = feature_info.mfcc_opts.frame_opts.samp_freq; - BaseFloat frame_shift = feature_info.FrameShiftInSeconds(); - seconds_per_chunk_ = chunk_num_samps_ / sample_freq_; - - int samp_per_frame = static_cast(sample_freq_ * frame_shift); - float n_input_framesf = chunk_num_samps_ / samp_per_frame; - bool is_integer = (n_input_framesf == std::floor(n_input_framesf)); - if (!is_integer) { - std::cerr << "WAVE_DATA dim must be a multiple fo samples per frame (" - << samp_per_frame << ")" << std::endl; - return kInvalidModelConfig; - } - int n_input_frames = static_cast(std::floor(n_input_framesf)); - batched_decoder_config_.compute_opts.frames_per_chunk = n_input_frames; - - return kSuccess; -} - -int Context::InitializeKaldiPipeline() { - batch_corr_ids_.reserve(max_batch_size_); - batch_wave_samples_.reserve(max_batch_size_); - batch_is_first_chunk_.reserve(max_batch_size_); - batch_is_last_chunk_.reserve(max_batch_size_); - wave_byte_buffers_.resize(max_batch_size_); - output_shape_ = {1, 1}; - kaldi::CuDevice::Instantiate() - .SelectAndInitializeGpuIdWithExistingCudaContext(gpu_device_); - kaldi::CuDevice::Instantiate().AllowMultithreading(); - - // Loading models - { - bool binary; - kaldi::Input ki(nnet3_rxfilename_, &binary); - trans_model_.Read(ki.Stream(), binary); - am_nnet_.Read(ki.Stream(), binary); - - kaldi::nnet3::SetBatchnormTestMode(true, &(am_nnet_.GetNnet())); - kaldi::nnet3::SetDropoutTestMode(true, &(am_nnet_.GetNnet())); - kaldi::nnet3::CollapseModel(kaldi::nnet3::CollapseModelConfig(), - &(am_nnet_.GetNnet())); - } - fst::Fst* decode_fst = fst::ReadFstKaldiGeneric(fst_rxfilename_); - cuda_pipeline_.reset( - new kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline( - batched_decoder_config_, *decode_fst, am_nnet_, trans_model_)); - delete decode_fst; - - // Loading word syms for text output - if (word_syms_rxfilename_ != "") { - if (!(word_syms_ = fst::SymbolTable::ReadText(word_syms_rxfilename_))) { - std::cerr << "Could not read symbol table from file " - << word_syms_rxfilename_; - return kInvalidModelConfig; - } - } - chunk_num_samps_ = cuda_pipeline_->GetNSampsPerChunk(); - chunk_num_bytes_ = chunk_num_samps_ * sizeof(BaseFloat); - return kSuccess; -} - -int Context::Init() { - return InputOutputSanityCheck() || ReadModelParameters() || - InitializeKaldiPipeline(); -} - -bool Context::CheckPayloadError(const CustomPayload& payload) { - int err = payload.error_code; - if (err) std::cerr << "Error: " << CustomErrorString(err) << std::endl; - return (err != 0); -} - -int Context::Execute(const uint32_t payload_cnt, CustomPayload* payloads, - CustomGetNextInputFn_t input_fn, - CustomGetOutputFn_t output_fn) { - // kaldi::Timer timer; - if (payload_cnt > num_channels_) return kBatchTooBig; - // Each payload is a chunk for one sequence - // Currently using dynamic batcher, not sequence batcher - for (uint32_t pidx = 0; pidx < payload_cnt; ++pidx) { - if (batch_corr_ids_.size() == max_batch_size_) FlushBatch(); - - CustomPayload& payload = payloads[pidx]; - if (payload.batch_size != 1) payload.error_code = kTimesteps; - if (CheckPayloadError(payload)) continue; - - // Get input tensors - int32_t start, dim, end, ready; - CorrelationID corr_id; - const BaseFloat* wave_buffer; - payload.error_code = GetSequenceInput( - input_fn, payload.input_context, &corr_id, &start, &ready, &dim, &end, - &wave_buffer, &wave_byte_buffers_[pidx]); - if (CheckPayloadError(payload)) continue; - if (!ready) continue; - if (dim > chunk_num_samps_) payload.error_code = kChunkTooBig; - if (CheckPayloadError(payload)) continue; - - kaldi::SubVector wave_part(wave_buffer, dim); - // Initialize corr_id if first chunk - if (start) { - if (!cuda_pipeline_->TryInitCorrID(corr_id)) { - printf("ERR %i \n", __LINE__); - // TODO add error code - continue; - } - } - // Add to batch - batch_corr_ids_.push_back(corr_id); - batch_wave_samples_.push_back(wave_part); - batch_is_first_chunk_.push_back(start); - batch_is_last_chunk_.push_back(end); - - if (end) { - // If last chunk, set the callback for that seq - cuda_pipeline_->SetLatticeCallback( - corr_id, [this, &output_fn, &payloads, pidx, - corr_id](kaldi::CompactLattice& clat) { - SetOutputs(clat, output_fn, payloads[pidx]); - }); - } - } - FlushBatch(); - cuda_pipeline_->WaitForLatticeCallbacks(); - return kSuccess; -} - -int Context::FlushBatch() { - if (!batch_corr_ids_.empty()) { - cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_, - batch_is_first_chunk_, batch_is_last_chunk_); - batch_corr_ids_.clear(); - batch_wave_samples_.clear(); - batch_is_first_chunk_.clear(); - batch_is_last_chunk_.clear(); - } -} - -int Context::InputOutputSanityCheck() { - if (!model_config_.has_sequence_batching()) { - return kSequenceBatcher; - } - - auto& batcher = model_config_.sequence_batching(); - if (batcher.control_input_size() != 4) { - return kModelControl; - } - - std::set control_input_names; - for (int i = 0; i < 4; ++i) - control_input_names.insert(batcher.control_input(i).name()); - if (!(control_input_names.erase("START") && - control_input_names.erase("END") && - control_input_names.erase("CORRID") && - control_input_names.erase("READY"))) { - return kModelControl; - } - - if (model_config_.input_size() != 2) { - return kInputOutput; - } - if ((model_config_.input(0).dims().size() != 1) || - (model_config_.input(0).dims(0) <= 0) || - (model_config_.input(1).dims().size() != 1) || - (model_config_.input(1).dims(0) != 1)) { - return kInputOutput; - } - chunk_num_samps_ = model_config_.input(0).dims(0); - chunk_num_bytes_ = chunk_num_samps_ * sizeof(float); - - if ((model_config_.input(0).data_type() != DataType::TYPE_FP32) || - (model_config_.input(1).data_type() != DataType::TYPE_INT32)) { - return kInputOutputDataType; - } - if ((model_config_.input(0).name() != "WAV_DATA") || - (model_config_.input(1).name() != "WAV_DATA_DIM")) { - return kInputName; - } - - if (model_config_.output_size() != 2) return kInputOutput; - - for (int ioutput = 0; ioutput < 2; ++ioutput) { - if ((model_config_.output(ioutput).dims().size() != 1) || - (model_config_.output(ioutput).dims(0) != 1)) { - return kInputOutput; - } - if (model_config_.output(ioutput).data_type() != DataType::TYPE_STRING) { - return kInputOutputDataType; - } - } - - if (model_config_.output(0).name() != "RAW_LATTICE") return kOutputName; - if (model_config_.output(1).name() != "TEXT") return kOutputName; - - return kSuccess; -} - -int Context::GetSequenceInput(CustomGetNextInputFn_t& input_fn, - void* input_context, CorrelationID* corr_id, - int32_t* start, int32_t* ready, int32_t* dim, - int32_t* end, const BaseFloat** wave_buffer, - std::vector* input_buffer) { - int err; - //&input_buffer[0]: char pointer -> alias with any types - // wave_data[0] will holds the struct - - // Get start of sequence tensor - const void* out; - err = GetInputTensor(input_fn, input_context, "WAV_DATA_DIM", - int32_byte_size_, &byte_buffer_, &out); - if (err != kSuccess) return err; - *dim = *reinterpret_cast(out); - - err = GetInputTensor(input_fn, input_context, "END", int32_byte_size_, - &byte_buffer_, &out); - if (err != kSuccess) return err; - *end = *reinterpret_cast(out); - - err = GetInputTensor(input_fn, input_context, "START", int32_byte_size_, - &byte_buffer_, &out); - if (err != kSuccess) return err; - *start = *reinterpret_cast(out); - - err = GetInputTensor(input_fn, input_context, "READY", int32_byte_size_, - &byte_buffer_, &out); - if (err != kSuccess) return err; - *ready = *reinterpret_cast(out); - - err = GetInputTensor(input_fn, input_context, "CORRID", int64_byte_size_, - &byte_buffer_, &out); - if (err != kSuccess) return err; - *corr_id = *reinterpret_cast(out); - - // Get pointer to speech tensor - err = GetInputTensor(input_fn, input_context, "WAV_DATA", chunk_num_bytes_, - input_buffer, &out); - if (err != kSuccess) return err; - *wave_buffer = reinterpret_cast(out); - - return kSuccess; -} - -int Context::SetOutputs(kaldi::CompactLattice& clat, - CustomGetOutputFn_t output_fn, CustomPayload payload) { - int status = kSuccess; - if (payload.error_code != kSuccess) return payload.error_code; - for (int ioutput = 0; ioutput < payload.output_cnt; ++ioutput) { - const char* output_name = payload.required_output_names[ioutput]; - if (!strcmp(output_name, "RAW_LATTICE")) { - std::ostringstream oss; - kaldi::WriteCompactLattice(oss, true, clat); - status = SetOutputByName(output_name, oss.str(), output_fn, payload); - if(status != kSuccess) return status; - } else if (!strcmp(output_name, "TEXT")) { - std::string output; - LatticeToString(*word_syms_, clat, &output); - status = SetOutputByName(output_name, output, output_fn, payload); - if(status != kSuccess) return status; - } - } - - return status; -} - -int Context::SetOutputByName(const char* output_name, - const std::string& out_bytes, - CustomGetOutputFn_t output_fn, - CustomPayload payload) { - uint32_t byte_size_with_size_int = out_bytes.size() + sizeof(int32); - void* obuffer; // output buffer - if (!output_fn(payload.output_context, output_name, output_shape_.size(), - &output_shape_[0], byte_size_with_size_int, &obuffer)) { - payload.error_code = kOutputBuffer; - return payload.error_code; - } - if (obuffer == nullptr) return kOutputBuffer; - - int32* buffer_as_int = reinterpret_cast(obuffer); - buffer_as_int[0] = out_bytes.size(); - memcpy(&buffer_as_int[1], out_bytes.data(), out_bytes.size()); - - return kSuccess; -} - -///////////// - -extern "C" { - -int CustomInitialize(const CustomInitializeData* data, void** custom_context) { - // Convert the serialized model config to a ModelConfig object. - ModelConfig model_config; - if (!model_config.ParseFromString(std::string( - data->serialized_model_config, data->serialized_model_config_size))) { - return kInvalidModelConfig; - } - - // Create the context and validate that the model configuration is - // something that we can handle. - Context* context = new Context(std::string(data->instance_name), model_config, - data->gpu_device_id); - int err = context->Init(); - if (err != kSuccess) { - return err; - } - - *custom_context = static_cast(context); - - return kSuccess; -} - -int CustomFinalize(void* custom_context) { - if (custom_context != nullptr) { - Context* context = static_cast(custom_context); - delete context; - } - - return kSuccess; -} - -const char* CustomErrorString(void* custom_context, int errcode) { - return CustomErrorString(errcode); -} - -int CustomExecute(void* custom_context, const uint32_t payload_cnt, - CustomPayload* payloads, CustomGetNextInputFn_t input_fn, - CustomGetOutputFn_t output_fn) { - if (custom_context == nullptr) { - return kUnknown; - } - - Context* context = static_cast(custom_context); - return context->Execute(payload_cnt, payloads, input_fn, output_fn); -} - -} // extern "C" -} // namespace kaldi_cbe -} // namespace custom -} // namespace inferenceserver -} // namespace nvidia diff --git a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.h b/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.h deleted file mode 100644 index 19c40bd0..00000000 --- a/Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#define HAVE_CUDA 1 // Loading Kaldi headers with GPU - -#include -#include -#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h" -#include "fstext/fstext-lib.h" -#include "lat/kaldi-lattice.h" -#include "lat/lattice-functions.h" -#include "nnet3/am-nnet-simple.h" -#include "nnet3/nnet-utils.h" -#include "util/kaldi-thread.h" - -#include "src/core/model_config.h" -#include "src/core/model_config.pb.h" -#include "src/custom/sdk/custom_instance.h" - -using kaldi::BaseFloat; - -namespace nvidia { -namespace inferenceserver { -namespace custom { -namespace kaldi_cbe { - -// Context object. All state must be kept in this object. -class Context { - public: - Context(const std::string& instance_name, const ModelConfig& config, - const int gpu_device); - virtual ~Context(); - - // Initialize the context. Validate that the model configuration, - // etc. is something that we can handle. - int Init(); - - // Perform custom execution on the payloads. - int Execute(const uint32_t payload_cnt, CustomPayload* payloads, - CustomGetNextInputFn_t input_fn, CustomGetOutputFn_t output_fn); - - private: - // init kaldi pipeline - int InitializeKaldiPipeline(); - int InputOutputSanityCheck(); - int ReadModelParameters(); - int GetSequenceInput(CustomGetNextInputFn_t& input_fn, void* input_context, - CorrelationID* corr_id, int32_t* start, int32_t* ready, - int32_t* dim, int32_t* end, - const kaldi::BaseFloat** wave_buffer, - std::vector* input_buffer); - - int SetOutputs(kaldi::CompactLattice& clat, - CustomGetOutputFn_t output_fn, CustomPayload payload); - - int SetOutputByName(const char* output_name, - const std::string& out_bytes, - CustomGetOutputFn_t output_fn, - CustomPayload payload); - - bool CheckPayloadError(const CustomPayload& payload); - int FlushBatch(); - - // The name of this instance of the backend. - const std::string instance_name_; - - // The model configuration. - const ModelConfig model_config_; - - // The GPU device ID to execute on or CUSTOM_NO_GPU_DEVICE if should - // execute on CPU. - const int gpu_device_; - - // Models paths - std::string nnet3_rxfilename_, fst_rxfilename_; - std::string word_syms_rxfilename_; - - // batch_size - int max_batch_size_; - int num_channels_; - int num_worker_threads_; - std::vector batch_corr_ids_; - std::vector> batch_wave_samples_; - std::vector batch_is_first_chunk_; - std::vector batch_is_last_chunk_; - - BaseFloat sample_freq_, seconds_per_chunk_; - int chunk_num_bytes_, chunk_num_samps_; - - // feature_config includes configuration for the iVector adaptation, - // as well as the basic features. - kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipelineConfig - batched_decoder_config_; - std::unique_ptr - cuda_pipeline_; - // Maintain the state of some shared objects - kaldi::TransitionModel trans_model_; - - kaldi::nnet3::AmNnetSimple am_nnet_; - fst::SymbolTable* word_syms_; - - const uint64_t int32_byte_size_; - const uint64_t int64_byte_size_; - std::vector output_shape_; - - std::vector byte_buffer_; - std::vector> wave_byte_buffers_; -}; - -} // namespace kaldi_cbe -} // namespace custom -} // namespace inferenceserver -} // namespace nvidia