4e544676f2
Signed-off-by: smajumdar <titu1994@gmail.com>
134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import contextlib
|
|
import glob
|
|
import os
|
|
|
|
import torch
|
|
|
|
import nemo.collections.asr as nemo_asr
|
|
from nemo.utils import logging, model_utils
|
|
|
|
# setup AMP (optional)
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
|
|
logging.info("AMP enabled!\n")
|
|
autocast = torch.cuda.amp.autocast
|
|
else:
|
|
|
|
@contextlib.contextmanager
|
|
def autocast():
|
|
yield
|
|
|
|
|
|
MODEL_CACHE = {}
|
|
|
|
# Special tags for fallbacks / user notifications
|
|
TAG_ERROR_DURING_TRANSCRIPTION = "<ERROR_DURING_TRANSCRIPTION>"
|
|
|
|
|
|
def get_model_names():
|
|
# Populate local copy of models
|
|
local_model_paths = glob.glob(os.path.join('models', "**", "*.nemo"), recursive=True)
|
|
local_model_names = list(sorted([os.path.basename(path) for path in local_model_paths]))
|
|
|
|
# Populate with pretrained nemo checkpoint list
|
|
nemo_model_names = set()
|
|
for model_info in nemo_asr.models.ASRModel.list_available_models():
|
|
for superclass in model_info.class_.mro():
|
|
if 'CTC' in superclass.__name__ or 'RNNT' in superclass.__name__:
|
|
if 'align' in model_info.pretrained_model_name:
|
|
continue
|
|
|
|
nemo_model_names.add(model_info.pretrained_model_name)
|
|
nemo_model_names = list(sorted(nemo_model_names))
|
|
return nemo_model_names, local_model_names
|
|
|
|
|
|
def initialize_model(model_name):
|
|
# load model
|
|
if model_name not in MODEL_CACHE:
|
|
if '.nemo' in model_name:
|
|
# use local model
|
|
model_name_no_ext = os.path.splitext(model_name)[0]
|
|
model_path = os.path.join('models', model_name_no_ext, model_name)
|
|
|
|
# Extract config
|
|
model_cfg = nemo_asr.models.ASRModel.restore_from(restore_path=model_path, return_config=True)
|
|
classpath = model_cfg.target # original class path
|
|
imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel
|
|
logging.info(f"Restoring local model : {imported_class.__name__}")
|
|
|
|
# load model from checkpoint
|
|
model = imported_class.restore_from(restore_path=model_path, map_location='cpu') # type: ASRModel
|
|
|
|
else:
|
|
# use pretrained model
|
|
model = nemo_asr.models.ASRModel.from_pretrained(model_name, map_location='cpu')
|
|
|
|
model.freeze()
|
|
|
|
# cache model
|
|
MODEL_CACHE[model_name] = model
|
|
|
|
model = MODEL_CACHE[model_name]
|
|
return model
|
|
|
|
|
|
def transcribe_all(filepaths, model_name, use_gpu_if_available=True):
|
|
# instantiate model
|
|
if model_name in MODEL_CACHE:
|
|
model = MODEL_CACHE[model_name]
|
|
else:
|
|
model = initialize_model(model_name)
|
|
|
|
if torch.cuda.is_available() and use_gpu_if_available:
|
|
model = model.cuda()
|
|
|
|
# transcribe audio
|
|
logging.info("Begin transcribing audio...")
|
|
try:
|
|
with autocast():
|
|
with torch.no_grad():
|
|
transcriptions = model.transcribe(filepaths, batch_size=32)
|
|
|
|
except RuntimeError:
|
|
# Purge the cache to clear some memory
|
|
MODEL_CACHE.clear()
|
|
|
|
logging.info("Ran out of memory on device - performing inference on CPU for now")
|
|
|
|
try:
|
|
model = model.cpu()
|
|
|
|
with torch.no_grad():
|
|
transcriptions = model.transcribe(filepaths, batch_size=32)
|
|
|
|
except Exception as e:
|
|
logging.info(f"Exception {e} occured while attemting to transcribe audio. Returning error message")
|
|
return TAG_ERROR_DURING_TRANSCRIPTION
|
|
|
|
logging.info(f"Finished transcribing {len(filepaths)} files !")
|
|
|
|
# If RNNT models transcribe, they return a tuple (greedy, beam_scores)
|
|
if type(transcriptions[0]) == list and len(transcriptions) == 2:
|
|
# get greedy transcriptions only
|
|
transcriptions = transcriptions[0]
|
|
|
|
# Force onto CPU
|
|
model = model.cpu()
|
|
MODEL_CACHE[model_name] = model
|
|
|
|
return transcriptions
|