Modify speaker input (#3100)

* initial_commit

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* init diarizer

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* vad+speaker

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* vad update

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* speaker done

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* initial working version

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* compare outputs

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* added uem support

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* pyannote improvements

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* updated config and script name

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* style fix

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update Jenkins file

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* jenkins fix

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* jenkins fix

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update file path in jenkins

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update file path in jenkins

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update file path in jenkins

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* jenkins quote fix

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update offline speaker diarization notebook

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* intial working asr_with_diarization

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* almost done, revist scoring part

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fixed eval in offline diarization with asr

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update write2manifest to consider only up to max audio duration

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* asr with diarization notebook

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Fixed ASR_with_diarization tutorial.ipynb and diarization_utils and edited config yaml file

Signed-off-by: Taejin Park <tango4j@gmail.com>

* Fixed VAD parameters in Speaker_Diarization_Inference.ipynb

Signed-off-by: Taejin Park <tango4j@gmail.com>

* Added Jenkins test, doc strings and updated README

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update jenkins test

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Doc info in offline_diarization_with_asr

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Review comments

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update outdir paths

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

Co-authored-by: Taejin Park <tango4j@gmail.com>
This commit is contained in:
Nithin Rao 2021-11-06 07:55:32 -07:00 committed by GitHub
parent 9ec02280d0
commit dc9ed88f78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 1026 additions and 1081 deletions

18
Jenkinsfile vendored
View file

@ -315,10 +315,8 @@ pipeline {
stage('Speaker Diarization Inference') {
steps {
sh 'python examples/speaker_tasks/diarization/speaker_diarize.py \
diarizer.oracle_num_speakers=2 \
diarizer.paths2audio_files=/home/TestData/an4_diarizer/audio_files.scp \
diarizer.path2groundtruth_rttm_files=/home/TestData/an4_diarizer/rttm_files.scp \
sh 'python examples/speaker_tasks/diarization/offline_diarization.py \
diarizer.manifest_filepath=/home/TestData/an4_diarizer/an4_manifest.json \
diarizer.speaker_embeddings.model_path=/home/TestData/an4_diarizer/spkr.nemo \
diarizer.vad.model_path=/home/TestData/an4_diarizer/MatchboxNet_VAD_3x2.nemo \
diarizer.out_dir=examples/speaker_tasks/diarization/speaker_diarization_results'
@ -326,6 +324,18 @@ pipeline {
}
}
stage('Speaker Diarization with ASR Inference') {
steps {
sh 'python examples/speaker_tasks/diarization/offline_diarization_with_asr.py \
diarizer.manifest_filepath=/home/TestData/an4_diarizer/an4_manifest.json \
diarizer.speaker_embeddings.model_path=/home/TestData/an4_diarizer/spkr.nemo \
diarizer.asr.model_path=QuartzNet15x5Base-En \
diarizer.asr.parameters.asr_based_vad=True \
diarizer.out_dir=examples/speaker_tasks/diarization/speaker_diarization_asr_results'
sh 'rm -rf examples/speaker_tasks/diarization/speaker_diarization_asr_results'
}
}
stage('L2: Speech to Text WPE - CitriNet') {
steps {
sh 'python examples/asr/speech_to_text_bpe.py \

View file

@ -51,7 +51,7 @@ Key Features
* `Speech Classification and Speech Command Recognition <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speech_classification/intro.html>`_: MatchboxNet (Command Recognition)
* `Voice activity Detection (VAD) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/speech_classification/models.html#marblenet-vad>`_: MarbleNet
* `Speaker Recognition <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speaker_recognition/intro.html>`_: SpeakerNet, ECAPA_TDNN
* `Speaker Diarization <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speaker_diarization/intro.html>`_: SpeakerNet
* `Speaker Diarization <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speaker_diarization/intro.html>`_: SpeakerNet, ECAPA_TDNN
* `Pretrained models on different languages. <https://ngc.nvidia.com/catalog/collections/nvidia:nemo_asr>`_: English, Spanish, German, Russian, Chinese, French, Italian, Polish, ...
* `NGC collection of pre-trained speech processing models. <https://ngc.nvidia.com/catalog/collections/nvidia:nemo_asr>`_
* Natural Language Processing

View file

@ -30,55 +30,38 @@ Diarization Error Rate (DER) table of `ecapa_tdnn.nemo` model on well known eval
#### Example script
```bash
python speaker_diarize.py \
diarizer.paths2audio_files='my_wav.list' \
diarizer.speaker_embeddings.model_path='ecapa_tdnn' \
diarizer.oracle_num_speakers=null \
diarizer.vad.model_path='vad_telephony_marblenet'
python offline_diarization.py \
diarizer.manifest_filepath=<path to manifest file> \
diarizer.out_dir='demo_output' \
diarizer.speaker_embeddings.model_path=<pretrained modelname or path to .nemo> \
diarizer.vad.model_path=<pretrained modelname or path to .nemo>
```
If you have oracle VAD files and groundtruth RTTM files for evaluation:
Provide rttm files in the input manifest file and enable oracle_vad as shown below.
```bash
python speaker_diarize.py \
diarizer.paths2audio_files='my_wav.list' \
diarizer.speaker_embeddings.model_path='path/to/ecapa_tdnn.nemo' \
diarizer.oracle_num_speakers= null \
diarizer.speaker_embeddings.oracle_vad_manifest='oracle_vad.manifest' \
diarizer.path2groundtruth_rttm_files='my_wav_rttm.list'
python offline_diarization.py \
python speaker_diarize.py \
diarizer.manifest_filepath=<path to manifest file> \
diarizer.out_dir='demo_output' \
diarizer.speaker_embeddings.model_path=<pretrained modelname or path to .nemo> \
diarizer.oracle_vad=True
```
#### Arguments
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; To run speaker diarization on your audio recordings, you need to prepare the following files.
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; To run speaker diarization on your audio recordings, you need to prepare the following file.
- **`diarizer.paths2audio_files`: audio file list**
- **`diarizer.manifest_filepath`: <manifest file>** Path to manifest file
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; Provide a list of full path names to the audio files you want to diarize.
Example: `my_wav.list`
Example: `manifest.json`
```bash
/path/to/my_audio1.wav
/path/to/my_audio2.wav
/path/to/my_audio3.wav
{"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, label: "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"}
```
Mandatory fields are `audio_filepath`, `offset`, `duration`, `label:"infer"` and `text: <ground truth or "-" >` , and the rest are optional keys which can be passed based on the type of evaluation
- **`diarizer.oracle_num_speakers` : number of speakers in the audio recording**
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; If you know how many speakers are in the recordings, you can provide the number of speakers as follows.
Example: `number_of_speakers.list`
```
my_audio1 7
my_audio2 6
my_audio3 5
<session_name> <number of speakers>
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `diarizer.oracle_num_speakers='number_of_speakers.list'`
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; If all sessions have the same number of speakers, you can specify the option as:
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `diarizer.oracle_num_speakers=5`
Some of important options in config file:
- **`diarizer.speaker_embeddings.model_path`: speaker embedding model name**
@ -101,24 +84,11 @@ my_audio3 5
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `diarizer.vad.model_path='path/to/vad_telephony_marblenet.nemo'`
- **`diarizer.path2groundtruth_rttm_files`: reference RTTM files for evaluating the diarization output**
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; You can provide a list of rttm files for evaluating the diarization output.
Example: `my_wav_rttm.list`
```
/path/to/my_audio1.rttm
/path/to/my_audio2.rttm
/path/to/my_audio3.rttm
<full path to *.rttm file>
```
<br/>
## Run Speech Recognition with Speaker Diarization
Using the script `asr_with_diarization.py`, you can transcribe your audio recording with speaker labels as shown below:
Using the script `offline_diarization_with_asr.py`, you can transcribe your audio recording with speaker labels as shown below:
```
[00:03.34 - 00:04.46] speaker_0: back from the gym oh good how's it going
@ -126,51 +96,49 @@ Using the script `asr_with_diarization.py`, you can transcribe your audio record
[00:12.10 - 00:13.97] speaker_0: well it's the middle of the week or whatever so
```
Currently, asr_with_diarization supports QuartzNet English model and ConformerCTC model (`QuartzNet15x5Base`, `stt_en_conformer_ctc_large`).
Currently, offline_diarization_with_asr supports QuartzNet English model and ConformerCTC model (`QuartzNet15x5Base-En`, `stt_en_conformer_ctc_large`).
#### Example script
```bash
python asr_with_diarization.py \
--asr_based_vad \
--pretrained_speaker_model='ecapa_tdnn' \
--audiofile_list_path='my_wav.list' \
python offline_diarization_with_asr.py \
diarizer.manifest_filepath=<path to manifest file> \
diarizer.out_dir='demo_asr_output' \
diarizer.speaker_embeddings.model_path=<pretrained modelname or path to .nemo> \
diarizer.asr.model_path=<pretrained modelname or path to .nemo> \
diarizer.asr.parameters.asr_based_vad=True
```
If you have reference rttm files or oracle number of speaker information, you can provide those files as in the following example.
If you have reference rttm files or oracle number of speaker information, you can provide those file paths and number of speakers in the manifest file path and pass `diarizer.clustering.parameters.oracle_num_speakers=True` as shown in the following example.
```bash
python asr_with_diarization.py \
--asr_based_vad \
--pretrained_speaker_model='ecapa_tdnn' \
--audiofile_list_path='my_wav.list' \
--reference_rttmfile_list_path='my_wav_rttm.list'\
--oracle_num_speakers=number_of_speakers.list
python offline_diarization_with_asr.py \
diarizer.manifest_filepath=<path to manifest file> \
diarizer.out_dir='demo_asr_output' \
diarizer.speaker_embeddings.model_path=<pretrained modelname or path to .nemo> \
diarizer.asr.model_path=<pretrained modelname or path to .nemo> \
diarizer.asr.parameters.asr_based_vad=True \
diarizer.clustering.parameters.oracle_num_speakers=True
```
#### Output folders
The above script will create a folder named `./asr_with_diar/`.
In `./asr_with_diar/`, you can check the results as below.
The above script will create a folder named `./demo_asr_output/`.
In `./demo_asr_output/`, you can check the results as below.
```bash
./asr_with_diar
├── json_result
│ └── my_audio1.json
│ └── my_audio2.json
│ └── my_audio3.json
└── transcript_with_speaker_labels
├── pred_rttms
└── my_audio1.json
└── my_audio1.txt
└── my_audio2.txt
└── my_audio3.txt
└── my_audio1.rttm
└── ...
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `json_result` folder includes word-by-word json output with speaker label and time stamps.
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `*.json` files contains word-by-word json output with speaker label and time stamps.
Example: `./asr_with_diar/json_result/my_audio1.json`
Example: `./demo_asr_output/pred_rttms/my_audio1.json`
```bash
{
"status": "Success",
@ -199,9 +167,9 @@ Example: `./asr_with_diar/json_result/my_audio1.json`
},
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `transcript_with_speaker_labels` folder includes transcription with speaker labels and corresponding time.
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; `*.txt` files contain transcription with speaker labels and corresponding time.
Example: `./asr_with_diar/transcript_with_speaker_labels/my_audio1.txt`
Example: `./demo_asr_output/pred_rttms/my_audio1.txt`
```
[00:03.34 - 00:04.46] speaker_0: back from the gym oh good how's it going
[00:04.46 - 00:09.96] speaker_1: pretty well it was really crowded today yeah i kind of assumed everylonewould be at the shore uhhuh
@ -209,4 +177,3 @@ Example: `./asr_with_diar/transcript_with_speaker_labels/my_audio1.txt`
[00:13.97 - 00:15.78] speaker_1: but it's the fourth of july mm
[00:16.90 - 00:21.80] speaker_0: so yeahg people still work tomorrow do you have to work tomorrow did you drive off yesterday
```

View file

@ -1,152 +0,0 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from nemo.collections.asr.parts.utils.diarization_utils import ASR_DIAR_OFFLINE, read_file_paths
from nemo.utils import logging
"""
Currently Supported ASR models:
QuartzNet15x5Base-En
stt_en_conformer_ctc_large
"""
DEFAULT_URL = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml"
parser = argparse.ArgumentParser()
parser.add_argument(
"--audiofile_list_path", type=str, help="Fullpath of a file contains the list of audio files", required=True
)
parser.add_argument(
"--reference_rttmfile_list_path", default=None, type=str, help="Fullpath of a file contains the list of rttm files"
)
parser.add_argument(
"--output_path", default=None, type=str, help="Path to the folder where output files are generated"
)
parser.add_argument("--pretrained_vad_model", default=None, type=str, help="Fullpath of the VAD model (*.nemo).")
parser.add_argument(
"--external_vad_manifest", default=None, type=str, help="External VAD output manifest for diarization"
)
parser.add_argument("--asr_based_vad", default=False, action='store_true', help="Use ASR-based VAD")
parser.add_argument(
'--generate_oracle_manifest', default=False, action='store_true', help="Use RTTM ground truth as VAD input"
)
parser.add_argument(
"--pretrained_speaker_model",
type=str,
help="Fullpath of the Speaker embedding extractor model (*.nemo).",
required=True,
)
parser.add_argument("--oracle_num_speakers", help="Either int or text file that contains number of speakers")
parser.add_argument("--threshold", default=10, type=int, help="Threshold for ASR based VAD")
parser.add_argument(
"--diar_config_url", default=DEFAULT_URL, type=str, help="Config yaml file for running speaker diarization"
)
parser.add_argument("--csv", default='result.csv', type=str, help="")
args = parser.parse_args()
vad_choices = [
args.asr_based_vad,
args.pretrained_vad_model,
args.external_vad_manifest,
args.generate_oracle_manifest,
]
if not sum(bool(c) for c in vad_choices) == 1:
raise ValueError(
"Please provide ONE and ONLY ONE method for VAD. \n \
(1) External manifest external_vad_manifest \n \
(2) Use a pretrained_vad_model \n \
(3) Use ASR-based VAD or \n \
(4) Provide reference_rttmfile_list_path so we can automatically generate oracle manifest for diarization"
)
if args.generate_oracle_manifest and not args.reference_rttmfile_list_path:
raise ValueError(
"Please provide reference_rttmfile_list_path so we can generate oracle manifest automatically"
)
if args.asr_based_vad:
oracle_manifest = 'asr_based_vad'
elif args.pretrained_vad_model:
oracle_manifest = 'system_vad'
elif args.external_vad_manifest:
oracle_manifest = args.external_vad_manifest
elif args.reference_rttmfile_list_path:
logging.info("Use the oracle manifest automatically generated by rttm")
oracle_manifest = asr_diar_offline.write_VAD_rttm(
asr_diar_offline.oracle_vad_dir, audio_file_list, args.reference_rttmfile_list_path
)
params = {
"round_float": 2,
"window_length_in_sec": 1.5,
"shift_length_in_sec": 0.75,
"print_transcript": False,
"lenient_overlap_WDER": True,
"VAD_threshold_for_word_ts": 0.7,
"max_word_ts_length_in_sec": 0.6,
"word_gap_in_sec": 0.01,
"minimum": True,
"threshold": args.threshold, # minimun width to consider non-speech activity
"diar_config_url": args.diar_config_url,
"ASR_model_name": 'stt_en_conformer_ctc_large',
}
asr_diar_offline = ASR_DIAR_OFFLINE(params)
asr_model = asr_diar_offline.set_asr_model(params['ASR_model_name'])
asr_diar_offline.create_directories(args.output_path)
audio_file_list = read_file_paths(args.audiofile_list_path)
word_list, word_ts_list = asr_diar_offline.run_ASR(audio_file_list, asr_model)
diar_labels = asr_diar_offline.run_diarization(
audio_file_list,
word_ts_list,
oracle_manifest=oracle_manifest,
oracle_num_speakers=args.oracle_num_speakers,
pretrained_speaker_model=args.pretrained_speaker_model,
pretrained_vad_model=args.pretrained_vad_model,
)
if args.reference_rttmfile_list_path:
ref_rttm_file_list = read_file_paths(args.reference_rttmfile_list_path)
ref_labels_list, DER_result_dict = asr_diar_offline.eval_diarization(
audio_file_list, diar_labels, ref_rttm_file_list
)
total_riva_dict = asr_diar_offline.write_json_and_transcript(audio_file_list, diar_labels, word_list, word_ts_list)
WDER_dict = asr_diar_offline.get_WDER(audio_file_list, total_riva_dict, DER_result_dict, ref_labels_list)
effective_wder = asr_diar_offline.get_effective_WDER(DER_result_dict, WDER_dict)
logging.info(
f"\nDER : {DER_result_dict['total']['DER']:.4f} \
\nFA : {DER_result_dict['total']['FA']:.4f} \
\nMISS : {DER_result_dict['total']['MISS']:.4f} \
\nCER : {DER_result_dict['total']['CER']:.4f} \
\nWDER : {WDER_dict['total']:.4f} \
\neffective WDER : {effective_wder:.4f} \
\nspk_counting_acc : {DER_result_dict['total']['spk_counting_acc']:.4f}"
)
else:
total_riva_dict = asr_diar_offline.write_json_and_transcript(audio_file_list, diar_labels, word_list, word_ts_list)

View file

@ -0,0 +1,47 @@
name: &name "ClusterDiarizer"
num_workers: 4
sample_rate: 16000
batch_size: 64
diarizer:
manifest_filepath: ???
out_dir: ???
oracle_vad: False # if True, uses RTTM files provided in manifest file to get speech activity (VAD) timestamps
collar: 0.25 # collar value for scoring
ignore_overlap: True # consider or ignore overlap segments while scoring
vad:
model_path: null #.nemo local model path or pretrained model name or none
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
parameters: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set)
window_length_in_sec: 0.15 # window length in sec for VAD context input
shift_length_in_sec: 0.01 # shift length in sec for generate frame level VAD prediction
smoothing: "median" # False or type of smoothing method (eg: median)
overlap: 0.875 # Overlap ratio for overlapped mean/median smoothing filter
onset: 0.4 # onset threshold for detecting the beginning and end of a speech
offset: 0.7 # offset threshold for detecting the end of a speech.
pad_onset: 0.05 # adding durations before each speech segment
pad_offset: -0.1 # adding durations after each speech segment
min_duration_on: 0.2 # threshold for small non_speech deletion
min_duration_off: 0.2 # threshold for short speech segment deletion
filter_speech_first: True
speaker_embeddings:
model_path: ??? #.nemo local model path or pretrained model name (ecapa_tdnn or speakerverification_speakernet)
parameters:
window_length_in_sec: 1.5 # window length in sec for speaker embedding extraction
shift_length_in_sec: 0.75 # shift length in sec for speaker embedding extraction
save_embeddings: False # save embeddings as pickle file for each audio input
clustering:
parameters:
oracle_num_speakers: False # if True, uses num of speakers value provided in manifest file
max_num_speakers: 20 # max number of speakers for each recording. If oracle num speakers is passed this value is ignored.
enhanced_count_thres: 80
max_rp_threshold: 0.25
sparse_search_volume: 30
# json manifest line example
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, label: "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"}

View file

@ -0,0 +1,58 @@
name: &name "ClusterDiarizer"
num_workers: 4
sample_rate: 16000
batch_size: 64
diarizer:
manifest_filepath: ???
out_dir: ???
oracle_vad: False # if True, uses RTTM files provided in manifest file to get speech activity (VAD) timestamps
collar: 0.25 # collar value for scoring
ignore_overlap: True # consider or ignore overlap segments while scoring
vad:
model_path: null #.nemo local model path or pretrained model name or none
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
parameters: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set)
window_length_in_sec: 0.15 # window length in sec for VAD context input
shift_length_in_sec: 0.01 # shift length in sec for generate frame level VAD prediction
smoothing: "median" # False or type of smoothing method (eg: median)
overlap: 0.875 # Overlap ratio for overlapped mean/median smoothing filter
onset: 0.4 # onset threshold for detecting the beginning and end of a speech
offset: 0.7 # offset threshold for detecting the end of a speech.
pad_onset: 0.05 # adding durations before each speech segment
pad_offset: -0.1 # adding durations after each speech segment
min_duration_on: 0.2 # threshold for small non_speech deletion
min_duration_off: 0.2 # threshold for short speech segment deletion
filter_speech_first: True
speaker_embeddings:
model_path: ??? #.nemo local model path or pretrained model name (ecapa_tdnn or speakerverification_speakernet)
parameters:
window_length_in_sec: 1.5 # window length in sec for speaker embedding extraction
shift_length_in_sec: 0.75 # shift length in sec for speaker embedding extraction
save_embeddings: False # save embeddings as pickle file for each audio input
clustering:
parameters:
oracle_num_speakers: False # if True, uses num of speakers value provided in manifest file
max_num_speakers: 20 # max number of speakers for each recording. If oracle num speakers is passed this value is ignored.
enhanced_count_thres: 80
max_rp_threshold: 0.25
sparse_search_volume: 30
asr:
model_path: ??? # Provie NGC cloud ASR model name
parameters:
asr_based_vad: False # if true, speech segmentation for diarization is based on word-timestamps from ASR inference.
asr_based_vad_threshold: 50 # threshold (multiple of 10ms) for ignoring the gap between two words when generating VAD timestamps using ASR based VAD.
print_transcript: False # if true, the output transcript is displayed after the transcript is generated.
lenient_overlap_WDER: True # if true, when a word falls into speaker-ovelapped regions, consider the word as a correctly diarized word.
vad_threshold_for_word_ts: 0.7 # threshold for getting VAD stamps that compensate the word timestamps generated from ASR decoder.
max_word_ts_length_in_sec: 0.6 # Maximum length of each word timestamps
word_gap_in_sec: 0.01 # The gap between the end of a word and the start of the following word.
# json manifest line example
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, label: "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"}

View file

@ -1,36 +0,0 @@
name: &name "ClusterDiarizer"
num_workers: 4
sample_rate: 16000
batch_size: 64
diarizer:
oracle_num_speakers: null # oracle number of speakers or pass null if not known. [int or path to file containing uniq-id with num of speakers of that session]
max_num_speakers: 20 # max number of speakers for each recording. If oracle num speakers is passed this value is ignored.
out_dir: ???
paths2audio_files: ??? #file containing paths to audio files for inference
path2groundtruth_rttm_files: null #file containing paths to ground truth rttm files for score computation or null if no reference rttms are present
vad:
model_path: null #.nemo local model path or pretrained model name or none
window_length_in_sec: 0.15
shift_length_in_sec: 0.01
vad_decision_smoothing: True
smoothing_params:
method: "median"
overlap: 0.875
threshold: null # if not null, use this value and ignore the following onset and offset. Warning! it will be removed in release 1.5!
postprocessing_params: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) You might need to choose resonable threshold or tune the threshold on your dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
onset: 0.4
offset: 0.7
pad_onset: 0.05
pad_offset: -0.1
min_duration_on: 0.2
min_duration_off: 0.2
filter_speech_first: True
speaker_embeddings:
oracle_vad_manifest: null
model_path: ??? #.nemo local model path or pretrained model name (ecapa_tdnn or speakerverification_speakernet)
window_length_in_sec: 1.5 # window length in sec for speaker embedding extraction
shift_length_in_sec: 0.75 # shift length in sec for speaker embedding extraction

View file

@ -23,21 +23,19 @@ from nemo.utils import logging
This script demonstrates how to use run speaker diarization.
Usage:
python speaker_diarize.py \
diarizer.paths2audio_files=<either list of audio file paths or file containing paths to audio files> \
diarizer.path2groundtruth_rttm_files=<(Optional) either list of rttm file paths or file containing paths to rttm files> \
diarizer.vad.model_path="vad_telephony_marblenet" \
diarizer.vad.threshold=0.7 \
diarizer.speaker_embeddings.model_path="speakerverification_speakernet" \
diarizer.out_dir="demo"
diarizer.manifest_filepath=<path to manifest file> \
diarizer.out_dir='demo_output' \
diarizer.speaker_embeddings.model_path=<pretrained modelname or path to .nemo> \
diarizer.vad.model_path='vad_marblenet'
Check out whole parameters in ./conf/speaker_diarization.yaml and their meanings.
Check out whole parameters in ./conf/offline_diarization.yaml and their meanings.
For details, have a look at <NeMo_git_root>/tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb
"""
seed_everything(42)
@hydra_runner(config_path="conf", config_name="speaker_diarization.yaml")
@hydra_runner(config_path="conf", config_name="offline_diarization.yaml")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

View file

@ -0,0 +1,82 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from omegaconf import OmegaConf
from nemo.collections.asr.parts.utils.diarization_utils import ASR_DIAR_OFFLINE
from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map
from nemo.core.config import hydra_runner
from nemo.utils import logging
"""
This script demonstrates how to run offline speaker diarization with asr.
Usage:
python offline_diarization_with_asr.py \
diarizer.manifest_filepath=<path to manifest file> \
diarizer.out_dir='demo_asr_output' \
diarizer.speaker_embeddings.model_path=<pretrained modelname or path to .nemo> \
diarizer.asr.model_path=<pretrained modelname or path to .nemo> \
diarizer.asr.parameters.asr_based_vad=True
Check out whole parameters in ./conf/offline_diarization_with_asr.yaml and their meanings.
For details, have a look at <NeMo_git_root>/tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb
Currently Supported ASR models:
QuartzNet15x5Base-En
stt_en_conformer_ctc_large
"""
@hydra_runner(config_path="conf", config_name="offline_diarization_with_asr.yaml")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
asr_diar_offline = ASR_DIAR_OFFLINE(**cfg.diarizer.asr.parameters)
asr_diar_offline.root_path = cfg.diarizer.out_dir
shutil.rmtree(asr_diar_offline.root_path, ignore_errors=True)
AUDIO_RTTM_MAP = audio_rttm_map(cfg.diarizer.manifest_filepath)
asr_diar_offline.AUDIO_RTTM_MAP = AUDIO_RTTM_MAP
asr_model = asr_diar_offline.set_asr_model(cfg.diarizer.asr.model_path)
word_list, word_ts_list = asr_diar_offline.run_ASR(asr_model)
score = asr_diar_offline.run_diarization(cfg, word_ts_list,)
total_riva_dict = asr_diar_offline.write_json_and_transcript(word_list, word_ts_list)
if score is not None:
metric, mapping_dict = score
DER_result_dict = asr_diar_offline.gather_eval_results(metric, mapping_dict)
WDER_dict = asr_diar_offline.get_WDER(total_riva_dict, DER_result_dict)
effective_wder = asr_diar_offline.get_effective_WDER(DER_result_dict, WDER_dict)
logging.info(
f"\nDER : {DER_result_dict['total']['DER']:.4f} \
\nFA : {DER_result_dict['total']['FA']:.4f} \
\nMISS : {DER_result_dict['total']['MISS']:.4f} \
\nCER : {DER_result_dict['total']['CER']:.4f} \
\nWDER : {WDER_dict['total']:.4f} \
\neffective WDER : {effective_wder:.4f} \
\nspk_counting_acc : {DER_result_dict['total']['spk_counting_acc']:.4f}"
)
if __name__ == '__main__':
main()

View file

@ -222,7 +222,7 @@ def _vad_frame_seq_collate_fn(self, batch):
sig_len += slice_length
if has_audio:
slices = (sig_len - slice_length) // shift
slices = torch.div(sig_len - slice_length, shift, rounding_mode='trunc')
for slice_id in range(slices):
start_idx = slice_id * shift
end_idx = start_idx + slice_length

View file

@ -19,20 +19,22 @@ import shutil
import tarfile
import tempfile
from collections import defaultdict
from copy import deepcopy
from typing import List, Optional
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf.listconfig import ListConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.label_models import ExtractSpeakerEmbeddingsModel
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.mixins.mixins import DiarizationMixin
from nemo.collections.asr.parts.utils.speaker_utils import (
audio_rttm_map,
perform_diarization,
get_uniqname_from_filepath,
perform_clustering,
score_labels,
segments_manifest_to_subsegments_manifest,
write_rttm2manifest,
)
@ -44,7 +46,6 @@ from nemo.collections.asr.parts.utils.vad_utils import (
)
from nemo.core.classes import Model
from nemo.utils import logging, model_utils
from nemo.utils.exp_manager import NotFoundError
try:
from torch.cuda.amp import autocast
@ -64,48 +65,44 @@ _SPEAKER_MODEL = "speaker_model.nemo"
def get_available_model_names(class_name):
"lists available pretrained model names from NGC"
available_models = class_name.list_available_models()
return list(map(lambda x: x.pretrained_model_name, available_models))
class ClusteringDiarizer(Model, DiarizationMixin):
"""
Inference model Class for offline speaker diarization.
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
Extract Embeddings, Clustering, Resegmentation and Scoring.
All the parameters are passed through config file
"""
def __init__(self, cfg: DictConfig):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
# Convert config to support Hydra 1.0+ instantiation
cfg = model_utils.maybe_update_config_version(cfg)
self._cfg = cfg
self._out_dir = self._cfg.diarizer.out_dir
if not os.path.exists(self._out_dir):
os.mkdir(self._out_dir)
# Diarizer set up
self._diarizer_params = self._cfg.diarizer
# init vad model
self.has_vad_model = False
self.has_vad_model_to_save = False
self._speaker_manifest_path = self._cfg.diarizer.speaker_embeddings.oracle_vad_manifest
self.AUDIO_RTTM_MAP = None
self.paths2audio_files = self._cfg.diarizer.paths2audio_files
if self._cfg.diarizer.vad.model_path is not None:
self._init_vad_model()
self._vad_dir = os.path.join(self._out_dir, 'vad_outputs')
self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json")
shutil.rmtree(self._vad_dir, ignore_errors=True)
os.makedirs(self._vad_dir)
if not self._diarizer_params.oracle_vad:
if self._cfg.diarizer.vad.model_path is not None:
self._vad_params = self._cfg.diarizer.vad.parameters
self._init_vad_model()
# init speaker model
self._init_speaker_model()
self._speaker_params = self._cfg.diarizer.speaker_embeddings.parameters
self._speaker_dir = os.path.join(self._diarizer_params.out_dir, 'speaker_outputs')
shutil.rmtree(self._speaker_dir, ignore_errors=True)
os.makedirs(self._speaker_dir)
if self._cfg.diarizer.get('num_speakers', None):
self._num_speakers = self._cfg.diarizer.num_speakers
logging.warning("in next release num_speakers will be changed to oracle_num_speakers")
else:
self._num_speakers = self._cfg.diarizer.oracle_num_speakers
if self._cfg.diarizer.get('max_num_speakers', None):
self.max_num_speakers = self._cfg.diarizer.max_num_speakers
else:
self.max_num_speakers = 8
# Clustering params
self._cluster_params = self._diarizer_params.clustering.parameters
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -113,28 +110,10 @@ class ClusteringDiarizer(Model, DiarizationMixin):
def list_available_models(cls):
pass
def _init_speaker_model(self):
model_path = self._cfg.diarizer.speaker_embeddings.model_path
if model_path is not None and model_path.endswith('.nemo'):
self._speaker_model = ExtractSpeakerEmbeddingsModel.restore_from(model_path)
logging.info("Speaker Model restored locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(ExtractSpeakerEmbeddingsModel):
logging.warning(
"requested {} model name not available in pretrained models, instead".format(model_path)
)
model_path = "speakerdiarization_speakernet"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._speaker_model = ExtractSpeakerEmbeddingsModel.from_pretrained(model_name=model_path)
self._speaker_dir = os.path.join(self._out_dir, 'speaker_outputs')
def set_vad_model(self, vad_config):
with open_dict(self._cfg):
self._cfg.diarizer.vad = vad_config
self._init_vad_model()
def _init_vad_model(self):
"""
Initialize vad model with model name or path passed through config
"""
model_path = self._cfg.diarizer.vad.model_path
if model_path.endswith('.nemo'):
self._vad_model = EncDecClassificationModel.restore_from(model_path)
@ -148,15 +127,35 @@ class ClusteringDiarizer(Model, DiarizationMixin):
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._vad_model = EncDecClassificationModel.from_pretrained(model_name=model_path)
self._vad_window_length_in_sec = self._cfg.diarizer.vad.window_length_in_sec
self._vad_shift_length_in_sec = self._cfg.diarizer.vad.shift_length_in_sec
self.has_vad_model_to_save = True
self._vad_window_length_in_sec = self._vad_params.window_length_in_sec
self._vad_shift_length_in_sec = self._vad_params.shift_length_in_sec
self.has_vad_model = True
def _init_speaker_model(self):
"""
Initialize speaker embedding model with model name or path passed through config
"""
model_path = self._cfg.diarizer.speaker_embeddings.model_path
if model_path is not None and model_path.endswith('.nemo'):
self._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
logging.info("Speaker Model restored locally from {}".format(model_path))
elif model_path.endswith('.ckpt'):
self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(model_path)
logging.info("Speaker Model restored locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(EncDecSpeakerLabelModel):
logging.warning(
"requested {} model name not available in pretrained models, instead".format(model_path)
)
model_path = "ecapa_tdnn"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name=model_path)
def _setup_vad_test_data(self, manifest_vad_input):
vad_dl_config = {
'manifest_filepath': manifest_vad_input,
'sample_rate': self._cfg.sample_rate,
'batch_size': self._cfg.get('batch_size'),
'vad_stream': True,
'labels': ['infer',],
'time_length': self._vad_window_length_in_sec,
@ -170,11 +169,10 @@ class ClusteringDiarizer(Model, DiarizationMixin):
spk_dl_config = {
'manifest_filepath': manifest_file,
'sample_rate': self._cfg.sample_rate,
'batch_size': self._cfg.get('batch_size', 32),
'time_length': self._cfg.diarizer.speaker_embeddings.window_length_in_sec,
'shift_length': self._cfg.diarizer.speaker_embeddings.shift_length_in_sec,
'batch_size': self._cfg.get('batch_size'),
'time_length': self._speaker_params.window_length_in_sec,
'shift_length': self._speaker_params.shift_length_in_sec,
'trim_silence': False,
'embedding_dir': self._speaker_dir,
'labels': None,
'task': "diarization",
'num_workers': self._cfg.num_workers,
@ -182,6 +180,18 @@ class ClusteringDiarizer(Model, DiarizationMixin):
self._speaker_model.setup_test_data(spk_dl_config)
def _run_vad(self, manifest_file):
"""
Run voice activity detection.
Get log probability of voice activity detection and smoothes using the post processing parameters.
Using generated frame level predictions generated manifest file for later speaker embedding extraction.
input:
manifest_file (str) : Manifest file containing path to audio file and label as infer
"""
shutil.rmtree(self._vad_dir, ignore_errors=True)
os.makedirs(self._vad_dir)
self._vad_model = self._vad_model.to(self._device)
self._vad_model.eval()
@ -191,8 +201,8 @@ class ClusteringDiarizer(Model, DiarizationMixin):
all_len = 0
data = []
for line in open(manifest_file, 'r'):
file = os.path.basename(json.loads(line)['audio_filepath'])
data.append(os.path.splitext(file)[0])
file = json.loads(line)['audio_filepath']
data.append(get_uniqname_from_filepath(file))
status = get_vad_stream_status(data)
for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader())):
@ -218,18 +228,17 @@ class ClusteringDiarizer(Model, DiarizationMixin):
if status[i] == 'end' or status[i] == 'single':
all_len = 0
if not self._cfg.diarizer.vad.vad_decision_smoothing:
if not self._vad_params.smoothing:
# Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
self.vad_pred_dir = self._vad_dir
else:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
logging.info("Generating predictions with overlapping input segments")
smoothing_pred_dir = generate_overlap_vad_seq(
frame_pred_dir=self._vad_dir,
smoothing_method=self._cfg.diarizer.vad.smoothing_params.method,
overlap=self._cfg.diarizer.vad.smoothing_params.overlap,
smoothing_method=self._vad_params.smoothing,
overlap=self._vad_params.overlap,
seg_len=self._vad_window_length_in_sec,
shift_len=self._vad_shift_length_in_sec,
num_workers=self._cfg.num_workers,
@ -237,30 +246,81 @@ class ClusteringDiarizer(Model, DiarizationMixin):
self.vad_pred_dir = smoothing_pred_dir
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
postprocessing_params = self._cfg.diarizer.vad.postprocessing_params
if self._cfg.diarizer.vad.get("threshold", None):
logging.info("threshold is not None. Use threshold and update onset=offset=threshold")
logging.warning("Support for threshold will be deprecated in release 1.5")
postprocessing_params['onset'] = self._cfg.diarizer.vad.threshold
postprocessing_params['offset'] = self._cfg.diarizer.vad.threshold
table_out_dir = generate_vad_segment_table(
vad_pred_dir=self.vad_pred_dir,
postprocessing_params=postprocessing_params,
postprocessing_params=self._vad_params,
shift_len=self._vad_shift_length_in_sec,
num_workers=self._cfg.num_workers,
)
AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy())
for key in AUDIO_VAD_RTTM_MAP:
AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(table_out_dir, key + ".txt")
vad_table_list = [os.path.join(table_out_dir, key + ".txt") for key in self.AUDIO_RTTM_MAP]
write_rttm2manifest(self._cfg.diarizer.paths2audio_files, vad_table_list, self._vad_out_file)
write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file)
self._speaker_manifest_path = self._vad_out_file
def _run_segmentation(self):
self.subsegments_manifest_path = os.path.join(self._speaker_dir, 'subsegments.json')
self.subsegments_manifest_path = segments_manifest_to_subsegments_manifest(
segments_manifest_file=self._speaker_manifest_path,
subsegments_manifest_file=self.subsegments_manifest_path,
window=self._speaker_params.window_length_in_sec,
shift=self._speaker_params.shift_length_in_sec,
)
return None
def _perform_speech_activity_detection(self):
"""
Checks for type of speech activity detection from config. Choices are NeMo VAD,
external vad manifest and oracle VAD (generates speech activity labels from provided RTTM files)
"""
if self.has_vad_model:
self._dont_auto_split = False
self._split_duration = 50
manifest_vad_input = self._diarizer_params.manifest_filepath
if not self._dont_auto_split:
logging.info("Split long audio file to avoid CUDA memory issue")
logging.debug("Try smaller split_duration if you still have CUDA memory issue")
config = {
'manifest_filepath': manifest_vad_input,
'time_length': self._vad_window_length_in_sec,
'split_duration': self._split_duration,
'num_workers': self._cfg.num_workers,
}
manifest_vad_input = prepare_manifest(config)
else:
logging.warning(
"If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
)
self._setup_vad_test_data(manifest_vad_input)
self._run_vad(manifest_vad_input)
elif self._diarizer_params.vad.external_vad_manifest is not None:
self._speaker_manifest_path = self._diarizer_params.vad.external_vad_manifest
elif self._diarizer_params.oracle_vad:
self._speaker_manifest_path = os.path.join(self._speaker_dir, 'oracle_vad_manifest.json')
self._speaker_manifest_path = write_rttm2manifest(self.AUDIO_RTTM_MAP, self._speaker_manifest_path)
else:
raise ValueError(
"Only one of diarizer.oracle_vad, vad.model_path or vad.external_vad_manifest must be passed"
)
def _extract_embeddings(self, manifest_file):
"""
This method extracts speaker embeddings from segments passed through manifest_file
Optionally you may save the intermediate speaker embeddings for debugging or any use.
"""
logging.info("Extracting embeddings for Diarization")
self._setup_spkr_test_data(manifest_file)
out_embeddings = defaultdict(list)
self.embeddings = defaultdict(list)
self._speaker_model = self._speaker_model.to(self._device)
self._speaker_model.eval()
self.time_stamps = {}
all_embs = []
for test_batch in tqdm(self._speaker_model.test_dataloader()):
@ -277,105 +337,96 @@ class ClusteringDiarizer(Model, DiarizationMixin):
for i, line in enumerate(manifest.readlines()):
line = line.strip()
dic = json.loads(line)
uniq_name = os.path.basename(dic['audio_filepath']).rsplit('.', 1)[0]
out_embeddings[uniq_name].extend([all_embs[i]])
uniq_name = get_uniqname_from_filepath(dic['audio_filepath'])
self.embeddings[uniq_name].extend([all_embs[i]])
if uniq_name not in self.time_stamps:
self.time_stamps[uniq_name] = []
start = dic['offset']
end = start + dic['duration']
stamp = '{:.3f} {:.3f} '.format(start, end)
self.time_stamps[uniq_name].append(stamp)
embedding_dir = os.path.join(self._speaker_dir, 'embeddings')
if not os.path.exists(embedding_dir):
os.makedirs(embedding_dir, exist_ok=True)
if self._speaker_params.save_embeddings:
embedding_dir = os.path.join(self._speaker_dir, 'embeddings')
if not os.path.exists(embedding_dir):
os.makedirs(embedding_dir, exist_ok=True)
prefix = manifest_file.split('/')[-1].rsplit('.', 1)[-2]
prefix = get_uniqname_from_filepath(manifest_file)
name = os.path.join(embedding_dir, prefix)
self._embeddings_file = name + '_embeddings.pkl'
pkl.dump(out_embeddings, open(self._embeddings_file, 'wb'))
logging.info("Saved embedding files to {}".format(embedding_dir))
name = os.path.join(embedding_dir, prefix)
self._embeddings_file = name + '_embeddings.pkl'
pkl.dump(self.embeddings, open(self._embeddings_file, 'wb'))
logging.info("Saved embedding files to {}".format(embedding_dir))
def path2audio_files_to_manifest(self, paths2audio_files):
mfst_file = os.path.join(self._out_dir, 'manifest.json')
with open(mfst_file, 'w') as fp:
def path2audio_files_to_manifest(self, paths2audio_files, manifest_filepath):
with open(manifest_filepath, 'w') as fp:
for audio_file in paths2audio_files:
audio_file = audio_file.strip()
entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'}
fp.write(json.dumps(entry) + '\n')
return mfst_file
def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0):
"""
Diarize files provided thorugh paths2audio_files or manifest file
input:
paths2audio_files (List[str]): list of paths to file containing audio file
batch_size (int): batch_size considered for extraction of speaker embeddings and VAD computation
"""
def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 1):
"""
"""
self._out_dir = self._diarizer_params.out_dir
if not os.path.exists(self._out_dir):
os.mkdir(self._out_dir)
self._vad_dir = os.path.join(self._out_dir, 'vad_outputs')
self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json")
if batch_size:
self._cfg.batch_size = batch_size
if paths2audio_files:
self.paths2audio_files = paths2audio_files
else:
if self._cfg.diarizer.paths2audio_files is None:
raise ValueError("Pass path2audio files either through config or to diarize method")
if type(paths2audio_files) is list:
self._diarizer_params.manifest_filepath = os.path.json(self._out_dir, 'paths2audio_filepath.json')
self.path2audio_files_to_manifest(paths2audio_files, self._diarizer_params.manifest_filepath)
else:
self.paths2audio_files = self._cfg.diarizer.paths2audio_files
raise ValueError("paths2audio_files must be of type list of paths to file containing audio file")
if type(self.paths2audio_files) is str and os.path.isfile(self.paths2audio_files):
paths2audio_files = []
with open(self.paths2audio_files, 'r') as path2file:
for audiofile in path2file.readlines():
audiofile = audiofile.strip()
paths2audio_files.append(audiofile)
self.AUDIO_RTTM_MAP = audio_rttm_map(self._diarizer_params.manifest_filepath)
elif type(self.paths2audio_files) in [list, ListConfig]:
paths2audio_files = list(self.paths2audio_files)
# Speech Activity Detection
self._perform_speech_activity_detection()
else:
raise ValueError("paths2audio_files must be of type list or path to file containing audio files")
# Segmentation
self._run_segmentation()
self.AUDIO_RTTM_MAP = audio_rttm_map(paths2audio_files, self._cfg.diarizer.path2groundtruth_rttm_files)
if self.has_vad_model:
logging.info("Performing VAD")
mfst_file = self.path2audio_files_to_manifest(paths2audio_files)
self._dont_auto_split = False
self._split_duration = 50
manifest_vad_input = mfst_file
if not self._dont_auto_split:
logging.info("Split long audio file to avoid CUDA memory issue")
logging.debug("Try smaller split_duration if you still have CUDA memory issue")
config = {
'manifest_filepath': mfst_file,
'time_length': self._vad_window_length_in_sec,
'split_duration': self._split_duration,
'num_workers': self._cfg.num_workers,
}
manifest_vad_input = prepare_manifest(config)
else:
logging.warning(
"If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
)
self._setup_vad_test_data(manifest_vad_input)
self._run_vad(manifest_vad_input)
else:
if not os.path.exists(self._speaker_manifest_path):
raise NotFoundError("Oracle VAD based manifest file not found")
self.subsegments_manifest_path = os.path.join(self._out_dir, 'subsegments.json')
self.subsegments_manifest_path = segments_manifest_to_subsegments_manifest(
segments_manifest_file=self._speaker_manifest_path,
subsegments_manifest_file=self.subsegments_manifest_path,
window=self._cfg.diarizer.speaker_embeddings.window_length_in_sec,
shift=self._cfg.diarizer.speaker_embeddings.shift_length_in_sec,
)
# Embedding Extraction
self._extract_embeddings(self.subsegments_manifest_path)
out_rttm_dir = os.path.join(self._out_dir, 'pred_rttms')
os.makedirs(out_rttm_dir, exist_ok=True)
perform_diarization(
embeddings_file=self._embeddings_file,
reco2num=self._num_speakers,
manifest_path=self.subsegments_manifest_path,
audio_rttm_map=self.AUDIO_RTTM_MAP,
# Clustering
all_reference, all_hypothesis = perform_clustering(
embeddings=self.embeddings,
time_stamps=self.time_stamps,
AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP,
out_rttm_dir=out_rttm_dir,
max_num_speakers=self.max_num_speakers,
clustering_params=self._cluster_params,
)
# TODO Resegmentation -> Coming Soon
# Scoring
score = score_labels(
self.AUDIO_RTTM_MAP,
all_reference,
all_hypothesis,
collar=self._diarizer_params.collar,
ignore_overlap=self._diarizer_params.ignore_overlap,
)
logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))
return score
@staticmethod
def __make_nemo_file_from_folder(filename, source_dir):
with tarfile.open(filename, "w:gz") as tar:

View file

@ -21,19 +21,17 @@ from collections import OrderedDict as od
from datetime import datetime
from typing import List
import librosa
import numpy as np
import soundfile as sf
import torch
import wget
from omegaconf import OmegaConf
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.metrics.wer_bpe import WERBPE
from nemo.collections.asr.models import ClusteringDiarizer, EncDecCTCModel, EncDecCTCModelBPE, EncDecRNNTBPEModel
from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map as get_audio_rttm_map
from nemo.collections.asr.parts.utils.speaker_utils import (
get_DER,
labels_to_pyannote_object,
get_uniqname_from_filepath,
labels_to_rttmfile,
rttm_to_labels,
write_rttm2manifest,
)
@ -61,26 +59,6 @@ def write_txt(w_path, val):
return None
def get_uniq_id_from_audio_path(audio_file_path):
"""Get the unique ID from the audio file path
"""
return '.'.join(os.path.basename(audio_file_path).split('.')[:-1])
def read_file_paths(file_list_path):
"""Read file paths from the given list
"""
out_path_list = []
if not file_list_path or (file_list_path in NONE_LIST):
raise ValueError("file_list_path is not provided.")
else:
with open(file_list_path, 'r') as path2file:
for _file in path2file.readlines():
out_path_list.append(_file.strip())
return out_path_list
class WERBPE_TS(WERBPE):
"""
This is WERBPE_TS class that is modified for generating word_timestamps with logits.
@ -306,15 +284,16 @@ class ASR_DIAR_OFFLINE(object):
A Class designed for performing ASR and diarization together.
"""
def __init__(self, params):
def __init__(self, **params):
self.params = params
self.nonspeech_threshold = self.params['threshold']
self.nonspeech_threshold = self.params['asr_based_vad_threshold']
self.root_path = None
self.fix_word_ts_with_VAD = True
self.run_ASR = None
self.frame_VAD = {}
self.AUDIO_RTTM_MAP = {}
def set_asr_model(self, ASR_model_name):
def set_asr_model(self, asr_model):
"""
Setup the parameters for the given ASR model
Currently, the following models are supported:
@ -324,79 +303,53 @@ class ASR_DIAR_OFFLINE(object):
QuartzNet15x5Base-En
"""
if 'QuartzNet' in ASR_model_name:
if 'QuartzNet' in asr_model:
self.run_ASR = self.run_ASR_QuartzNet_CTC
asr_model = EncDecCTCModel.from_pretrained(model_name=ASR_model_name, strict=False)
asr_model = EncDecCTCModel.from_pretrained(model_name=asr_model, strict=False)
self.params['offset'] = -0.18
self.model_stride_in_secs = 0.02
self.asr_delay_sec = -1 * self.params['offset']
elif 'conformer_ctc' in ASR_model_name:
elif 'conformer_ctc' in asr_model:
self.run_ASR = self.run_ASR_BPE_CTC
asr_model = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name, strict=False)
asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False)
self.model_stride_in_secs = 0.04
self.asr_delay_sec = 0.0
self.params['offset'] = 0
self.chunk_len_in_sec = 1.6
self.total_buffer_in_secs = 4
elif 'citrinet' in ASR_model_name:
elif 'citrinet' in asr_model:
self.run_ASR = self.run_ASR_BPE_CTC
asr_model = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name, strict=False)
asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False)
self.model_stride_in_secs = 0.08
self.asr_delay_sec = 0.0
self.params['offset'] = 0
self.chunk_len_in_sec = 1.6
self.total_buffer_in_secs = 4
elif 'conformer_transducer' in ASR_model_name or 'contextnet' in ASR_model_name:
elif 'conformer_transducer' in asr_model or 'contextnet' in asr_model:
self.run_ASR = self.run_ASR_BPE_RNNT
self.get_speech_labels_list = self.save_VAD_labels_list
asr_model = EncDecRNNTBPEModel.from_pretrained(model_name=ASR_model_name, strict=False)
asr_model = EncDecRNNTBPEModel.from_pretrained(model_name=asr_model, strict=False)
self.model_stride_in_secs = 0.04
self.asr_delay_sec = 0.0
self.params['offset'] = 0
self.chunk_len_in_sec = 1.6
self.total_buffer_in_secs = 4
else:
raise ValueError(f"ASR model name not found: {self.params['ASR_model_name']}")
raise ValueError(f"ASR model name not found: {asr_model}")
self.params['time_stride'] = self.model_stride_in_secs
self.asr_batch_size = 16
self.audio_file_list = [value['audio_filepath'] for _, value in self.AUDIO_RTTM_MAP.items()]
return asr_model
def create_directories(self, output_path=None):
"""Creates directories for transcribing with diarization.
"""
if output_path:
os.makedirs(output_path, exist_ok=True)
assert os.path.exists(output_path)
ROOT = output_path
else:
ROOT = os.path.join(os.getcwd(), 'asr_with_diar')
self.oracle_vad_dir = os.path.join(ROOT, 'oracle_vad')
self.json_result_dir = os.path.join(ROOT, 'json_result')
self.trans_with_spks_dir = os.path.join(ROOT, 'transcript_with_speaker_labels')
self.audacity_label_dir = os.path.join(ROOT, 'audacity_label')
self.root_path = ROOT
os.makedirs(self.root_path, exist_ok=True)
os.makedirs(self.oracle_vad_dir, exist_ok=True)
os.makedirs(self.json_result_dir, exist_ok=True)
os.makedirs(self.trans_with_spks_dir, exist_ok=True)
os.makedirs(self.audacity_label_dir, exist_ok=True)
data_dir = os.path.join(ROOT, 'data')
os.makedirs(data_dir, exist_ok=True)
def run_ASR_QuartzNet_CTC(self, audio_file_list, _asr_model):
def run_ASR_QuartzNet_CTC(self, _asr_model):
"""
Run an QuartzNet ASR model and collect logit, timestamps and text output
Args:
audio_file_list (list):
List of audio file paths.
_asr_model (class):
The loaded NeMo ASR model.
@ -418,7 +371,7 @@ class ASR_DIAR_OFFLINE(object):
)
with torch.cuda.amp.autocast():
transcript_logits_list = _asr_model.transcribe(audio_file_list, batch_size=1, logprobs=True)
transcript_logits_list = _asr_model.transcribe(self.audio_file_list, batch_size=1, logprobs=True)
for logit_np in transcript_logits_list:
log_prob = torch.from_numpy(logit_np)
logits_len = torch.from_numpy(np.array([log_prob.shape[0]]))
@ -438,16 +391,14 @@ class ASR_DIAR_OFFLINE(object):
word_ts_list.append(word_ts)
return words_list, word_ts_list
def run_ASR_BPE_RNNT(self, audio_file_list, _asr_model):
def run_ASR_BPE_RNNT(self, _asr_model):
raise NotImplementedError
def run_ASR_BPE_CTC(self, audio_file_list, _asr_model):
def run_ASR_BPE_CTC(self, _asr_model):
"""
Run a CTC-BPE based ASR model and collect logit, timestamps and text output
Args:
audio_file_list (list):
List of audio file paths.
_asr_model (class):
The loaded NeMo ASR model.
@ -486,9 +437,8 @@ class ASR_DIAR_OFFLINE(object):
)
with torch.cuda.amp.autocast():
logging.info(f"Running ASR model {self.params['ASR_model_name']}")
hyps, tokens_list, sample_list = get_wer_feat_logit(
audio_file_list,
self.audio_file_list,
frame_asr,
self.chunk_len_in_sec,
tokens_per_chunk,
@ -537,12 +487,17 @@ class ASR_DIAR_OFFLINE(object):
List that contains word timestamps.
audio_file_list (list):
List of audio file paths.
"""
self.VAD_RTTM_MAP = {}
for i, word_timestamps in enumerate(word_ts_list):
speech_labels_float = self._get_speech_labels_from_decoded_prediction(word_timestamps)
speech_labels = self.get_str_speech_labels(speech_labels_float)
self.write_VAD_rttm_from_speech_labels(self.root_path, audio_file_list[i], speech_labels)
uniq_id = get_uniqname_from_filepath(audio_file_list[i])
output_path = os.path.join(self.root_path, 'pred_rttms')
if not os.path.exists(output_path):
os.makedirs(output_path)
filename = labels_to_rttmfile(speech_labels, uniq_id, output_path)
self.VAD_RTTM_MAP[uniq_id] = {'audio_filepath': audio_file_list[i], 'rttm_filepath': filename}
def _get_speech_labels_from_decoded_prediction(self, input_word_ts):
"""
@ -604,80 +559,31 @@ class ASR_DIAR_OFFLINE(object):
return word_timestamps
def run_diarization(
self,
audio_file_list,
words_and_timestamps,
oracle_manifest=None,
oracle_num_speakers=None,
pretrained_speaker_model=None,
pretrained_vad_model=None,
self, diar_model_config, words_and_timestamps,
):
"""
Run the diarization process using the given VAD timestamp (oracle_manifest).
Args:
audio_file_list (list):
List of audio file paths.
word_and_timestamps (list):
List contains words and word-timestamps.
oracle_manifest (str):
A json file path which contains the timestamps of the VAD output.
if None, we use word-timestamps for VAD and segmentation.
oracle_num_speakers (int):
Oracle number of speakers. If None, the number of speakers is estimated.
pretrained_speaker_model (str):
NeMo model file path for speaker embedding extractor model.
Return:
diar_labels (list):
List that contains diarization result in the form of
speaker labels and time stamps.
List contains words and word-timestamps
"""
if oracle_num_speakers != None:
if oracle_num_speakers.isnumeric():
oracle_num_speakers = int(oracle_num_speakers)
elif oracle_num_speakers in NONE_LIST:
oracle_num_speakers = None
if diar_model_config.diarizer.asr.parameters.asr_based_vad:
self.save_VAD_labels_list(words_and_timestamps, self.audio_file_list)
oracle_manifest = os.path.join(self.root_path, 'asr_vad_manifest.json')
oracle_manifest = write_rttm2manifest(self.VAD_RTTM_MAP, oracle_manifest)
diar_model_config.diarizer.vad.model_path = None
diar_model_config.diarizer.vad.external_vad_manifest = oracle_manifest
data_dir = os.path.join(self.root_path, 'data')
oracle_model = ClusteringDiarizer(cfg=diar_model_config)
score = oracle_model.diarize()
if diar_model_config.diarizer.vad.model_path is not None:
self.get_frame_level_VAD(vad_processing_dir=oracle_model.vad_pred_dir)
MODEL_CONFIG = os.path.join(data_dir, 'speaker_diarization.yaml')
if not os.path.exists(MODEL_CONFIG):
MODEL_CONFIG = wget.download(self.params['diar_config_url'], data_dir)
return score
config = OmegaConf.load(MODEL_CONFIG)
if oracle_manifest == 'asr_based_vad':
# Use ASR-based VAD for diarization.
self.save_VAD_labels_list(words_and_timestamps, audio_file_list)
oracle_manifest = self.write_VAD_rttm(self.oracle_vad_dir, audio_file_list)
elif oracle_manifest == 'system_vad':
# Use System VAD for diarization.
logging.info(f"Using the provided system VAD model for diarization: {pretrained_vad_model}")
config.diarizer.vad.model_path = pretrained_vad_model
else:
config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest
output_dir = os.path.join(self.root_path, 'oracle_vad')
config.diarizer.paths2audio_files = audio_file_list
config.diarizer.out_dir = output_dir # Directory to store intermediate files and prediction outputs
if pretrained_speaker_model:
config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest
config.diarizer.oracle_num_speakers = oracle_num_speakers
config.diarizer.speaker_embeddings.shift_length_in_sec = self.params['shift_length_in_sec']
config.diarizer.speaker_embeddings.window_length_in_sec = self.params['window_length_in_sec']
oracle_model = ClusteringDiarizer(cfg=config)
oracle_model.diarize()
if oracle_manifest == 'system_vad':
self.get_frame_level_VAD(oracle_model, audio_file_list)
diar_labels = self.get_diarization_labels(audio_file_list)
return diar_labels
def get_frame_level_VAD(self, oracle_model, audio_file_list):
def get_frame_level_VAD(self, vad_processing_dir):
"""
Read frame-level VAD.
Args:
@ -686,104 +592,68 @@ class ASR_DIAR_OFFLINE(object):
audio_file_path (List):
List contains file paths for audio files.
"""
for k, audio_file_path in enumerate(audio_file_list):
uniq_id = get_uniq_id_from_audio_path(audio_file_path)
vad_pred_diar = oracle_model.vad_pred_dir
frame_vad = os.path.join(vad_pred_diar, uniq_id + '.median')
frame_vad_list = read_file_paths(frame_vad)
frame_vad_float_list = [float(x) for x in frame_vad_list]
for uniq_id in self.AUDIO_RTTM_MAP:
frame_vad = os.path.join(vad_processing_dir, uniq_id + '.median')
frame_vad_float_list = []
with open(frame_vad, 'r') as fp:
for line in fp.readlines():
frame_vad_float_list.append(float(line.strip()))
self.frame_VAD[uniq_id] = frame_vad_float_list
def get_diarization_labels(self, audio_file_list):
def gather_eval_results(self, metric, mapping_dict):
"""
Save the diarization labels into a list.
Arg:
audio_file_list (list):
List of audio file paths.
Return:
diar_labels (list):
List of the speaker labels for each speech segment.
Gathers diarization evaluation results from pyannote DiarizationErrorRate metric object.
Inputs
metric (DiarizationErrorRate metric): DiarizationErrorRate metric pyannote object
mapping_dict (dict): Dictionary containing speaker mapping labels for each audio file with key as uniq name
Returns
DER_result_dict (dict): Dictionary containing scores for each audio file along with aggreated results
"""
diar_labels = []
for k, audio_file_path in enumerate(audio_file_list):
uniq_id = get_uniq_id_from_audio_path(audio_file_path)
pred_rttm = os.path.join(self.oracle_vad_dir, 'pred_rttms', uniq_id + '.rttm')
pred_labels = rttm_to_labels(pred_rttm)
diar_labels.append(pred_labels)
est_n_spk = self.get_num_of_spk_from_labels(pred_labels)
logging.info(f"Estimated n_spk [{uniq_id}]: {est_n_spk}")
return diar_labels
def eval_diarization(self, audio_file_list, diar_labels, ref_rttm_file_list):
"""
Evaluate the predicted speaker labels (pred_rttm) using ref_rttm_file_list.
DER and speaker counting accuracy are calculated.
Args:
audio_file_list (list):
List of audio file paths.
ref_rttm_file_list (list):
List of reference rttm paths.
Returns:
ref_labels_list (list):
Return ref_labels_list for future use.
DER_result_dict (dict):
A dictionary that contains evaluation results.
"""
ref_labels_list = []
all_hypotheses, all_references = [], []
results = metric.results_
DER_result_dict = {}
count_correct_spk_counting = 0
audio_rttm_map = get_audio_rttm_map(audio_file_list, ref_rttm_file_list)
for k, audio_file_path in enumerate(audio_file_list):
uniq_id = get_uniq_id_from_audio_path(audio_file_path)
rttm_file = audio_rttm_map[uniq_id]['rttm_path']
if os.path.exists(rttm_file):
ref_labels = rttm_to_labels(rttm_file)
ref_labels_list.append(ref_labels)
reference = labels_to_pyannote_object(ref_labels)
all_references.append(reference)
else:
raise ValueError("No reference RTTM file provided.")
pred_labels = diar_labels[k]
for result in results:
key, score = result
pred_rttm = os.path.join(self.root_path, 'pred_rttms', key + '.rttm')
pred_labels = rttm_to_labels(pred_rttm)
est_n_spk = self.get_num_of_spk_from_labels(pred_labels)
ref_rttm = self.AUDIO_RTTM_MAP[key]['rttm_filepath']
ref_labels = rttm_to_labels(ref_rttm)
ref_n_spk = self.get_num_of_spk_from_labels(ref_labels)
hypothesis = labels_to_pyannote_object(pred_labels)
all_hypotheses.append(hypothesis)
DER, CER, FA, MISS, mapping = get_DER([reference], [hypothesis])
DER_result_dict[uniq_id] = {
DER, CER, FA, MISS = (
score['diarization error rate'],
score['confusion'],
score['false alarm'],
score['missed detection'],
)
DER_result_dict[key] = {
"DER": DER,
"CER": CER,
"FA": FA,
"MISS": MISS,
"n_spk": est_n_spk,
"mapping": mapping[0],
"mapping": mapping_dict[key],
"spk_counting": (est_n_spk == ref_n_spk),
}
logging.info("score for session {}: {}".format(key, DER_result_dict[key]))
count_correct_spk_counting += int(est_n_spk == ref_n_spk)
DER, CER, FA, MISS, mapping = get_DER(all_references, all_hypotheses)
logging.info(
"Cumulative results of all the files: \n FA: {:.4f}\t MISS {:.4f}\t\
Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format(
FA, MISS, DER, CER
)
DER, CER, FA, MISS = (
abs(metric),
metric['confusion'] / metric['total'],
metric['false alarm'] / metric['total'],
metric['missed detection'] / metric['total'],
)
DER_result_dict['total'] = {
DER_result_dict["total"] = {
"DER": DER,
"CER": CER,
"FA": FA,
"MISS": MISS,
"spk_counting_acc": count_correct_spk_counting / len(audio_file_list),
"spk_counting_acc": count_correct_spk_counting / len(metric.results_),
}
return ref_labels_list, DER_result_dict
return DER_result_dict
@staticmethod
def closest_silence_start(vad_index_word_end, vad_frames, params, offset=10):
@ -807,7 +677,7 @@ class ASR_DIAR_OFFLINE(object):
c = vad_index_word_end + offset
limit = int(100 * params['max_word_ts_length_in_sec'] + vad_index_word_end)
while c < len(vad_frames):
if vad_frames[c] < params['VAD_threshold_for_word_ts']:
if vad_frames[c] < params['vad_threshold_for_word_ts']:
break
else:
c += 1
@ -839,7 +709,7 @@ class ASR_DIAR_OFFLINE(object):
"""
enhanced_word_ts_list = []
for idx, word_ts_seq_list in enumerate(word_ts_list):
uniq_id = get_uniq_id_from_audio_path(audio_file_list[idx])
uniq_id = get_uniqname_from_filepath(audio_file_list[idx])
N = len(word_ts_seq_list)
enhanced_word_ts_buffer = []
for k, word_ts in enumerate(word_ts_seq_list):
@ -857,11 +727,14 @@ class ASR_DIAR_OFFLINE(object):
min_candidate = min(vad_est_len, len_to_next_word)
fixed_word_len = max(min(params['max_word_ts_length_in_sec'], min_candidate), word_len)
enhanced_word_ts_buffer.append([word_ts[0], word_ts[0] + fixed_word_len])
else:
enhanced_word_ts_buffer.append([word_ts[0], word_ts[1]])
enhanced_word_ts_list.append(enhanced_word_ts_buffer)
return enhanced_word_ts_list
def write_json_and_transcript(
self, audio_file_list, diar_labels, word_list, word_ts_list,
self, word_list, word_ts_list,
):
"""
Matches the diarization result with the ASR output.
@ -869,8 +742,6 @@ class ASR_DIAR_OFFLINE(object):
in a for loop.
Args:
audio_file_list (list):
List that contains audio file paths.
diar_labels (list):
List of the Diarization output labels in str.
word_list (list):
@ -888,15 +759,16 @@ class ASR_DIAR_OFFLINE(object):
"""
total_riva_dict = {}
if self.fix_word_ts_with_VAD:
word_ts_list = self.compensate_word_ts_list(audio_file_list, word_ts_list, self.params)
word_ts_list = self.compensate_word_ts_list(self.audio_file_list, word_ts_list, self.params)
if self.frame_VAD == {}:
logging.info(
f"VAD timestamps are not provided and skipping word timestamp fix. Please check the VAD model."
)
for k, audio_file_path in enumerate(audio_file_list):
uniq_id = get_uniq_id_from_audio_path(audio_file_path)
labels = diar_labels[k]
for k, audio_file_path in enumerate(self.audio_file_list):
uniq_id = get_uniqname_from_filepath(audio_file_path)
pred_rttm = os.path.join(self.root_path, 'pred_rttms', uniq_id + '.rttm')
labels = rttm_to_labels(pred_rttm)
audacity_label_words = []
n_spk = self.get_num_of_spk_from_labels(labels)
string_out = ''
@ -915,7 +787,6 @@ class ASR_DIAR_OFFLINE(object):
logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ")
string_out = self.print_time(string_out, speaker, start_point, end_point, self.params)
word_pos, idx = 0, 0
for j, word_ts_stt_end in enumerate(word_ts_list[k]):
@ -941,7 +812,7 @@ class ASR_DIAR_OFFLINE(object):
return total_riva_dict
def get_WDER(self, audio_file_list, total_riva_dict, DER_result_dict, ref_labels_list):
def get_WDER(self, total_riva_dict, DER_result_dict):
"""
Calculate word-level diarization error rate (WDER). WDER is calculated by
counting the the wrongly diarized words and divided by the total number of words
@ -953,8 +824,6 @@ class ASR_DIAR_OFFLINE(object):
DER_result_dict (dict):
The dictionary that stores DER, FA, Miss, CER, mapping, the estimated
number of speakers and speaker counting accuracy.
audio_file_list (list):
List that contains audio file paths.
ref_labels_list (list):
List that contains the ground truth speaker labels for each segment.
@ -964,10 +833,11 @@ class ASR_DIAR_OFFLINE(object):
"""
wder_dict = {}
grand_total_word_count, grand_correct_word_count = 0, 0
for k, audio_file_path in enumerate(audio_file_list):
for k, audio_file_path in enumerate(self.audio_file_list):
labels = ref_labels_list[k]
uniq_id = get_uniq_id_from_audio_path(audio_file_path)
uniq_id = get_uniqname_from_filepath(audio_file_path)
ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath']
labels = rttm_to_labels(ref_rttm)
mapping_dict = DER_result_dict[uniq_id]['mapping']
words_list = total_riva_dict[uniq_id]['words']
@ -1027,7 +897,7 @@ class ASR_DIAR_OFFLINE(object):
Saves the diarization result into a csv file.
"""
row = [
args.threshold,
args.asr_based_vad_threshold,
WDER_dict['total'],
DER_result_dict['total']['DER'],
DER_result_dict['total']['FA'],
@ -1079,14 +949,14 @@ class ASR_DIAR_OFFLINE(object):
"""Writes output files and display logging messages.
"""
ROOT = self.root_path
logging.info(f"Writing {ROOT}/json_result/{uniq_id}.json")
dump_json_to_file(f'{ROOT}/json_result/{uniq_id}.json', riva_dict)
logging.info(f"Writing {ROOT}/pred_rttms/{uniq_id}.json")
dump_json_to_file(f'{ROOT}/pred_rttms/{uniq_id}.json', riva_dict)
logging.info(f"Writing {ROOT}/transcript_with_speaker_labels/{uniq_id}.txt")
write_txt(f'{ROOT}/transcript_with_speaker_labels/{uniq_id}.txt', string_out.strip())
logging.info(f"Writing {ROOT}/pred_rttms/{uniq_id}.txt")
write_txt(f'{ROOT}/pred_rttms/{uniq_id}.txt', string_out.strip())
logging.info(f"Writing {ROOT}/audacity_label/{uniq_id}.w.label")
write_txt(f'{ROOT}/audacity_label/{uniq_id}.w.label', '\n'.join(audacity_label_words))
logging.info(f"Writing {ROOT}/pred_rttms/{uniq_id}.w.label")
write_txt(f'{ROOT}/pred_rttms/{uniq_id}.w.label', '\n'.join(audacity_label_words))
@staticmethod
def clean_trans_and_TS(trans, char_ts):
@ -1119,51 +989,9 @@ class ASR_DIAR_OFFLINE(object):
char_ts = char_ts[: -1 * diff_R]
return trans, char_ts
@staticmethod
def write_VAD_rttm_from_speech_labels(ROOT, AUDIO_FILENAME, speech_labels):
"""Writes a VAD rttm file from speech_labels list.
"""
uniq_id = get_uniq_id_from_audio_path(AUDIO_FILENAME)
oracle_vad_dir = os.path.join(ROOT, 'oracle_vad')
with open(f'{oracle_vad_dir}/{uniq_id}.rttm', 'w') as f:
for spl in speech_labels:
start, end, speaker = spl.split()
start, end = float(start), float(end)
f.write("SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> speech <NA>\n".format(uniq_id, start, end - start))
@staticmethod
def write_VAD_rttm(oracle_vad_dir, audio_file_list, reference_rttmfile_list_path=None):
"""
Writes VAD files to the oracle_vad_dir folder.
Args:
oracle_vad_dir (str):
The path of oracle VAD folder.
audio_file_list (list):
List of audio file paths.
Return:
oracle_manifest (str):
Returns the full path of orcale_manifest.json file.
"""
if not reference_rttmfile_list_path:
reference_rttmfile_list_path = []
for path_name in audio_file_list:
uniq_id = get_uniq_id_from_audio_path(path_name)
reference_rttmfile_list_path.append(f'{oracle_vad_dir}/{uniq_id}.rttm')
oracle_manifest = os.path.join(oracle_vad_dir, 'oracle_manifest.json')
write_rttm2manifest(
paths2audio_files=audio_file_list,
paths2rttm_files=reference_rttmfile_list_path,
manifest_file=oracle_manifest,
)
return oracle_manifest
@staticmethod
def threshold_non_speech(source_list, params):
return list(filter(lambda x: x[1] - x[0] > params['threshold'], source_list))
return list(filter(lambda x: x[1] - x[0] > params['asr_based_vad_threshold'], source_list))
@staticmethod
def get_effective_WDER(DER_result_dict, WDER_dict):

View file

@ -497,6 +497,8 @@ def COSclustering(
max_num_speaker=8,
min_samples_for_NMESC=6,
enhanced_count_thres=80,
max_rp_threshold=0.25,
sparse_search_volume=30,
fixed_thres=None,
cuda=False,
):
@ -527,6 +529,16 @@ def COSclustering(
Thus, getEnhancedSpeakerCount() employs anchor embeddings (dummy representations)
to mitigate the effect of cluster sparsity.
enhanced_count_thres = 80 is recommended.
max_rp_threshold: (float)
Limits the range of parameter search.
Clustering performance can vary depending on this range.
Default is 0.25.
sparse_search_volume: (int)
The number of p_values we search during NME analysis.
Default is 30. The lower the value, the faster NME-analysis becomes.
Lower than 20 might cause a poor parameter estimation.
Returns:
Y: (List[int])
@ -546,10 +558,10 @@ def COSclustering(
nmesc = NMESC(
mat,
max_num_speaker=max_num_speaker,
max_rp_threshold=0.25,
max_rp_threshold=max_rp_threshold,
sparse_search=True,
sparse_search_volume=30,
fixed_thres=None,
sparse_search_volume=sparse_search_volume,
fixed_thres=fixed_thres,
NME_mat_size=300,
cuda=cuda,
)

View file

@ -14,13 +14,12 @@
import json
import math
import os
import pickle as pkl
from copy import deepcopy
import numpy as np
import soundfile as sf
import torch
from omegaconf import ListConfig
from pyannote.core import Annotation, Segment
from pyannote.core import Annotation, Segment, Timeline
from pyannote.metrics.diarization import DiarizationErrorRate
from tqdm import tqdm
@ -33,57 +32,49 @@ This file contains all the utility functions required for speaker embeddings par
"""
def audio_rttm_map(audio_file_list, rttm_file_list=None):
"""
Returns a AUDIO_TO_RTTM dictionary thats maps all the unique file names with
audio file paths and corresponding ground truth rttm files calculated from audio
file list and rttm file list
Args:
audio_file_list(list,str): either list of audio file paths or file containing paths to audio files (required)
rttm_file_list(lisr,str): either list of rttm file paths or file containing paths to rttm files (optional)
[Required if DER needs to be calculated]
Returns:
AUDIO_RTTM_MAP (dict): dictionary thats maps all the unique file names with
audio file paths and corresponding ground truth rttm files
"""
rttm_notfound = False
if type(audio_file_list) in [list, ListConfig]:
audio_files = audio_file_list
def get_uniqname_from_filepath(filepath):
"return base name from provided filepath"
if type(filepath) is str:
basename = os.path.basename(filepath).rsplit('.', 1)[0]
return basename
else:
audio_pointer = open(audio_file_list, 'r')
audio_files = audio_pointer.read().splitlines()
raise TypeError("input must be filepath string")
if rttm_file_list:
if type(rttm_file_list) in [list, ListConfig]:
rttm_files = rttm_file_list
else:
rttm_pointer = open(rttm_file_list, 'r')
rttm_files = rttm_pointer.read().splitlines()
else:
rttm_notfound = True
rttm_files = ['-'] * len(audio_files)
assert len(audio_files) == len(rttm_files)
def audio_rttm_map(manifest):
"""
This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings,
cluster and unify time stamps
input: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists
returns:
AUDIO_RTTM_MAP (dict) : Dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files
"""
AUDIO_RTTM_MAP = {}
rttm_dict = {}
audio_dict = {}
for audio_file, rttm_file in zip(audio_files, rttm_files):
uniq_audio_name = audio_file.split('/')[-1].rsplit('.', 1)[0]
uniq_rttm_name = rttm_file.split('/')[-1].rsplit('.', 1)[0]
with open(manifest, 'r') as inp_file:
lines = inp_file.readlines()
logging.info("Number of files to diarize: {}".format(len(lines)))
for line in lines:
line = line.strip()
dic = json.loads(line)
if rttm_notfound:
uniq_rttm_name = uniq_audio_name
meta = {
'audio_filepath': dic['audio_filepath'],
'rttm_filepath': dic.get('rttm_filepath', None),
'text': dic.get('text', '-'),
'num_speakers': dic.get('num_speakers', None),
'uem_filepath': dic.get('uem_filepath'),
}
audio_dict[uniq_audio_name] = audio_file
rttm_dict[uniq_rttm_name] = rttm_file
uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath'])
for key, value in audio_dict.items():
AUDIO_RTTM_MAP[key] = {'audio_path': audio_dict[key], 'rttm_path': rttm_dict[key]}
assert len(rttm_dict.items()) == len(audio_dict.items())
if uniqname not in AUDIO_RTTM_MAP:
AUDIO_RTTM_MAP[uniqname] = meta
else:
raise KeyError(
"file {} is already part AUDIO_RTTM_Map, it might be duplicated".format(meta['audio_filepath'])
)
return AUDIO_RTTM_MAP
@ -128,11 +119,11 @@ def merge_stamps(lines):
return overlap_stamps
def labels_to_pyannote_object(labels):
def labels_to_pyannote_object(labels, uniq_name=''):
"""
converts labels to pyannote object to calculate DER and for visualization
"""
annotation = Annotation()
annotation = Annotation(uri=uniq_name)
for label in labels:
start, end, speaker = label.strip().split()
start, end = float(start), float(end)
@ -141,6 +132,24 @@ def labels_to_pyannote_object(labels):
return annotation
def uem_timeline_from_file(uem_file, uniq_name=''):
"""
outputs pyannote timeline segments for uem file
<UEM> file format
UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME
"""
timeline = Timeline(uri=uniq_name)
with open(uem_file, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
speaker_id, channel, start_time, end_time = line.split()
timeline.add(Segment(float(start_time), float(end_time)))
return timeline
def labels_to_rttmfile(labels, uniq_id, out_rttm_dir):
"""
write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels
@ -155,6 +164,8 @@ def labels_to_rttmfile(labels, uniq_id, out_rttm_dir):
log = 'SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> {} <NA> <NA>\n'.format(uniq_id, start, duration, speaker)
f.write(log)
return filename
def rttm_to_labels(rttm_filename):
"""
@ -169,94 +180,52 @@ def rttm_to_labels(rttm_filename):
return labels
def get_time_stamps(embeddings_file, reco2num, manifest_path):
"""
Loads embedding file and generates time stamps based on window and shift for speaker embeddings
clustering.
Args:
embeddings_file (str): Path to embeddings pickle file
reco2num (int,str): common integer number of speakers for every recording or path to
file that states unique_file id with number of speakers
manifest_path (str): path to manifest for time stamps matching
sample_rate (int): sample rate
window (float): window length in sec for speaker embeddings extraction
shift (float): shift length in sec for speaker embeddings extraction window shift
Returns:
embeddings (dict): Embeddings with key as unique_id
time_stamps (dict): time stamps list for each audio recording
speakers (dict): number of speaker for each audio recording
"""
embeddings = pkl.load(open(embeddings_file, 'rb'))
all_uniq_files = list(embeddings.keys())
num_files = len(all_uniq_files)
logging.info("Number of files to diarize: {}".format(num_files))
sample = all_uniq_files[0]
logging.info("sample '{}' embeddings shape is {}\n".format(sample, embeddings[sample][0].shape))
speakers = {}
if isinstance(reco2num, int) or reco2num is None:
for key in embeddings.keys():
speakers[key] = reco2num
elif isinstance(reco2num, str) and os.path.exists(reco2num):
for key in open(reco2num).readlines():
key = key.strip()
wav_id, num = key.split()
speakers[wav_id] = int(num)
else:
raise TypeError("reco2num must be None, int or path to file containing number of speakers for each uniq id")
with open(manifest_path, 'r') as manifest:
time_stamps = {}
for line in manifest.readlines():
line = line.strip()
line = json.loads(line)
audio, offset, duration = line['audio_filepath'], line['offset'], line['duration']
audio = os.path.basename(audio).rsplit('.', 1)[0]
if audio not in time_stamps:
time_stamps[audio] = []
start = offset
end = start + duration
stamp = '{:.3f} {:.3f} '.format(start, end)
time_stamps[audio].append(stamp)
return embeddings, time_stamps, speakers
def perform_clustering(embeddings, time_stamps, speakers, audio_rttm_map, out_rttm_dir, max_num_speakers=8):
def perform_clustering(embeddings, time_stamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params):
"""
performs spectral clustering on embeddings with time stamps generated from VAD output
Args:
embeddings (dict): Embeddings with key as unique_id
time_stamps (dict): time stamps list for each audio recording
speakers (dict): number of speaker for each audio recording
audio_rttm_map (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path
AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path
out_rttm_dir (str): Path to write predicted rttms
max_num_speakers (int): maximum number of speakers to consider for spectral clustering. Will be ignored if speakers['key'] is not None
clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int),
oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int)
Returns:
all_reference (list[Annotation]): reference annotations for score calculation
all_hypothesis (list[Annotation]): hypothesis annotations for score calculation
all_reference (list[uniq_name,Annotation]): reference annotations for score calculation
all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation
"""
all_hypothesis = []
all_reference = []
no_references = False
max_num_speakers = clustering_params['max_num_speakers']
cuda = True
if not torch.cuda.is_available():
logging.warning("cuda=False, using CPU for Eigen decompostion. This might slow down the clustering process.")
cuda = False
for uniq_key in tqdm(embeddings.keys()):
NUM_speakers = speakers[uniq_key]
for uniq_key, value in tqdm(AUDIO_RTTM_MAP.items()):
if clustering_params.oracle_num_speakers:
num_speakers = value.get('num_speakers', None)
if num_speakers is None:
raise ValueError("Provided option as oracle num of speakers but num_speakers in manifest is null")
else:
num_speakers = None
emb = embeddings[uniq_key]
emb = np.asarray(emb)
cluster_labels = COSclustering(
uniq_key, emb, oracle_num_speakers=NUM_speakers, max_num_speaker=max_num_speakers, cuda=cuda,
uniq_key,
emb,
oracle_num_speakers=num_speakers,
max_num_speaker=max_num_speakers,
enhanced_count_thres=clustering_params.enhanced_count_thres,
max_rp_threshold=clustering_params.max_rp_threshold,
sparse_search_volume=clustering_params.sparse_search_volume,
cuda=cuda,
)
lines = time_stamps[uniq_key]
@ -269,14 +238,14 @@ def perform_clustering(embeddings, time_stamps, speakers, audio_rttm_map, out_rt
labels = merge_stamps(a)
if out_rttm_dir:
labels_to_rttmfile(labels, uniq_key, out_rttm_dir)
hypothesis = labels_to_pyannote_object(labels)
all_hypothesis.append(hypothesis)
hypothesis = labels_to_pyannote_object(labels, uniq_name=uniq_key)
all_hypothesis.append([uniq_key, hypothesis])
rttm_file = audio_rttm_map[uniq_key]['rttm_path']
if os.path.exists(rttm_file) and not no_references:
rttm_file = value.get('rttm_filepath', None)
if rttm_file is not None and os.path.exists(rttm_file) and not no_references:
ref_labels = rttm_to_labels(rttm_file)
reference = labels_to_pyannote_object(ref_labels)
all_reference.append(reference)
reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_key)
all_reference.append([uniq_key, reference])
else:
no_references = True
all_reference = []
@ -284,19 +253,18 @@ def perform_clustering(embeddings, time_stamps, speakers, audio_rttm_map, out_rt
return all_reference, all_hypothesis
def get_DER(all_reference, all_hypothesis, collar=0.5, skip_overlap=True):
def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True):
"""
calculates DER, CER, FA and MISS
Args:
all_reference (list[Annotation]): reference annotations for score calculation
all_hypothesis (list[Annotation]): hypothesis annotations for score calculation
AUDIO_RTTM_MAP : Dictionary containing information provided from manifestpath
all_reference (list[uniq_name,Annotation]): reference annotations for score calculation
all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation
Returns:
DER (float): Diarization Error Rate
CER (float): Confusion Error Rate
FA (float): False Alarm
Miss (float): Miss Detection
metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile.
mapping (dict): Mapping dict containing the mapping speaker label for each audio input
< Caveat >
Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of
@ -304,93 +272,113 @@ def get_DER(all_reference, all_hypothesis, collar=0.5, skip_overlap=True):
collar in md-eval.pl, 0.5s should be applied for pyannote.metrics.
"""
metric = DiarizationErrorRate(collar=collar, skip_overlap=skip_overlap, uem=None)
metric = None
if len(all_reference) == len(all_hypothesis):
metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap)
mapping_dict = {}
for k, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)):
metric(reference, hypothesis, detailed=True)
mapping_dict[k] = metric.optimal_mapping(reference, hypothesis)
mapping_dict = {}
for (reference, hypothesis) in zip(all_reference, all_hypothesis):
ref_key, ref_labels = reference
_, hyp_labels = hypothesis
uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None)
if uem is not None:
uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key)
metric(ref_labels, hyp_labels, uem=uem, detailed=True)
mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels)
DER = abs(metric)
CER = metric['confusion'] / metric['total']
FA = metric['false alarm'] / metric['total']
MISS = metric['missed detection'] / metric['total']
DER = abs(metric)
CER = metric['confusion'] / metric['total']
FA = metric['false alarm'] / metric['total']
MISS = metric['missed detection'] / metric['total']
metric.reset()
return DER, CER, FA, MISS, mapping_dict
def perform_diarization(
embeddings_file=None, reco2num=2, manifest_path=None, audio_rttm_map=None, out_rttm_dir=None, max_num_speakers=8,
):
"""
Performs diarization with embeddings generated based on VAD time stamps with recording 2 num of speakers (reco2num)
for spectral clustering
"""
embeddings, time_stamps, speakers = get_time_stamps(embeddings_file, reco2num, manifest_path)
logging.info("Performing Clustering")
all_reference, all_hypothesis = perform_clustering(
embeddings, time_stamps, speakers, audio_rttm_map, out_rttm_dir, max_num_speakers
)
if len(all_reference) and len(all_hypothesis):
DER, CER, FA, MISS, _ = get_DER(all_reference, all_hypothesis)
logging.info(
"Cumulative results of all the files: \n FA: {:.4f}\t MISS {:.4f}\t \
"Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \
Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format(
FA, MISS, DER, CER
collar, ignore_overlap, FA, MISS, DER, CER
)
)
return metric, mapping_dict
else:
logging.warning("Please check if each ground truth RTTMs was present in provided path2groundtruth_rttm_files")
logging.warning("Skipping calculation of Diariazation Error Rate")
logging.warning(
"check if each ground truth RTTMs were present in provided manifest file. Skipping calculation of Diariazation Error Rate"
)
return None
def write_rttm2manifest(paths2audio_files, paths2rttm_files, manifest_file):
def write_rttm2manifest(AUDIO_RTTM_MAP, manifest_file):
"""
writes manifest file based on rttm files (or vad table out files). This manifest file would be used by
speaker diarizer to compute embeddings and cluster them. This function also takes care of overlap time stamps
Args:
audio_file_list(list,str): either list of audio file paths or file containing paths to audio files (required)
rttm_file_list(lisr,str): either list of rttm file paths or file containing paths to rttm files (optional)
AUDIO_RTTM_MAP: dict containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents,
these are used to extract oracle vad timestamps.
manifest (str): path to write manifest file
Returns:
manifest (str): path to write manifest file
"""
AUDIO_RTTM_MAP = audio_rttm_map(paths2audio_files, paths2rttm_files)
with open(manifest_file, 'w') as outfile:
for key in AUDIO_RTTM_MAP:
f = open(AUDIO_RTTM_MAP[key]['rttm_path'], 'r')
audio_path = AUDIO_RTTM_MAP[key]['audio_path']
rttm_filename = AUDIO_RTTM_MAP[key]['rttm_filepath']
if rttm_filename and os.path.exists(rttm_filename):
f = open(rttm_filename, 'r')
else:
raise FileNotFoundError(
"Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format(
rttm_filename
)
)
audio_path = AUDIO_RTTM_MAP[key]['audio_filepath']
if AUDIO_RTTM_MAP[key].get('duration', None):
max_duration = AUDIO_RTTM_MAP[key]['duration']
else:
sound = sf.SoundFile(audio_path)
max_duration = sound.frames / sound.samplerate
lines = f.readlines()
time_tup = (-1, -1)
for line in lines:
vad_out = line.strip().split()
if len(vad_out) > 3:
start, dur, activity = float(vad_out[3]), float(vad_out[4]), vad_out[7]
start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7]
else:
start, dur, activity = float(vad_out[0]), float(vad_out[1]), vad_out[2]
start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2]
start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur))
if time_tup[0] >= 0 and start > time_tup[1]:
dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0]))
meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'}
json.dump(meta, outfile)
outfile.write("\n")
if time_tup[0] < max_duration and dur2 > 0:
meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'}
json.dump(meta, outfile)
outfile.write("\n")
else:
logging.warning(
"RTTM label has been truncated since start is greater than duration of audio file"
)
time_tup = (start, start + dur)
else:
if time_tup[0] == -1:
time_tup = (start, start + dur)
end_time = start + dur
if end_time > max_duration:
end_time = max_duration
time_tup = (start, end_time)
else:
time_tup = (min(time_tup[0], start), max(time_tup[1], start + dur))
end_time = max(time_tup[1], start + dur)
if end_time > max_duration:
end_time = max_duration
time_tup = (min(time_tup[0], start), end_time)
dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0]))
meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'}
json.dump(meta, outfile)
outfile.write("\n")
if time_tup[0] < max_duration and dur2 > 0:
meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'}
json.dump(meta, outfile)
outfile.write("\n")
else:
logging.warning("RTTM label has been truncated since start is greater than duration of audio file")
f.close()
return manifest_file

View file

@ -0,0 +1,105 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import random
from collections import Counter
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels
random.seed(42)
"""
This script creates manifest file for speaker diarization inference purposes.
Useful to get manifest when you have list of audio files and optionally rttm and uem files for evaluation
Note: make sure basename for each file is unique and rttm files also has the corresponding base name for mapping
"""
def write_file(name, lines, idx):
with open(name, 'w') as fout:
for i in idx:
dic = lines[i]
json.dump(dic, fout)
fout.write('\n')
logging.info("wrote", name)
def read_file(scp):
scp = open(scp, 'r').readlines()
return sorted(scp)
def main(wav_scp, rttm_scp=None, uem_scp=None, manifest_filepath=None):
if os.path.exists(manifest_filepath):
os.remove(manifest_filepath)
wav_scp = read_file(wav_scp)
if rttm_scp is not None:
rttm_scp = read_file(rttm_scp)
else:
rttm_scp = len(wav_scp) * [None]
if uem_scp is not None:
uem_scp = read_file(uem_scp)
else:
uem_scp = len(wav_scp) * [None]
lines = []
for wav, rttm, uem in zip(wav_scp, rttm_scp, uem_scp):
audio_line = wav.strip()
if rttm is not None:
rttm = rttm.strip()
labels = rttm_to_labels(rttm)
num_speakers = Counter([l.split()[-1] for l in labels]).keys().__len__()
else:
num_speakers = None
if uem is not None:
uem = uem.strip()
meta = [
{
"audio_filepath": audio_line,
"offset": 0,
"duration": None,
"label": "infer",
"text": "-",
"num_speakers": num_speakers,
"rttm_filepath": rttm,
"uem_filepath": uem,
}
]
lines.extend(meta)
write_file(manifest_filepath, lines, range(len(lines)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--paths2audio_files", help="path to text file containing list of audio files", type=str, required=True
)
parser.add_argument("--paths2rttm_files", help="path to text file containing list of rttm files", type=str)
parser.add_argument("--paths2uem_files", help="path to uem files", type=str)
parser.add_argument("--manifest_filepath", help="scp file name", type=str, required=True)
args = parser.parse_args()
main(args.paths2audio_files, args.paths2rttm_files, args.paths2uem_files, args.manifest_filepath)

View file

@ -1,40 +0,0 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from nemo.collections.asr.parts.utils.speaker_utils import write_rttm2manifest
from nemo.utils import logging
"""
This file converts vad outputs to manifest file for speaker diarization purposes
present in vad output directory.
every vad line consists of start_time, end_time , speech/non-speech
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--paths2rttm_files", help="path to vad output rttm-like files. Could be a list or a text file", required=True
)
parser.add_argument(
"--paths2audio_files",
help="path to audio files that vad was computed. Could be a list or a text file",
required=True,
)
parser.add_argument("--manifest_file", help="output manifest file name", type=str, required=True)
args = parser.parse_args()
write_rttm2manifest(args.paths2audio_files, args.paths2rttm_files, args.manifest_file)
logging.info("wrote {} file from vad output files present in {}".format(args.manifest_file, args.paths2rttm_files))

View file

@ -99,18 +99,6 @@
"pp = pprint.PrettyPrinter(indent=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_file(path_to_file):\n",
" with open(path_to_file) as f:\n",
" contents = f.read().splitlines()\n",
" return contents"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -127,6 +115,7 @@
"ROOT = os.getcwd()\n",
"data_dir = os.path.join(ROOT,'data')\n",
"os.makedirs(data_dir, exist_ok=True)\n",
"\n",
"an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n",
"if not os.path.exists(os.path.join(data_dir,'an4_diarize_test.wav')):\n",
" AUDIO_FILENAME = wget.download(an4_audio_url, data_dir)\n",
@ -213,21 +202,99 @@
"metadata": {},
"outputs": [],
"source": [
"CONFIG_URL = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml\"\n",
"from omegaconf import OmegaConf\n",
"import shutil\n",
"CONFIG_URL = \"https://raw.githubusercontent.com/NVIDIA/NeMo/modify_speaker_input/examples/speaker_tasks/diarization/conf/offline_diarization_with_asr.yaml\"\n",
"\n",
"params = {\n",
" \"round_float\": 2,\n",
" \"window_length_in_sec\": 1.0,\n",
" \"shift_length_in_sec\": 0.25,\n",
" \"fix_word_ts_with_VAD\": False,\n",
" \"print_transcript\": False,\n",
" \"word_gap_in_sec\": 0.01,\n",
" \"max_word_ts_length_in_sec\": 0.6,\n",
" \"minimum\": True,\n",
" \"threshold\": 300, \n",
" \"diar_config_url\": CONFIG_URL,\n",
" \"ASR_model_name\": 'QuartzNet15x5Base-En',\n",
"}\n"
"if not os.path.exists(os.path.join(data_dir,'offline_diarization_with_asr.yaml')):\n",
" CONFIG = wget.download(CONFIG_URL, data_dir)\n",
"else:\n",
" CONFIG = os.path.join(data_dir,'offline_diarization_with_asr.yaml')\n",
" \n",
"cfg = OmegaConf.load(CONFIG)\n",
"print(OmegaConf.to_yaml(cfg))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Speaker Diarization scripts commonly expects following arguments:\n",
"1. manifest_filepath : Path to manifest file containing json lines of format: {\"audio_filepath\": \"/path/to/audio_file\", \"offset\": 0, \"duration\": null, \"label\": \"infer\", \"text\": \"-\", \"num_speakers\": null, \"rttm_filepath\": \"/path/to/rttm/file\", \"uem_filepath\"=\"/path/to/uem/filepath\"}\n",
"2. out_dir : directory where outputs and intermediate files are stored. \n",
"3. oracle_vad: If this is true then we extract speech activity labels from rttm files, if False then either \n",
"4. vad.model_path or external_manifestpath containing speech activity labels has to be passed. \n",
"\n",
"Mandatory fields are audio_filepath, offset, duration, label and text. For the rest if you would like to evaluate with known number of speakers pass the value else `null`. If you would like to score the system with known rttms then that should be passed as well, else `null`. uem file is used to score only part of your audio for evaluation purposes, hence pass if you would like to evaluate on it else `null`.\n",
"\n",
"\n",
"**Note:** we expect audio and corresponding RTTM have **same base name** and the name should be **unique**. \n",
"\n",
"For example: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets create manifest with the an4 audio and rttm available. If you have more than one files you may also use the script `NeMo/scripts/speaker_tasks/rttm_to_manifest.py` to generate manifest file from list of audio files and optionally rttm files "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a manifest for input with below format. \n",
"# {\"audio_filepath\": \"/path/to/audio_file\", \"offset\": 0, \"duration\": null, \"label\": \"infer\", \"text\": \"-\", \n",
"# \"num_speakers\": null, \"rttm_filepath\": \"/path/to/rttm/file\", \"uem_filepath\"=\"/path/to/uem/filepath\"}\n",
"import json\n",
"meta = {\n",
" 'audio_filepath': AUDIO_FILENAME, \n",
" 'offset': 0, \n",
" 'duration':None, \n",
" 'label': 'infer', \n",
" 'text': '-', \n",
" 'num_speakers': 2, \n",
" 'rttm_filepath': None, \n",
" 'uem_filepath' : None\n",
"}\n",
"with open(os.path.join(data_dir,'input_manifest.json'),'w') as fp:\n",
" json.dump(meta,fp)\n",
" fp.write('\\n')\n",
"\n",
"\n",
"cfg.diarizer.manifest_filepath = os.path.join(data_dir,'input_manifest.json')\n",
"!cat {cfg.diarizer.manifest_filepath}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the parameters required for diarization, here we get voice activity labels from ASR, which is set through parameter `cfg.diarizer.asr.parameters.asr_based_vad`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pretrained_speaker_model='ecapa_tdnn'\n",
"cfg.diarizer.manifest_filepath = cfg.diarizer.manifest_filepath\n",
"cfg.diarizer.out_dir = data_dir #Directory to store intermediate files and prediction outputs\n",
"cfg.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
"cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5\n",
"cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75\n",
"cfg.diarizer.clustering.parameters.oracle_num_speakers=True\n",
"\n",
"# USE VAD generated from ASR timestamps\n",
"cfg.diarizer.asr.model_path = 'QuartzNet15x5Base-En'\n",
"cfg.diarizer.oracle_vad = False # ----> ORACLE VAD \n",
"cfg.diarizer.asr.parameters.asr_based_vad = True\n",
"cfg.diarizer.asr.parameters.threshold=300"
]
},
{
@ -243,35 +310,13 @@
"metadata": {},
"outputs": [],
"source": [
"asr_diar_offline = ASR_DIAR_OFFLINE(params)\n",
"asr_model = asr_diar_offline.set_asr_model(params['ASR_model_name'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will create folders that we need for storing VAD stamps and ASR/diarization results.\n",
"from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map\n",
"asr_diar_offline = ASR_DIAR_OFFLINE(**cfg.diarizer.asr.parameters)\n",
"asr_diar_offline.root_path = cfg.diarizer.out_dir\n",
"\n",
"Under the folder named ``asr_with_diar``, the following folders will be created.\n",
"\n",
"- ``oracle_vad``\n",
"- ``json_result``\n",
"- ``transcript_with_speaker_labels``\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"asr_diar_offline.create_directories()\n",
"\n",
"print(\"Folders are created as below.\")\n",
"print(\"VAD file path: \\n\", asr_diar_offline.oracle_vad_dir)\n",
"print(\"JSON result path: \\n\", asr_diar_offline.json_result_dir)\n",
"print(\"Transcript result path: \\n\", asr_diar_offline.trans_with_spks_dir)"
"AUDIO_RTTM_MAP = audio_rttm_map(cfg.diarizer.manifest_filepath)\n",
"asr_diar_offline.AUDIO_RTTM_MAP = AUDIO_RTTM_MAP\n",
"asr_model = asr_diar_offline.set_asr_model(cfg.diarizer.asr.model_path)"
]
},
{
@ -296,7 +341,7 @@
"metadata": {},
"outputs": [],
"source": [
"word_list, word_ts_list = asr_diar_offline.run_ASR(audio_file_list, asr_model)\n",
"word_list, word_ts_list = asr_diar_offline.run_ASR(asr_model)\n",
"\n",
"print(\"Decoded word output: \\n\", word_list[0])\n",
"print(\"Word-level timestamps \\n\", word_ts_list[0])"
@ -325,23 +370,14 @@
"metadata": {},
"outputs": [],
"source": [
"oracle_manifest = 'asr_based_vad'\n",
"# If we know the number of speakers, we can assign \"2\".\n",
"num_speakers = None\n",
"speaker_embedding_model = 'ecapa_tdnn'\n",
"\n",
"diar_labels = asr_diar_offline.run_diarization(audio_file_list, \n",
" word_ts_list, \n",
" oracle_manifest=oracle_manifest, \n",
" oracle_num_speakers=num_speakers,\n",
" pretrained_speaker_model=speaker_embedding_model)"
"score = asr_diar_offline.run_diarization(cfg, word_ts_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`run_diarization()` function creates `./asr_with_diar/oracle_vad/pred_rttm/an4_diarize_test.rttm` file. Let's see what is written in this `rttm` file."
"`run_diarization()` function creates `an4_diarize_test.rttm` file. Let's see what is written in this `rttm` file."
]
},
{
@ -350,29 +386,34 @@
"metadata": {},
"outputs": [],
"source": [
"predicted_speaker_label_rttm_path = f\"{ROOT}/asr_with_diar/oracle_vad/pred_rttms/an4_diarize_test.rttm\"\n",
"def read_file(path_to_file):\n",
" with open(path_to_file) as f:\n",
" contents = f.read().splitlines()\n",
" return contents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predicted_speaker_label_rttm_path = f\"{data_dir}/pred_rttms/an4_diarize_test.rttm\"\n",
"pred_rttm = read_file(predicted_speaker_label_rttm_path)\n",
"\n",
"pp.pprint(pred_rttm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`run_diarization()` also returns a variable named `diar_labels` which contains the estimated speaker label information with timestamps from the predicted rttm file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Diarization Labels:\")\n",
"pp.pprint(diar_labels[0])\n",
"from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels\n",
"pred_labels = rttm_to_labels(predicted_speaker_label_rttm_path)\n",
"\n",
"color = get_color(signal, diar_labels[0])\n",
"color = get_color(signal, pred_labels)\n",
"display_waveform(signal,'Audio with Speaker Labels', color)\n",
"display(Audio(signal,rate=16000))"
]
@ -391,14 +432,14 @@
"metadata": {},
"outputs": [],
"source": [
"asr_output_dict = asr_diar_offline.write_json_and_transcript(audio_file_list, diar_labels, word_list, word_ts_list)"
"asr_output_dict = asr_diar_offline.write_json_and_transcript(word_list, word_ts_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After running `write_json_and_transcript()` function, the transcription output will be located in `./asr_with_diar/transcript_with_speaker_labels` folder, which shows **start time to end time of the utterance, speaker ID, and words spoken** during the notified time."
"After running `write_json_and_transcript()` function, the transcription output will be located in `./pred_rttms` folder, which shows **start time to end time of the utterance, speaker ID, and words spoken** during the notified time."
]
},
{
@ -407,7 +448,7 @@
"metadata": {},
"outputs": [],
"source": [
"transcription_path_to_file = f\"{ROOT}/asr_with_diar/transcript_with_speaker_labels/an4_diarize_test.txt\"\n",
"transcription_path_to_file = f\"{data_dir}/pred_rttms/an4_diarize_test.txt\"\n",
"transcript = read_file(transcription_path_to_file)\n",
"pp.pprint(transcript)"
]
@ -416,7 +457,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Another output is transcription output in JSON format, which is saved in `./asr_with_diar/json_result`. \n",
"Another output is transcription output in JSON format, which is saved in `./pred_rttms/an4_diarize_test.json`. \n",
"\n",
"In the JSON format output, we include information such as **transcription, estimated number of speakers (variable named `speaker_count`), start and end time of each word and most importantly, speaker label for each word.**"
]
@ -427,7 +468,7 @@
"metadata": {},
"outputs": [],
"source": [
"transcription_path_to_file = f\"{ROOT}/asr_with_diar/json_result/an4_diarize_test.json\"\n",
"transcription_path_to_file = f\"{data_dir}/pred_rttms/an4_diarize_test.json\"\n",
"json_contents = read_file(transcription_path_to_file)\n",
"pp.pprint(json_contents)"
]
@ -449,7 +490,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
"version": "3.8.2"
}
},
"nbformat": 4,

View file

@ -151,32 +151,25 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Speaker Diarization scripts commonly expects two files:\n",
"1. paths2audio_files : either list of audio file paths or file containing paths to audio files for which we need to perform diarization.\n",
"2. path2groundtruth_rttm_files (optional): either list of rttm file paths or file containing paths to rttm files (this can be passed if we need to calculate DER rate based on our ground truth rttm files).\n",
"Speaker Diarization scripts commonly expects following arguments:\n",
"1. manifest_filepath : Path to manifest file containing json lines of format: {'audio_filepath': /path/to/audio_file, 'offset': 0, 'duration':None, 'label': 'infer', 'text': '-', 'num_speakers': None, 'rttm_filepath': /path/to/rttm/file, 'uem_filepath'='/path/to/uem/filepath'}\n",
"2. out_dir : directory where outputs and intermediate files are stored. \n",
"3. oracle_vad: If this is true then we extract speech activity labels from rttm files, if False then either \n",
"4. vad.model_path or external_manifestpath containing speech activity labels has to be passed. \n",
"\n",
"Mandatory fields are audio_filepath, offset, duration, label and text. For the rest if you would like to evaluate with known number of speakers pass the value else None. If you would like to score the system with known rttms then that should be passed as well, else None. uem file is used to score only part of your audio for evaluation purposes, hence pass if you would like to evaluate on it else None.\n",
"\n",
"\n",
"**Note** we expect audio and corresponding RTTM have **same base name** and the name should be **unique**. \n",
"\n",
"For eg: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n",
"\n",
"Now let's create paths2audio_files list (or file) for which we need to perform diarization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"paths2audio_files = [an4_audio]\n",
"print(paths2audio_files)"
"For eg: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly create` path2groundtruth_rttm_files` list (this is optional, and needed for score calculation)"
"Lets create manifest with the an4 audio and rttm available. If you have more than one files you may also use the script `pathsfiles_to_manifest.py` to generate manifest file from list of audio files and optionally rttm files "
]
},
{
@ -185,8 +178,28 @@
"metadata": {},
"outputs": [],
"source": [
"path2groundtruth_rttm_files = [an4_rttm]\n",
"print(path2groundtruth_rttm_files)"
"# Create a manifest for input with below format. \n",
"# {'audio_filepath': /path/to/audio_file, 'offset': 0, 'duration':None, 'label': 'infer', 'text': '-', \n",
"# 'num_speakers': None, 'rttm_filepath': /path/to/rttm/file, 'uem_filepath'='/path/to/uem/filepath'}\n",
"import json\n",
"meta = {\n",
" 'audio_filepath': an4_audio, \n",
" 'offset': 0, \n",
" 'duration':None, \n",
" 'label': 'infer', \n",
" 'text': '-', \n",
" 'num_speakers': 2, \n",
" 'rttm_filepath': an4_rttm, \n",
" 'uem_filepath' : None\n",
"}\n",
"with open('data/input_manifest.json','w') as fp:\n",
" json.dump(meta,fp)\n",
" fp.write('\\n')\n",
"\n",
"!cat data/input_manifest.json\n",
"\n",
"output_dir = os.path.join(ROOT, 'oracle_vad')\n",
"os.makedirs(output_dir,exist_ok=True)"
]
},
{
@ -203,39 +216,8 @@
"Oracle-vad diarization is to compute speaker embeddings from known speech label timestamps rather than depending on VAD output. This step can also be used to run speaker diarization with rttms generated from any external VAD, not just VAD model from NeMo.\n",
"\n",
"For it, the first step is to start converting reference audio rttm(vad) time stamps to oracle manifest file. This manifest file would be sent to our speaker diarizer to extract embeddings.\n",
"For that let's use write_rttm2manifest function, that takes paths2audio_files and paths2rttm_files as arguments "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.asr.parts.utils.speaker_utils import write_rttm2manifest\n",
"output_dir = os.path.join(ROOT, 'oracle_vad')\n",
"os.makedirs(output_dir,exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"oracle_manifest = os.path.join(output_dir,'oracle_manifest.json')\n",
"write_rttm2manifest(paths2audio_files=paths2audio_files,\n",
" paths2rttm_files=path2groundtruth_rttm_files,\n",
" manifest_file=oracle_manifest)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cat {oracle_manifest}"
"\n",
"This is just an argument in our config, and system automatically computes oracle manifest based on the rttms provided through input manifest file"
]
},
{
@ -253,10 +235,11 @@
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')\n",
"MODEL_CONFIG = os.path.join(data_dir,'offline_diarization.yaml')\n",
"if not os.path.exists(MODEL_CONFIG):\n",
" config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml\"\n",
" config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/modify_speaker_input/examples/speaker_tasks/diarization/conf/offline_diarization.yaml\"\n",
" MODEL_CONFIG = wget.download(config_url,data_dir)\n",
"\n",
"config = OmegaConf.load(MODEL_CONFIG)\n",
"print(OmegaConf.to_yaml(config))"
]
@ -275,14 +258,14 @@
"outputs": [],
"source": [
"pretrained_speaker_model='ecapa_tdnn'\n",
"config.diarizer.paths2audio_files = paths2audio_files\n",
"config.diarizer.path2groundtruth_rttm_files = path2groundtruth_rttm_files\n",
"config.diarizer.manifest_filepath = 'data/input_manifest.json'\n",
"config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs\n",
"\n",
"config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
"# Ignoring vad we just need to pass the manifest file we created\n",
"config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest\n",
"config.diarizer.oracle_num_speakers=2"
"config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5\n",
"config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75\n",
"config.diarizer.oracle_vad = True # ----> ORACLE VAD \n",
"config.diarizer.clustering.parameters.oracle_num_speakers = True"
]
},
{
@ -352,7 +335,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this method we compute VAD time stamps using NeMo VAD model on `paths2audio_files` and then use these time stamps of speech label to find speaker embeddings followed by clustering them into num of speakers"
"In this method we compute VAD time stamps using NeMo VAD model on input manifest file and then use these time stamps of speech label to find speaker embeddings followed by clustering them into num of speakers"
]
},
{
@ -413,20 +396,23 @@
"outputs": [],
"source": [
"output_dir = os.path.join(ROOT,'outputs')\n",
"config.diarizer.paths2audio_files = paths2audio_files\n",
"config.diarizer.path2groundtruth_rttm_files = path2groundtruth_rttm_files\n",
"config.diarizer.out_dir = output_dir # Directory to store intermediate files and prediction outputs\n",
"config.diarizer.manifest_filepath = 'data/input_manifest.json'\n",
"config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs\n",
"\n",
"config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
"config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5\n",
"config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75\n",
"config.diarizer.oracle_vad = False # compute VAD provided with model_path to vad config\n",
"config.diarizer.clustering.parameters.oracle_num_speakers=True\n",
"\n",
"#Here we use our inhouse pretrained NeMo VAD \n",
"config.diarizer.vad.model_path = pretrained_vad\n",
"config.diarizer.vad.window_length_in_sec = 0.15\n",
"config.diarizer.vad.shift_length_in_sec = 0.01\n",
"# config.diarizer.vad.threshold = 0.8 threshold would be deprecated in release 1.5\n",
"config.diarizer.vad.postprocessing_params.onset = 0.8 \n",
"config.diarizer.vad.postprocessing_params.offset = 0.7\n",
"config.diarizer.vad.postprocessing_params.min_duration_on = 0.1\n",
"config.diarizer.vad.postprocessing_params.min_duration_off = 0.3"
"config.diarizer.vad.parameters.onset = 0.8 \n",
"config.diarizer.vad.parameters.offset = 0.6\n",
"config.diarizer.vad.parameters.min_duration_on = 0.1\n",
"config.diarizer.vad.parameters.min_duration_off = 0.4"
]
},
{
@ -490,13 +476,13 @@
"# you can also use single threshold(=onset=offset) for binarization and plot here\n",
"from nemo.collections.asr.parts.utils.vad_utils import plot\n",
"plot(\n",
" paths2audio_files[0],\n",
" an4_audio,\n",
" 'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n",
" path2groundtruth_rttm_files[0],\n",
" per_args = config.diarizer.vad.postprocessing_params, #threshold\n",
" an4_rttm,\n",
" per_args = config.diarizer.vad.parameters, #threshold\n",
" ) \n",
"\n",
"print(f\"postprocessing_params: {config.diarizer.vad.postprocessing_params}\")"
"print(f\"postprocessing_params: {config.diarizer.vad.parameters}\")"
]
},
{
@ -599,14 +585,14 @@
"outputs": [],
"source": [
"quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=\"QuartzNet15x5Base-En\")\n",
"for fname, transcription in zip(paths2audio_files, quartznet.transcribe(paths2audio_files=paths2audio_files)):\n",
" print(f\"Audio in {fname} was recognized as: {transcription}\")"
"for fname, transcription in zip([an4_audio], quartznet.transcribe(paths2audio_files=[an4_audio])):\n",
" print(f\"Audio in {fname} was recognized as:\\n{transcription}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -620,7 +606,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.2"
},
"pycharm": {
"stem_cell": {