{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "Gwt7z7qdmTbW" }, "outputs": [], "source": [ "# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "i4NKCp2VmTbn" }, "source": [ "\n", "\n", "# Kaldi TRTIS Inference Offline Demo" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "fW0OKDzvmTbt" }, "source": [ "## Overview\n", "\n", "\n", "This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.\n", "\n", "This repository contains a TensorRT Inference Server custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This TensorRT Inference Server integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer. More information about the TensorRT Inference Server can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/). \n", "\n", "\n", "\n", "### Learning objectives\n", "\n", "This notebook demonstrates the steps for carrying out inferencing with the Kaldi TRTIS backend server using a Python gRPC client in an offline context, that is, we will stream pre-recorded .wav files to the inference server and receive the results back.\n", "\n", "## Content\n", "1. [Pre-requisite](#1)\n", "1. [Setup](#2)\n", "1. [Audio helper classes](#3)\n", "1. [Inference](#4)\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "aDFrE4eqmTbv" }, "source": [ "\n", "## 1. Pre-requisite\n", "\n", "### 1.1 Docker containers\n", "Follow the steps in [README](README.md) to build Kaldi server and client containers.\n", "\n", "### 1.2 Hardware\n", "This notebook can be executed on any CUDA-enabled NVIDIA GPU, although for efficient mixed precision inference, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "k7RLEcKhmTb0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Thu Mar 5 00:20:50 2020 \r\n", "+-----------------------------------------------------------------------------+\r\n", "| NVIDIA-SMI 440.48.02 Driver Version: 440.48.02 CUDA Version: 10.2 |\r\n", "|-------------------------------+----------------------+----------------------+\r\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n", "|===============================+======================+======================|\r\n", "| 0 Quadro GV100 Off | 00000000:05:00.0 Off | Off |\r\n", "| 33% 46C P2 37W / 250W | 17706MiB / 32506MiB | 1% Default |\r\n", "+-------------------------------+----------------------+----------------------+\r\n", " \r\n", "+-----------------------------------------------------------------------------+\r\n", "| Processes: GPU Memory |\r\n", "| GPU PID Type Process name Usage |\r\n", "|=============================================================================|\r\n", "+-----------------------------------------------------------------------------+\r\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HqSUGePjmTb9" }, "source": [ "### 1.3 Data download and preprocessing\n", "\n", "The script `scripts/docker/launch_download.sh` will download the LibriSpeech test dataset along with Kaldi ASR models." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "S2PR7weWmTcK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BOOKS.TXT LICENSE.TXT SPEAKERS.TXT test-other\r\n", "CHAPTERS.TXT README.TXT test-clean\r\n" ] } ], "source": [ "!ls /Kaldi/data/data/LibriSpeech " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EQAIszkxmTcT" }, "source": [ "Within the docker container, the final data and model directory should look like:\n", "\n", "```\n", "/Kaldi/data\n", " data\n", " datasets \n", " models\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2 Setup \n", "### Import libraries and parameters" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import numpy as np\n", "import os\n", "from builtins import range\n", "from functools import partial\n", "import soundfile\n", "import pyaudio as pa\n", "import soundfile\n", "import subprocess\n", "\n", "import grpc\n", "from tensorrtserver.api import api_pb2\n", "from tensorrtserver.api import grpc_service_pb2\n", "from tensorrtserver.api import grpc_service_pb2_grpc\n", "import tensorrtserver.api.model_config_pb2 as model_config" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "parser = argparse.ArgumentParser()\n", "parser.add_argument('-f', '--file', help='Path for input file. First line should contain number of lines to search in')\n", "\n", "parser.add_argument('-v', '--verbose', action=\"store_true\", required=False, default=False,\n", " help='Enable verbose output')\n", "parser.add_argument('-a', '--async', dest=\"async_set\", action=\"store_true\", required=False,\n", " default=False, help='Use asynchronous inference API')\n", "parser.add_argument('--streaming', action=\"store_true\", required=False, default=False,\n", " help='Use streaming inference API')\n", "parser.add_argument('-m', '--model-name', type=str, required=False, default='kaldi_online' ,\n", " help='Name of model')\n", "parser.add_argument('-x', '--model-version', type=int, required=False, default=1,\n", " help='Version of model. Default is to use latest version.')\n", "parser.add_argument('-b', '--batch-size', type=int, required=False, default=1,\n", " help='Batch size. Default is 1.')\n", "parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',\n", " help='Inference server URL. Default is localhost:8001.')\n", "FLAGS = parser.parse_args()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Checking server status\n", "\n", "We first query the status of the server. The target model is 'kaldi_online'. A successful deployment of the server should result in output similar to the below.\n", "\n", "```\n", "request_status {\n", " code: SUCCESS\n", " server_id: \"inference:0\"\n", " request_id: 17514\n", "}\n", "server_status {\n", " id: \"inference:0\"\n", " version: \"1.9.0\"\n", " uptime_ns: 14179155408971\n", " model_status {\n", " key: \"kaldi_online\"\n", "...\n", "```" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "request_status {\n", " code: SUCCESS\n", " server_id: \"inference:0\"\n", " request_id: 5864\n", "}\n", "server_status {\n", " id: \"inference:0\"\n", " version: \"1.9.0\"\n", " uptime_ns: 3596863561558\n", " model_status {\n", " key: \"kaldi_online\"\n", " value {\n", " config {\n", " name: \"kaldi_online\"\n", " platform: \"custom\"\n", " version_policy {\n", " latest {\n", " num_versions: 1\n", " }\n", " }\n", " max_batch_size: 2200\n", " input {\n", " name: \"WAV_DATA\"\n", " data_type: TYPE_FP32\n", " dims: 8160\n", " }\n", " input {\n", " name: \"WAV_DATA_DIM\"\n", " data_type: TYPE_INT32\n", " dims: 1\n", " }\n", " output {\n", " name: \"TEXT\"\n", " data_type: TYPE_STRING\n", " dims: 1\n", " }\n", " instance_group {\n", " name: \"kaldi_online_0\"\n", " count: 2\n", " gpus: 0\n", " kind: KIND_GPU\n", " }\n", " default_model_filename: \"libkaldi-trtisbackend.so\"\n", " sequence_batching {\n", " max_sequence_idle_microseconds: 5000000\n", " control_input {\n", " name: \"START\"\n", " control {\n", " int32_false_true: 0\n", " int32_false_true: 1\n", " }\n", " }\n", " control_input {\n", " name: \"READY\"\n", " control {\n", " kind: CONTROL_SEQUENCE_READY\n", " int32_false_true: 0\n", " int32_false_true: 1\n", " }\n", " }\n", " control_input {\n", " name: \"END\"\n", " control {\n", " kind: CONTROL_SEQUENCE_END\n", " int32_false_true: 0\n", " int32_false_true: 1\n", " }\n", " }\n", " control_input {\n", " name: \"CORRID\"\n", " control {\n", " kind: CONTROL_SEQUENCE_CORRID\n", " data_type: TYPE_UINT64\n", " }\n", " }\n", " oldest {\n", " max_candidate_sequences: 2200\n", " preferred_batch_size: 256\n", " preferred_batch_size: 512\n", " max_queue_delay_microseconds: 1000\n", " }\n", " }\n", " parameters {\n", " key: \"acoustic_scale\"\n", " value {\n", " string_value: \"1.0\"\n", " }\n", " }\n", " parameters {\n", " key: \"beam\"\n", " value {\n", " string_value: \"10\"\n", " }\n", " }\n", " parameters {\n", " key: \"frame_subsampling_factor\"\n", " value {\n", " string_value: \"3\"\n", " }\n", " }\n", " parameters {\n", " key: \"fst_rxfilename\"\n", " value {\n", " string_value: \"/data/models/LibriSpeech/HCLG.fst\"\n", " }\n", " }\n", " parameters {\n", " key: \"ivector_filename\"\n", " value {\n", " string_value: \"/data/models/LibriSpeech/conf/ivector_extractor.conf\"\n", " }\n", " }\n", " parameters {\n", " key: \"lattice_beam\"\n", " value {\n", " string_value: \"7\"\n", " }\n", " }\n", " parameters {\n", " key: \"max_active\"\n", " value {\n", " string_value: \"10000\"\n", " }\n", " }\n", " parameters {\n", " key: \"max_execution_batch_size\"\n", " value {\n", " string_value: \"512\"\n", " }\n", " }\n", " parameters {\n", " key: \"mfcc_filename\"\n", " value {\n", " string_value: \"/data/models/LibriSpeech/conf/mfcc.conf\"\n", " }\n", " }\n", " parameters {\n", " key: \"nnet3_rxfilename\"\n", " value {\n", " string_value: \"/data/models/LibriSpeech/final.mdl\"\n", " }\n", " }\n", " parameters {\n", " key: \"num_worker_threads\"\n", " value {\n", " string_value: \"40\"\n", " }\n", " }\n", " parameters {\n", " key: \"word_syms_rxfilename\"\n", " value {\n", " string_value: \"/data/models/LibriSpeech/words.txt\"\n", " }\n", " }\n", " }\n", " version_status {\n", " key: 1\n", " value {\n", " ready_state: MODEL_READY\n", " infer_stats {\n", " key: 1\n", " value {\n", " success {\n", " count: 6215\n", " total_time_ns: 207922970400\n", " }\n", " compute {\n", " count: 6215\n", " total_time_ns: 201127749263\n", " }\n", " queue {\n", " count: 6215\n", " total_time_ns: 6651959011\n", " }\n", " }\n", " }\n", " model_execution_count: 6215\n", " model_inference_count: 6215\n", " ready_state_reason {\n", " }\n", " last_inference_timestamp_milliseconds: 13273904861447035456\n", " }\n", " }\n", " }\n", " }\n", " ready_state: SERVER_READY\n", "}\n", "\n" ] } ], "source": [ "# Create gRPC stub for communicating with the server\n", "channel = grpc.insecure_channel(FLAGS.url)\n", "grpc_stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)\n", "\n", "# Prepare request for Status gRPC\n", "request = grpc_service_pb2.StatusRequest(model_name=FLAGS.model_name)\n", "# Call and receive response from Status gRPC\n", "response = grpc_stub.Status(request)\n", "\n", "print(response)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RL8d9IwzmTcV" }, "source": [ "\n", "## 3. Audio helper classes" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "o6wayGf1mTcX" }, "source": [ "Next, we define some helper classes for pre-processing audio from files. The below AudioSegment class reads audio data from .wav files and converts the sampling rate to that required by the Kaldi ASR model, which is 16000Hz by default.\n", "\n", "Note: For historical reasons, Kaldi expects waveforms in the range (2^15-1)x[-1, 1], not the usual default DSP range [-1, 1]. Therefore, we scale the audio signal by a factor of (2^15-1)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "WAV_SCALE_FACTOR = 2**15-1\n", "\n", "class AudioSegment(object):\n", " \"\"\"Monaural audio segment abstraction.\n", " :param samples: Audio samples [num_samples x num_channels].\n", " :type samples: ndarray.float32\n", " :param sample_rate: Audio sample rate.\n", " :type sample_rate: int\n", " :raises TypeError: If the sample data type is not float or int.\n", " \"\"\"\n", "\n", " def __init__(self, samples, sample_rate, target_sr=16000, trim=False,\n", " trim_db=60):\n", " \"\"\"Create audio segment from samples.\n", " Samples are convert float32 internally, with int scaled to [-1, 1].\n", " \"\"\"\n", " samples = self._convert_samples_to_float32(samples)\n", " if target_sr is not None and target_sr != sample_rate:\n", " samples = librosa.core.resample(samples, sample_rate, target_sr)\n", " sample_rate = target_sr\n", " if trim:\n", " samples, _ = librosa.effects.trim(samples, trim_db)\n", " self._samples = samples\n", " self._sample_rate = sample_rate\n", " if self._samples.ndim >= 2:\n", " self._samples = np.mean(self._samples, 1)\n", "\n", " @staticmethod\n", " def _convert_samples_to_float32(samples):\n", " \"\"\"Convert sample type to float32.\n", " Audio sample type is usually integer or float-point.\n", " Integers will be scaled to [-1, 1] in float32.\n", " \"\"\"\n", " float32_samples = samples.astype('float32')\n", " if samples.dtype in np.sctypes['int']:\n", " bits = np.iinfo(samples.dtype).bits\n", " float32_samples *= (1. / ((2 ** (bits - 1)) - 1))\n", " elif samples.dtype in np.sctypes['float']:\n", " pass\n", " else:\n", " raise TypeError(\"Unsupported sample type: %s.\" % samples.dtype)\n", " return WAV_SCALE_FACTOR * float32_samples\n", "\n", " @classmethod\n", " def from_file(cls, filename, target_sr=16000, offset=0, duration=0,\n", " min_duration=0, trim=False):\n", " \"\"\"\n", " Load a file supported by librosa and return as an AudioSegment.\n", " :param filename: path of file to load\n", " :param target_sr: the desired sample rate\n", " :param int_values: if true, load samples as 32-bit integers\n", " :param offset: offset in seconds when loading audio\n", " :param duration: duration in seconds when loading audio\n", " :return: numpy array of samples\n", " \"\"\"\n", " with sf.SoundFile(filename, 'r') as f:\n", " dtype_options = {'PCM_16': 'int16', 'PCM_32': 'int32', 'FLOAT': 'float32'}\n", " dtype_file = f.subtype\n", " if dtype_file in dtype_options:\n", " dtype = dtype_options[dtype_file]\n", " else:\n", " dtype = 'float32'\n", " sample_rate = f.samplerate\n", " if offset > 0:\n", " f.seek(int(offset * sample_rate))\n", " if duration > 0:\n", " samples = f.read(int(duration * sample_rate), dtype=dtype)\n", " else:\n", " samples = f.read(dtype=dtype)\n", "\n", " num_zero_pad = int(target_sr * min_duration - samples.shape[0])\n", " if num_zero_pad > 0:\n", " samples = np.pad(samples, [0, num_zero_pad], mode='constant')\n", "\n", " samples = samples.transpose()\n", " return cls(samples, sample_rate, target_sr=target_sr, trim=trim)\n", "\n", " @property\n", " def samples(self):\n", " return self._samples.copy()\n", "\n", " @property\n", " def sample_rate(self):\n", " return self._sample_rate" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# read audio chunk from a file\n", "def get_audio_chunk_from_soundfile(sf, chunk_size):\n", "\n", " dtype_options = {'PCM_16': 'int16', 'PCM_32': 'int32', 'FLOAT': 'float32'}\n", " dtype_file = sf.subtype\n", " if dtype_file in dtype_options:\n", " dtype = dtype_options[dtype_file]\n", " else:\n", " dtype = 'float32'\n", " audio_signal = sf.read(chunk_size, dtype=dtype)\n", " end = False\n", " # pad to chunk size\n", " if len(audio_signal) < chunk_size:\n", " end = True\n", " audio_signal = np.pad(audio_signal, (0, chunk_size-len(\n", " audio_signal)), mode='constant')\n", " return audio_signal, end\n", "\n", "\n", "# generator that returns chunks of audio data from file\n", "def audio_generator_from_file(input_filename, target_sr, chunk_duration):\n", "\n", " sf = soundfile.SoundFile(input_filename, 'rb')\n", " chunk_size = int(chunk_duration*sf.samplerate)\n", " start = True\n", " end = False\n", "\n", " while not end:\n", "\n", " audio_signal, end = get_audio_chunk_from_soundfile(sf, chunk_size)\n", "\n", " audio_segment = AudioSegment(audio_signal, sf.samplerate, target_sr)\n", "\n", " yield audio_segment.samples, target_sr, start, end\n", " start = False\n", "\n", " sf.close()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading data\n", "\n", "We load and play a wave file from the LibriSpeech data set. The LibriSpeech data set is organized into directories and subdirectories containing speech segments and transcripts for different speakers. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1089-134686-0000.flac 1089-134686-0013.flac 1089-134686-0026.flac\r\n", "1089-134686-0000.wav 1089-134686-0013.wav 1089-134686-0026.wav\r\n", "1089-134686-0001.flac 1089-134686-0014.flac 1089-134686-0027.flac\r\n", "1089-134686-0001.wav 1089-134686-0014.wav 1089-134686-0027.wav\r\n", "1089-134686-0002.flac 1089-134686-0015.flac 1089-134686-0028.flac\r\n", "1089-134686-0002.wav 1089-134686-0015.wav 1089-134686-0028.wav\r\n", "1089-134686-0003.flac 1089-134686-0016.flac 1089-134686-0029.flac\r\n", "1089-134686-0003.wav 1089-134686-0016.wav 1089-134686-0029.wav\r\n", "1089-134686-0004.flac 1089-134686-0017.flac 1089-134686-0030.flac\r\n", "1089-134686-0004.wav 1089-134686-0017.wav 1089-134686-0030.wav\r\n", "1089-134686-0005.flac 1089-134686-0018.flac 1089-134686-0031.flac\r\n", "1089-134686-0005.wav 1089-134686-0018.wav 1089-134686-0031.wav\r\n", "1089-134686-0006.flac 1089-134686-0019.flac 1089-134686-0032.flac\r\n", "1089-134686-0006.wav 1089-134686-0019.wav 1089-134686-0032.wav\r\n", "1089-134686-0007.flac 1089-134686-0020.flac 1089-134686-0033.flac\r\n", "1089-134686-0007.wav 1089-134686-0020.wav 1089-134686-0033.wav\r\n", "1089-134686-0008.flac 1089-134686-0021.flac 1089-134686-0034.flac\r\n", "1089-134686-0008.wav 1089-134686-0021.wav 1089-134686-0034.wav\r\n", "1089-134686-0009.flac 1089-134686-0022.flac 1089-134686-0035.flac\r\n", "1089-134686-0009.wav 1089-134686-0022.wav 1089-134686-0035.wav\r\n", "1089-134686-0010.flac 1089-134686-0023.flac 1089-134686-0036.flac\r\n", "1089-134686-0010.wav 1089-134686-0023.wav 1089-134686-0036.wav\r\n", "1089-134686-0011.flac 1089-134686-0024.flac 1089-134686-0037.flac\r\n", "1089-134686-0011.wav 1089-134686-0024.wav 1089-134686-0037.wav\r\n", "1089-134686-0012.flac 1089-134686-0025.flac 1089-134686.trans.txt\r\n", "1089-134686-0012.wav 1089-134686-0025.wav\r\n" ] } ], "source": [ "!ls /Kaldi/data/data/LibriSpeech/test-clean/1089/134686/" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce\n", "\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "DIR = \"1089\"\n", "SUBDIR = \"134686\"\n", "FILE_ID = \"0000\"\n", "FILE_NAME = \"/Kaldi/data/data/LibriSpeech/test-clean/%s/%s/%s-%s-%s.wav\"%(DIR, SUBDIR, DIR, SUBDIR, FILE_ID)\n", "TRANSRIPTION_FILE = \"/Kaldi/data/data/LibriSpeech/test-clean/%s/%s/%s-%s.trans.txt\"%(DIR, SUBDIR, DIR, SUBDIR)\n", "\n", "batcmd = \"cat %s|grep %s\"%(TRANSRIPTION_FILE, FILE_ID)\n", "res = subprocess.check_output(batcmd, shell=True)\n", "transcript = \" \".join(res.decode('utf-8').split(\" \")[1:]).lower()\n", "\n", "print(transcript)\n", "import IPython.display as ipd\n", "ipd.Audio(FILE_NAME)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define a helper function which generate pairs of filepath and transcript from a LibriSpeech data directory." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def libri_generator(DATASET_ROOT):\n", " for subdir in os.listdir(DATASET_ROOT):\n", " SUBDIR = os.path.join(DATASET_ROOT, subdir)\n", " if os.path.isdir(os.path.join(DATASET_ROOT, subdir)):\n", " for subsubdir in os.listdir(SUBDIR):\n", " SUBSUBDIR = os.path.join(SUBDIR, subsubdir)\n", " #print(os.listdir(SUBSUBDIR))\n", " transcription_file = os.path.join(DATASET_ROOT, SUBDIR, SUBSUBDIR, \"%s-%s.trans.txt\"%(subdir, subsubdir))\n", " transcriptions = {}\n", " #pdb.set_trace()\n", " with open(transcription_file, \"r\") as f:\n", " for line in f:\n", " fields = line.split(\" \")\n", " transcriptions[fields[0]] = \" \".join(fields[1:])\n", " for file_key, transcript in transcriptions.items():\n", " file_path = os.path.join(DATASET_ROOT, SUBDIR, SUBSUBDIR, file_key+'.wav')\n", " yield file_path, transcript.strip().lower()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "datagen = libri_generator(\"/Kaldi/data/data/LibriSpeech/test-clean/\")\n", "filepath, transcript = next(datagen)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "natty harmon tried the kitchen pump secretly several times during the evening for the water had to run up hill all the way from the well to the kitchen sink and he believed this to be a continual miracle that might give out at any moment\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(transcript)\n", "import IPython.display as ipd\n", "ipd.Audio(filepath)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Inference\n", "\n", "We first create an inference context object that connects to the Kaldi TRTIS servier via a gPRC connection.\n", "\n", "The server expects chunks of audio each containing up to input.WAV_DATA.dims samples (default: 8160). Per default, this corresponds to 510ms of audio per chunk (i.e. 16000Hz sampling rate). The last chunk can send a partial chunk smaller than this maximum value." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from tensorrtserver.api import *\n", "protocol = ProtocolType.from_str(\"grpc\")\n", "\n", "CORRELATION_ID = 1101\n", "ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,\n", " correlation_id=CORRELATION_ID, verbose=True,\n", " streaming=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we take chunks from a selected audio file (each 510ms in duration, containing 8160 samples) and stream them sequentially to the Kaldi server. The server processes each chunk as soon as it is received. The transcription result is returned upon receiving the final chunk.\n", "\n", "In the following, you can either specify a specific .wav file or take a file via the Kaldi dataset generator." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Chunk 0 (8160,) 16000 True False\n", "Chunk 1 (8160,) 16000 False False\n", "Chunk 2 (8160,) 16000 False False\n", "Chunk 3 (8160,) 16000 False False\n", "Chunk 4 (8160,) 16000 False False\n", "Chunk 5 (8160,) 16000 False False\n", "Chunk 6 (8160,) 16000 False False\n", "Chunk 7 (8160,) 16000 False False\n", "Chunk 8 (8160,) 16000 False False\n", "Chunk 9 (8160,) 16000 False False\n", "Chunk 10 (8160,) 16000 False False\n", "Chunk 11 (8160,) 16000 False False\n", "Chunk 12 (8160,) 16000 False False\n", "Chunk 13 (8160,) 16000 False False\n", "Chunk 14 (8160,) 16000 False False\n", "Chunk 15 (8160,) 16000 False False\n", "Chunk 16 (8160,) 16000 False False\n", "Chunk 17 (8160,) 16000 False False\n", "Chunk 18 (8160,) 16000 False False\n", "Chunk 19 (8160,) 16000 False False\n", "Chunk 20 (8160,) 16000 False True\n", "ASR output: he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flower fatten sauce \n", "Ground truth: he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce\n", "\n" ] } ], "source": [ "## Take a specific file\n", "DIR = \"1089\"\n", "SUBDIR = \"134686\"\n", "FILE_ID = \"0000\"\n", "FILE_NAME = \"/Kaldi/data/data/LibriSpeech/test-clean/%s/%s/%s-%s-%s.wav\"%(DIR, SUBDIR, DIR, SUBDIR, FILE_ID)\n", "TRANSRIPTION_FILE = \"/Kaldi/data/data/LibriSpeech/test-clean/%s/%s/%s-%s.trans.txt\"%(DIR, SUBDIR, DIR, SUBDIR)\n", "batcmd = \"cat %s|grep %s\"%(TRANSRIPTION_FILE, FILE_ID)\n", "res = subprocess.check_output(batcmd, shell=True)\n", "transcript = \" \".join(res.decode('utf-8').split(\" \")[1:]).lower()\n", "\n", "\n", "## Alternatively, take a file from the data generator\n", "#FILE_NAME, transcript = next(datagen)\n", "\n", "cnt = 0\n", "for audio_chunk in audio_generator_from_file(FILE_NAME, 16000, 0.51):\n", " print(\"Chunk \", cnt, audio_chunk[0].shape, audio_chunk[1], audio_chunk[2], audio_chunk[3])\n", " cnt += 1\n", " flags = InferRequestHeader.FLAG_NONE\n", " if audio_chunk[2]:\n", " flags = flags | InferRequestHeader.FLAG_SEQUENCE_START\n", " if audio_chunk[3]:\n", " flags = flags | InferRequestHeader.FLAG_SEQUENCE_END\n", "\n", " if not audio_chunk[3]: # if not end of sequence\n", " ctx.run({'WAV_DATA' : (audio_chunk[0],),\n", " 'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(audio_chunk[0]), dtype=np.int32),)\n", " },\n", " {},\n", " batch_size=1, \n", " flags=flags,\n", " corr_id=CORRELATION_ID)\n", " else:\n", "\n", " result = ctx.run({'WAV_DATA' : (audio_chunk[0],),\n", " 'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(audio_chunk[0]), dtype=np.int32),)\n", " },\n", " { 'TEXT' : InferContext.ResultFormat.RAW},\n", " batch_size=1, \n", " flags=flags,\n", " corr_id=CORRELATION_ID)\n", "print(\"ASR output: %s\" % \"\".join([c.decode('utf-8') for c in result['TEXT'][0]]).lower())\n", "if transcript:\n", " print(\"Ground truth: %s\"%transcript.lower())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "g8MxXY5GmTc8" }, "source": [ "# Conclusion\n", "\n", "In this notebook, we have walked through the complete process of preparing the audio data and carry out inference with the Kaldi ASR model.\n", "\n", "## What's next\n", "Now it's time to try the Kaldi ASR model on your own data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "249yGNLmmTc_" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "include_colab_link": true, "name": "TensorFlow_UNet_Industrial_Colab_train_and_inference.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 1 }