From dc9ed88f784c662022ad85224f8c5f6e8f4767e9 Mon Sep 17 00:00:00 2001 From: Nithin Rao Date: Sat, 6 Nov 2021 07:55:32 -0700 Subject: [PATCH] Modify speaker input (#3100) * initial_commit Signed-off-by: nithinraok * init diarizer Signed-off-by: nithinraok * vad+speaker Signed-off-by: nithinraok * vad update Signed-off-by: nithinraok * speaker done Signed-off-by: nithinraok * initial working version Signed-off-by: nithinraok * compare outputs Signed-off-by: nithinraok * added uem support Signed-off-by: nithinraok * pyannote improvements Signed-off-by: nithinraok * updated config and script name Signed-off-by: nithinraok * style fix Signed-off-by: nithinraok * update Jenkins file Signed-off-by: nithinraok * jenkins fix Signed-off-by: nithinraok * jenkins fix Signed-off-by: nithinraok * update file path in jenkins Signed-off-by: nithinraok * update file path in jenkins Signed-off-by: nithinraok * update file path in jenkins Signed-off-by: nithinraok * jenkins quote fix Signed-off-by: nithinraok * update offline speaker diarization notebook Signed-off-by: nithinraok * intial working asr_with_diarization Signed-off-by: nithinraok * almost done, revist scoring part Signed-off-by: nithinraok * fixed eval in offline diarization with asr Signed-off-by: nithinraok * update write2manifest to consider only up to max audio duration Signed-off-by: nithinraok * asr with diarization notebook Signed-off-by: nithinraok * Fixed ASR_with_diarization tutorial.ipynb and diarization_utils and edited config yaml file Signed-off-by: Taejin Park * Fixed VAD parameters in Speaker_Diarization_Inference.ipynb Signed-off-by: Taejin Park * Added Jenkins test, doc strings and updated README Signed-off-by: nithinraok * update jenkins test Signed-off-by: nithinraok * Doc info in offline_diarization_with_asr Signed-off-by: nithinraok * Review comments Signed-off-by: nithinraok * update outdir paths Signed-off-by: nithinraok Co-authored-by: Taejin Park --- Jenkinsfile | 18 +- README.rst | 2 +- examples/speaker_tasks/diarization/README.md | 119 ++---- .../diarization/asr_with_diarization.py | 152 ------- .../diarization/conf/offline_diarization.yaml | 47 +++ .../conf/offline_diarization_with_asr.yaml | 58 +++ .../diarization/conf/speaker_diarization.yaml | 36 -- ...aker_diarize.py => offline_diarization.py} | 14 +- .../offline_diarization_with_asr.py | 82 ++++ nemo/collections/asr/data/audio_to_label.py | 2 +- .../asr/models/clustering_diarizer.py | 349 +++++++++------- .../asr/parts/utils/diarization_utils.py | 384 +++++------------- .../asr/parts/utils/nmse_clustering.py | 18 +- .../asr/parts/utils/speaker_utils.py | 336 ++++++++------- .../speaker_tasks/pathsfiles_to_manifest.py | 105 +++++ scripts/speaker_tasks/rttm_to_manifest.py | 40 -- .../ASR_with_SpeakerDiarization.ipynb | 207 ++++++---- .../Speaker_Diarization_Inference.ipynb | 138 +++---- 18 files changed, 1026 insertions(+), 1081 deletions(-) delete mode 100644 examples/speaker_tasks/diarization/asr_with_diarization.py create mode 100644 examples/speaker_tasks/diarization/conf/offline_diarization.yaml create mode 100644 examples/speaker_tasks/diarization/conf/offline_diarization_with_asr.yaml delete mode 100644 examples/speaker_tasks/diarization/conf/speaker_diarization.yaml rename examples/speaker_tasks/diarization/{speaker_diarize.py => offline_diarization.py} (69%) create mode 100644 examples/speaker_tasks/diarization/offline_diarization_with_asr.py create mode 100644 scripts/speaker_tasks/pathsfiles_to_manifest.py delete mode 100644 scripts/speaker_tasks/rttm_to_manifest.py diff --git a/Jenkinsfile b/Jenkinsfile index 75181b8f6..097c5c374 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -315,10 +315,8 @@ pipeline { stage('Speaker Diarization Inference') { steps { - sh 'python examples/speaker_tasks/diarization/speaker_diarize.py \ - diarizer.oracle_num_speakers=2 \ - diarizer.paths2audio_files=/home/TestData/an4_diarizer/audio_files.scp \ - diarizer.path2groundtruth_rttm_files=/home/TestData/an4_diarizer/rttm_files.scp \ + sh 'python examples/speaker_tasks/diarization/offline_diarization.py \ + diarizer.manifest_filepath=/home/TestData/an4_diarizer/an4_manifest.json \ diarizer.speaker_embeddings.model_path=/home/TestData/an4_diarizer/spkr.nemo \ diarizer.vad.model_path=/home/TestData/an4_diarizer/MatchboxNet_VAD_3x2.nemo \ diarizer.out_dir=examples/speaker_tasks/diarization/speaker_diarization_results' @@ -326,6 +324,18 @@ pipeline { } } + stage('Speaker Diarization with ASR Inference') { + steps { + sh 'python examples/speaker_tasks/diarization/offline_diarization_with_asr.py \ + diarizer.manifest_filepath=/home/TestData/an4_diarizer/an4_manifest.json \ + diarizer.speaker_embeddings.model_path=/home/TestData/an4_diarizer/spkr.nemo \ + diarizer.asr.model_path=QuartzNet15x5Base-En \ + diarizer.asr.parameters.asr_based_vad=True \ + diarizer.out_dir=examples/speaker_tasks/diarization/speaker_diarization_asr_results' + sh 'rm -rf examples/speaker_tasks/diarization/speaker_diarization_asr_results' + } + } + stage('L2: Speech to Text WPE - CitriNet') { steps { sh 'python examples/asr/speech_to_text_bpe.py \ diff --git a/README.rst b/README.rst index 1c6ab05d3..f92c089f8 100644 --- a/README.rst +++ b/README.rst @@ -51,7 +51,7 @@ Key Features * `Speech Classification and Speech Command Recognition `_: MatchboxNet (Command Recognition) * `Voice activity Detection (VAD) `_: MarbleNet * `Speaker Recognition `_: SpeakerNet, ECAPA_TDNN - * `Speaker Diarization `_: SpeakerNet + * `Speaker Diarization `_: SpeakerNet, ECAPA_TDNN * `Pretrained models on different languages. `_: English, Spanish, German, Russian, Chinese, French, Italian, Polish, ... * `NGC collection of pre-trained speech processing models. `_ * Natural Language Processing diff --git a/examples/speaker_tasks/diarization/README.md b/examples/speaker_tasks/diarization/README.md index 6c49b83fb..eef3a46d8 100644 --- a/examples/speaker_tasks/diarization/README.md +++ b/examples/speaker_tasks/diarization/README.md @@ -30,55 +30,38 @@ Diarization Error Rate (DER) table of `ecapa_tdnn.nemo` model on well known eval #### Example script ```bash -python speaker_diarize.py \ - diarizer.paths2audio_files='my_wav.list' \ - diarizer.speaker_embeddings.model_path='ecapa_tdnn' \ - diarizer.oracle_num_speakers=null \ - diarizer.vad.model_path='vad_telephony_marblenet' + python offline_diarization.py \ + diarizer.manifest_filepath= \ + diarizer.out_dir='demo_output' \ + diarizer.speaker_embeddings.model_path= \ + diarizer.vad.model_path= ``` If you have oracle VAD files and groundtruth RTTM files for evaluation: +Provide rttm files in the input manifest file and enable oracle_vad as shown below. ```bash -python speaker_diarize.py \ - diarizer.paths2audio_files='my_wav.list' \ - diarizer.speaker_embeddings.model_path='path/to/ecapa_tdnn.nemo' \ - diarizer.oracle_num_speakers= null \ - diarizer.speaker_embeddings.oracle_vad_manifest='oracle_vad.manifest' \ - diarizer.path2groundtruth_rttm_files='my_wav_rttm.list' +python offline_diarization.py \ + python speaker_diarize.py \ + diarizer.manifest_filepath= \ + diarizer.out_dir='demo_output' \ + diarizer.speaker_embeddings.model_path= \ + diarizer.oracle_vad=True ``` #### Arguments -      To run speaker diarization on your audio recordings, you need to prepare the following files. +      To run speaker diarization on your audio recordings, you need to prepare the following file. -- **`diarizer.paths2audio_files`: audio file list** +- **`diarizer.manifest_filepath`: ** Path to manifest file -      Provide a list of full path names to the audio files you want to diarize. - -Example: `my_wav.list` +Example: `manifest.json` ```bash -/path/to/my_audio1.wav -/path/to/my_audio2.wav -/path/to/my_audio3.wav +{"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, label: "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"} ``` +Mandatory fields are `audio_filepath`, `offset`, `duration`, `label:"infer"` and `text: ` , and the rest are optional keys which can be passed based on the type of evaluation -- **`diarizer.oracle_num_speakers` : number of speakers in the audio recording** - -      If you know how many speakers are in the recordings, you can provide the number of speakers as follows. - -Example: `number_of_speakers.list` -``` -my_audio1 7 -my_audio2 6 -my_audio3 5 - -``` -           `diarizer.oracle_num_speakers='number_of_speakers.list'` - -      If all sessions have the same number of speakers, you can specify the option as: - -           `diarizer.oracle_num_speakers=5` +Some of important options in config file: - **`diarizer.speaker_embeddings.model_path`: speaker embedding model name** @@ -101,24 +84,11 @@ my_audio3 5            `diarizer.vad.model_path='path/to/vad_telephony_marblenet.nemo'` - -- **`diarizer.path2groundtruth_rttm_files`: reference RTTM files for evaluating the diarization output** - -      You can provide a list of rttm files for evaluating the diarization output. - -Example: `my_wav_rttm.list` -``` -/path/to/my_audio1.rttm -/path/to/my_audio2.rttm -/path/to/my_audio3.rttm - -``` -
## Run Speech Recognition with Speaker Diarization -Using the script `asr_with_diarization.py`, you can transcribe your audio recording with speaker labels as shown below: +Using the script `offline_diarization_with_asr.py`, you can transcribe your audio recording with speaker labels as shown below: ``` [00:03.34 - 00:04.46] speaker_0: back from the gym oh good how's it going @@ -126,51 +96,49 @@ Using the script `asr_with_diarization.py`, you can transcribe your audio record [00:12.10 - 00:13.97] speaker_0: well it's the middle of the week or whatever so ``` -Currently, asr_with_diarization supports QuartzNet English model and ConformerCTC model (`QuartzNet15x5Base`, `stt_en_conformer_ctc_large`). +Currently, offline_diarization_with_asr supports QuartzNet English model and ConformerCTC model (`QuartzNet15x5Base-En`, `stt_en_conformer_ctc_large`). #### Example script ```bash -python asr_with_diarization.py \ - --asr_based_vad \ - --pretrained_speaker_model='ecapa_tdnn' \ - --audiofile_list_path='my_wav.list' \ +python offline_diarization_with_asr.py \ + diarizer.manifest_filepath= \ + diarizer.out_dir='demo_asr_output' \ + diarizer.speaker_embeddings.model_path= \ + diarizer.asr.model_path= \ + diarizer.asr.parameters.asr_based_vad=True ``` -If you have reference rttm files or oracle number of speaker information, you can provide those files as in the following example. +If you have reference rttm files or oracle number of speaker information, you can provide those file paths and number of speakers in the manifest file path and pass `diarizer.clustering.parameters.oracle_num_speakers=True` as shown in the following example. ```bash -python asr_with_diarization.py \ - --asr_based_vad \ - --pretrained_speaker_model='ecapa_tdnn' \ - --audiofile_list_path='my_wav.list' \ - --reference_rttmfile_list_path='my_wav_rttm.list'\ - --oracle_num_speakers=number_of_speakers.list +python offline_diarization_with_asr.py \ + diarizer.manifest_filepath= \ + diarizer.out_dir='demo_asr_output' \ + diarizer.speaker_embeddings.model_path= \ + diarizer.asr.model_path= \ + diarizer.asr.parameters.asr_based_vad=True \ + diarizer.clustering.parameters.oracle_num_speakers=True ``` #### Output folders -The above script will create a folder named `./asr_with_diar/`. -In `./asr_with_diar/`, you can check the results as below. +The above script will create a folder named `./demo_asr_output/`. +In `./demo_asr_output/`, you can check the results as below. ```bash ./asr_with_diar -├── json_result -│ └── my_audio1.json -│ └── my_audio2.json -│ └── my_audio3.json -│ -└── transcript_with_speaker_labels +├── pred_rttms + └── my_audio1.json └── my_audio1.txt - └── my_audio2.txt - └── my_audio3.txt + └── my_audio1.rttm │ └── ... ``` -      `json_result` folder includes word-by-word json output with speaker label and time stamps. +      `*.json` files contains word-by-word json output with speaker label and time stamps. -Example: `./asr_with_diar/json_result/my_audio1.json` +Example: `./demo_asr_output/pred_rttms/my_audio1.json` ```bash { "status": "Success", @@ -199,9 +167,9 @@ Example: `./asr_with_diar/json_result/my_audio1.json` }, ``` -      `transcript_with_speaker_labels` folder includes transcription with speaker labels and corresponding time. +      `*.txt` files contain transcription with speaker labels and corresponding time. -Example: `./asr_with_diar/transcript_with_speaker_labels/my_audio1.txt` +Example: `./demo_asr_output/pred_rttms/my_audio1.txt` ``` [00:03.34 - 00:04.46] speaker_0: back from the gym oh good how's it going [00:04.46 - 00:09.96] speaker_1: pretty well it was really crowded today yeah i kind of assumed everylonewould be at the shore uhhuh @@ -209,4 +177,3 @@ Example: `./asr_with_diar/transcript_with_speaker_labels/my_audio1.txt` [00:13.97 - 00:15.78] speaker_1: but it's the fourth of july mm [00:16.90 - 00:21.80] speaker_0: so yeahg people still work tomorrow do you have to work tomorrow did you drive off yesterday ``` - diff --git a/examples/speaker_tasks/diarization/asr_with_diarization.py b/examples/speaker_tasks/diarization/asr_with_diarization.py deleted file mode 100644 index 3248bd1d5..000000000 --- a/examples/speaker_tasks/diarization/asr_with_diarization.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from nemo.collections.asr.parts.utils.diarization_utils import ASR_DIAR_OFFLINE, read_file_paths -from nemo.utils import logging - -""" -Currently Supported ASR models: - -QuartzNet15x5Base-En -stt_en_conformer_ctc_large -""" - -DEFAULT_URL = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml" - -parser = argparse.ArgumentParser() -parser.add_argument( - "--audiofile_list_path", type=str, help="Fullpath of a file contains the list of audio files", required=True -) -parser.add_argument( - "--reference_rttmfile_list_path", default=None, type=str, help="Fullpath of a file contains the list of rttm files" -) -parser.add_argument( - "--output_path", default=None, type=str, help="Path to the folder where output files are generated" -) -parser.add_argument("--pretrained_vad_model", default=None, type=str, help="Fullpath of the VAD model (*.nemo).") -parser.add_argument( - "--external_vad_manifest", default=None, type=str, help="External VAD output manifest for diarization" -) -parser.add_argument("--asr_based_vad", default=False, action='store_true', help="Use ASR-based VAD") -parser.add_argument( - '--generate_oracle_manifest', default=False, action='store_true', help="Use RTTM ground truth as VAD input" -) -parser.add_argument( - "--pretrained_speaker_model", - type=str, - help="Fullpath of the Speaker embedding extractor model (*.nemo).", - required=True, -) -parser.add_argument("--oracle_num_speakers", help="Either int or text file that contains number of speakers") -parser.add_argument("--threshold", default=10, type=int, help="Threshold for ASR based VAD") -parser.add_argument( - "--diar_config_url", default=DEFAULT_URL, type=str, help="Config yaml file for running speaker diarization" -) -parser.add_argument("--csv", default='result.csv', type=str, help="") -args = parser.parse_args() - - -vad_choices = [ - args.asr_based_vad, - args.pretrained_vad_model, - args.external_vad_manifest, - args.generate_oracle_manifest, -] -if not sum(bool(c) for c in vad_choices) == 1: - raise ValueError( - "Please provide ONE and ONLY ONE method for VAD. \n \ - (1) External manifest external_vad_manifest \n \ - (2) Use a pretrained_vad_model \n \ - (3) Use ASR-based VAD or \n \ - (4) Provide reference_rttmfile_list_path so we can automatically generate oracle manifest for diarization" - ) - if args.generate_oracle_manifest and not args.reference_rttmfile_list_path: - raise ValueError( - "Please provide reference_rttmfile_list_path so we can generate oracle manifest automatically" - ) - -if args.asr_based_vad: - oracle_manifest = 'asr_based_vad' -elif args.pretrained_vad_model: - oracle_manifest = 'system_vad' -elif args.external_vad_manifest: - oracle_manifest = args.external_vad_manifest -elif args.reference_rttmfile_list_path: - logging.info("Use the oracle manifest automatically generated by rttm") - oracle_manifest = asr_diar_offline.write_VAD_rttm( - asr_diar_offline.oracle_vad_dir, audio_file_list, args.reference_rttmfile_list_path - ) - -params = { - "round_float": 2, - "window_length_in_sec": 1.5, - "shift_length_in_sec": 0.75, - "print_transcript": False, - "lenient_overlap_WDER": True, - "VAD_threshold_for_word_ts": 0.7, - "max_word_ts_length_in_sec": 0.6, - "word_gap_in_sec": 0.01, - "minimum": True, - "threshold": args.threshold, # minimun width to consider non-speech activity - "diar_config_url": args.diar_config_url, - "ASR_model_name": 'stt_en_conformer_ctc_large', -} - -asr_diar_offline = ASR_DIAR_OFFLINE(params) - -asr_model = asr_diar_offline.set_asr_model(params['ASR_model_name']) - -asr_diar_offline.create_directories(args.output_path) - -audio_file_list = read_file_paths(args.audiofile_list_path) - -word_list, word_ts_list = asr_diar_offline.run_ASR(audio_file_list, asr_model) - -diar_labels = asr_diar_offline.run_diarization( - audio_file_list, - word_ts_list, - oracle_manifest=oracle_manifest, - oracle_num_speakers=args.oracle_num_speakers, - pretrained_speaker_model=args.pretrained_speaker_model, - pretrained_vad_model=args.pretrained_vad_model, -) - -if args.reference_rttmfile_list_path: - - ref_rttm_file_list = read_file_paths(args.reference_rttmfile_list_path) - - ref_labels_list, DER_result_dict = asr_diar_offline.eval_diarization( - audio_file_list, diar_labels, ref_rttm_file_list - ) - - total_riva_dict = asr_diar_offline.write_json_and_transcript(audio_file_list, diar_labels, word_list, word_ts_list) - - WDER_dict = asr_diar_offline.get_WDER(audio_file_list, total_riva_dict, DER_result_dict, ref_labels_list) - - effective_wder = asr_diar_offline.get_effective_WDER(DER_result_dict, WDER_dict) - - logging.info( - f"\nDER : {DER_result_dict['total']['DER']:.4f} \ - \nFA : {DER_result_dict['total']['FA']:.4f} \ - \nMISS : {DER_result_dict['total']['MISS']:.4f} \ - \nCER : {DER_result_dict['total']['CER']:.4f} \ - \nWDER : {WDER_dict['total']:.4f} \ - \neffective WDER : {effective_wder:.4f} \ - \nspk_counting_acc : {DER_result_dict['total']['spk_counting_acc']:.4f}" - ) - -else: - total_riva_dict = asr_diar_offline.write_json_and_transcript(audio_file_list, diar_labels, word_list, word_ts_list) diff --git a/examples/speaker_tasks/diarization/conf/offline_diarization.yaml b/examples/speaker_tasks/diarization/conf/offline_diarization.yaml new file mode 100644 index 000000000..99f3bc169 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/offline_diarization.yaml @@ -0,0 +1,47 @@ +name: &name "ClusterDiarizer" + +num_workers: 4 +sample_rate: 16000 +batch_size: 64 + +diarizer: + manifest_filepath: ??? + out_dir: ??? + oracle_vad: False # if True, uses RTTM files provided in manifest file to get speech activity (VAD) timestamps + collar: 0.25 # collar value for scoring + ignore_overlap: True # consider or ignore overlap segments while scoring + + vad: + model_path: null #.nemo local model path or pretrained model name or none + external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set + + parameters: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) + window_length_in_sec: 0.15 # window length in sec for VAD context input + shift_length_in_sec: 0.01 # shift length in sec for generate frame level VAD prediction + smoothing: "median" # False or type of smoothing method (eg: median) + overlap: 0.875 # Overlap ratio for overlapped mean/median smoothing filter + onset: 0.4 # onset threshold for detecting the beginning and end of a speech + offset: 0.7 # offset threshold for detecting the end of a speech. + pad_onset: 0.05 # adding durations before each speech segment + pad_offset: -0.1 # adding durations after each speech segment + min_duration_on: 0.2 # threshold for small non_speech deletion + min_duration_off: 0.2 # threshold for short speech segment deletion + filter_speech_first: True + + speaker_embeddings: + model_path: ??? #.nemo local model path or pretrained model name (ecapa_tdnn or speakerverification_speakernet) + parameters: + window_length_in_sec: 1.5 # window length in sec for speaker embedding extraction + shift_length_in_sec: 0.75 # shift length in sec for speaker embedding extraction + save_embeddings: False # save embeddings as pickle file for each audio input + + clustering: + parameters: + oracle_num_speakers: False # if True, uses num of speakers value provided in manifest file + max_num_speakers: 20 # max number of speakers for each recording. If oracle num speakers is passed this value is ignored. + enhanced_count_thres: 80 + max_rp_threshold: 0.25 + sparse_search_volume: 30 + +# json manifest line example +# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, label: "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"} diff --git a/examples/speaker_tasks/diarization/conf/offline_diarization_with_asr.yaml b/examples/speaker_tasks/diarization/conf/offline_diarization_with_asr.yaml new file mode 100644 index 000000000..b23dfd442 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/offline_diarization_with_asr.yaml @@ -0,0 +1,58 @@ +name: &name "ClusterDiarizer" + +num_workers: 4 +sample_rate: 16000 +batch_size: 64 + +diarizer: + manifest_filepath: ??? + out_dir: ??? + oracle_vad: False # if True, uses RTTM files provided in manifest file to get speech activity (VAD) timestamps + collar: 0.25 # collar value for scoring + ignore_overlap: True # consider or ignore overlap segments while scoring + + vad: + model_path: null #.nemo local model path or pretrained model name or none + external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set + + parameters: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) + window_length_in_sec: 0.15 # window length in sec for VAD context input + shift_length_in_sec: 0.01 # shift length in sec for generate frame level VAD prediction + smoothing: "median" # False or type of smoothing method (eg: median) + overlap: 0.875 # Overlap ratio for overlapped mean/median smoothing filter + onset: 0.4 # onset threshold for detecting the beginning and end of a speech + offset: 0.7 # offset threshold for detecting the end of a speech. + pad_onset: 0.05 # adding durations before each speech segment + pad_offset: -0.1 # adding durations after each speech segment + min_duration_on: 0.2 # threshold for small non_speech deletion + min_duration_off: 0.2 # threshold for short speech segment deletion + filter_speech_first: True + + speaker_embeddings: + model_path: ??? #.nemo local model path or pretrained model name (ecapa_tdnn or speakerverification_speakernet) + parameters: + window_length_in_sec: 1.5 # window length in sec for speaker embedding extraction + shift_length_in_sec: 0.75 # shift length in sec for speaker embedding extraction + save_embeddings: False # save embeddings as pickle file for each audio input + + clustering: + parameters: + oracle_num_speakers: False # if True, uses num of speakers value provided in manifest file + max_num_speakers: 20 # max number of speakers for each recording. If oracle num speakers is passed this value is ignored. + enhanced_count_thres: 80 + max_rp_threshold: 0.25 + sparse_search_volume: 30 + + asr: + model_path: ??? # Provie NGC cloud ASR model name + parameters: + asr_based_vad: False # if true, speech segmentation for diarization is based on word-timestamps from ASR inference. + asr_based_vad_threshold: 50 # threshold (multiple of 10ms) for ignoring the gap between two words when generating VAD timestamps using ASR based VAD. + print_transcript: False # if true, the output transcript is displayed after the transcript is generated. + lenient_overlap_WDER: True # if true, when a word falls into speaker-ovelapped regions, consider the word as a correctly diarized word. + vad_threshold_for_word_ts: 0.7 # threshold for getting VAD stamps that compensate the word timestamps generated from ASR decoder. + max_word_ts_length_in_sec: 0.6 # Maximum length of each word timestamps + word_gap_in_sec: 0.01 # The gap between the end of a word and the start of the following word. + +# json manifest line example +# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, label: "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath"="/path/to/uem/filepath"} diff --git a/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml b/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml deleted file mode 100644 index f9c3bfcc4..000000000 --- a/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml +++ /dev/null @@ -1,36 +0,0 @@ -name: &name "ClusterDiarizer" - -num_workers: 4 -sample_rate: 16000 -batch_size: 64 - -diarizer: - oracle_num_speakers: null # oracle number of speakers or pass null if not known. [int or path to file containing uniq-id with num of speakers of that session] - max_num_speakers: 20 # max number of speakers for each recording. If oracle num speakers is passed this value is ignored. - out_dir: ??? - paths2audio_files: ??? #file containing paths to audio files for inference - path2groundtruth_rttm_files: null #file containing paths to ground truth rttm files for score computation or null if no reference rttms are present - - vad: - model_path: null #.nemo local model path or pretrained model name or none - window_length_in_sec: 0.15 - shift_length_in_sec: 0.01 - vad_decision_smoothing: True - smoothing_params: - method: "median" - overlap: 0.875 - threshold: null # if not null, use this value and ignore the following onset and offset. Warning! it will be removed in release 1.5! - postprocessing_params: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) You might need to choose resonable threshold or tune the threshold on your dev set. Check /scripts/voice_activity_detection/vad_tune_threshold.py - onset: 0.4 - offset: 0.7 - pad_onset: 0.05 - pad_offset: -0.1 - min_duration_on: 0.2 - min_duration_off: 0.2 - filter_speech_first: True - - speaker_embeddings: - oracle_vad_manifest: null - model_path: ??? #.nemo local model path or pretrained model name (ecapa_tdnn or speakerverification_speakernet) - window_length_in_sec: 1.5 # window length in sec for speaker embedding extraction - shift_length_in_sec: 0.75 # shift length in sec for speaker embedding extraction diff --git a/examples/speaker_tasks/diarization/speaker_diarize.py b/examples/speaker_tasks/diarization/offline_diarization.py similarity index 69% rename from examples/speaker_tasks/diarization/speaker_diarize.py rename to examples/speaker_tasks/diarization/offline_diarization.py index 9a62245c6..316a4fd56 100644 --- a/examples/speaker_tasks/diarization/speaker_diarize.py +++ b/examples/speaker_tasks/diarization/offline_diarization.py @@ -23,21 +23,19 @@ from nemo.utils import logging This script demonstrates how to use run speaker diarization. Usage: python speaker_diarize.py \ - diarizer.paths2audio_files= \ - diarizer.path2groundtruth_rttm_files=<(Optional) either list of rttm file paths or file containing paths to rttm files> \ - diarizer.vad.model_path="vad_telephony_marblenet" \ - diarizer.vad.threshold=0.7 \ - diarizer.speaker_embeddings.model_path="speakerverification_speakernet" \ - diarizer.out_dir="demo" + diarizer.manifest_filepath= \ + diarizer.out_dir='demo_output' \ + diarizer.speaker_embeddings.model_path= \ + diarizer.vad.model_path='vad_marblenet' -Check out whole parameters in ./conf/speaker_diarization.yaml and their meanings. +Check out whole parameters in ./conf/offline_diarization.yaml and their meanings. For details, have a look at /tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb """ seed_everything(42) -@hydra_runner(config_path="conf", config_name="speaker_diarization.yaml") +@hydra_runner(config_path="conf", config_name="offline_diarization.yaml") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') diff --git a/examples/speaker_tasks/diarization/offline_diarization_with_asr.py b/examples/speaker_tasks/diarization/offline_diarization_with_asr.py new file mode 100644 index 000000000..5405c461b --- /dev/null +++ b/examples/speaker_tasks/diarization/offline_diarization_with_asr.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +from omegaconf import OmegaConf + +from nemo.collections.asr.parts.utils.diarization_utils import ASR_DIAR_OFFLINE +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +""" +This script demonstrates how to run offline speaker diarization with asr. +Usage: +python offline_diarization_with_asr.py \ + diarizer.manifest_filepath= \ + diarizer.out_dir='demo_asr_output' \ + diarizer.speaker_embeddings.model_path= \ + diarizer.asr.model_path= \ + diarizer.asr.parameters.asr_based_vad=True + +Check out whole parameters in ./conf/offline_diarization_with_asr.yaml and their meanings. +For details, have a look at /tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb + +Currently Supported ASR models: + +QuartzNet15x5Base-En +stt_en_conformer_ctc_large +""" + + +@hydra_runner(config_path="conf", config_name="offline_diarization_with_asr.yaml") +def main(cfg): + + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + asr_diar_offline = ASR_DIAR_OFFLINE(**cfg.diarizer.asr.parameters) + asr_diar_offline.root_path = cfg.diarizer.out_dir + shutil.rmtree(asr_diar_offline.root_path, ignore_errors=True) + + AUDIO_RTTM_MAP = audio_rttm_map(cfg.diarizer.manifest_filepath) + asr_diar_offline.AUDIO_RTTM_MAP = AUDIO_RTTM_MAP + asr_model = asr_diar_offline.set_asr_model(cfg.diarizer.asr.model_path) + + word_list, word_ts_list = asr_diar_offline.run_ASR(asr_model) + + score = asr_diar_offline.run_diarization(cfg, word_ts_list,) + total_riva_dict = asr_diar_offline.write_json_and_transcript(word_list, word_ts_list) + + if score is not None: + metric, mapping_dict = score + DER_result_dict = asr_diar_offline.gather_eval_results(metric, mapping_dict) + + WDER_dict = asr_diar_offline.get_WDER(total_riva_dict, DER_result_dict) + effective_wder = asr_diar_offline.get_effective_WDER(DER_result_dict, WDER_dict) + + logging.info( + f"\nDER : {DER_result_dict['total']['DER']:.4f} \ + \nFA : {DER_result_dict['total']['FA']:.4f} \ + \nMISS : {DER_result_dict['total']['MISS']:.4f} \ + \nCER : {DER_result_dict['total']['CER']:.4f} \ + \nWDER : {WDER_dict['total']:.4f} \ + \neffective WDER : {effective_wder:.4f} \ + \nspk_counting_acc : {DER_result_dict['total']['spk_counting_acc']:.4f}" + ) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index 00a297357..00b9cf06f 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -222,7 +222,7 @@ def _vad_frame_seq_collate_fn(self, batch): sig_len += slice_length if has_audio: - slices = (sig_len - slice_length) // shift + slices = torch.div(sig_len - slice_length, shift, rounding_mode='trunc') for slice_id in range(slices): start_idx = slice_id * shift end_idx = start_idx + slice_length diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 056753d3f..f78cec92b 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -19,20 +19,22 @@ import shutil import tarfile import tempfile from collections import defaultdict +from copy import deepcopy from typing import List, Optional import torch -from omegaconf import DictConfig, OmegaConf, open_dict -from omegaconf.listconfig import ListConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm from nemo.collections.asr.models.classification_models import EncDecClassificationModel -from nemo.collections.asr.models.label_models import ExtractSpeakerEmbeddingsModel +from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel from nemo.collections.asr.parts.mixins.mixins import DiarizationMixin from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, - perform_diarization, + get_uniqname_from_filepath, + perform_clustering, + score_labels, segments_manifest_to_subsegments_manifest, write_rttm2manifest, ) @@ -44,7 +46,6 @@ from nemo.collections.asr.parts.utils.vad_utils import ( ) from nemo.core.classes import Model from nemo.utils import logging, model_utils -from nemo.utils.exp_manager import NotFoundError try: from torch.cuda.amp import autocast @@ -64,48 +65,44 @@ _SPEAKER_MODEL = "speaker_model.nemo" def get_available_model_names(class_name): + "lists available pretrained model names from NGC" available_models = class_name.list_available_models() return list(map(lambda x: x.pretrained_model_name, available_models)) class ClusteringDiarizer(Model, DiarizationMixin): + """ + Inference model Class for offline speaker diarization. + This class handles required functionality for diarization : Speech Activity Detection, Segmentation, + Extract Embeddings, Clustering, Resegmentation and Scoring. + All the parameters are passed through config file + """ + def __init__(self, cfg: DictConfig): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Convert config to support Hydra 1.0+ instantiation cfg = model_utils.maybe_update_config_version(cfg) self._cfg = cfg - self._out_dir = self._cfg.diarizer.out_dir - if not os.path.exists(self._out_dir): - os.mkdir(self._out_dir) + + # Diarizer set up + self._diarizer_params = self._cfg.diarizer # init vad model self.has_vad_model = False - self.has_vad_model_to_save = False - - self._speaker_manifest_path = self._cfg.diarizer.speaker_embeddings.oracle_vad_manifest - self.AUDIO_RTTM_MAP = None - self.paths2audio_files = self._cfg.diarizer.paths2audio_files - - if self._cfg.diarizer.vad.model_path is not None: - self._init_vad_model() - self._vad_dir = os.path.join(self._out_dir, 'vad_outputs') - self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json") - shutil.rmtree(self._vad_dir, ignore_errors=True) - os.makedirs(self._vad_dir) + if not self._diarizer_params.oracle_vad: + if self._cfg.diarizer.vad.model_path is not None: + self._vad_params = self._cfg.diarizer.vad.parameters + self._init_vad_model() # init speaker model self._init_speaker_model() + self._speaker_params = self._cfg.diarizer.speaker_embeddings.parameters + self._speaker_dir = os.path.join(self._diarizer_params.out_dir, 'speaker_outputs') + shutil.rmtree(self._speaker_dir, ignore_errors=True) + os.makedirs(self._speaker_dir) - if self._cfg.diarizer.get('num_speakers', None): - self._num_speakers = self._cfg.diarizer.num_speakers - logging.warning("in next release num_speakers will be changed to oracle_num_speakers") - else: - self._num_speakers = self._cfg.diarizer.oracle_num_speakers - - if self._cfg.diarizer.get('max_num_speakers', None): - self.max_num_speakers = self._cfg.diarizer.max_num_speakers - else: - self.max_num_speakers = 8 + # Clustering params + self._cluster_params = self._diarizer_params.clustering.parameters self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -113,28 +110,10 @@ class ClusteringDiarizer(Model, DiarizationMixin): def list_available_models(cls): pass - def _init_speaker_model(self): - model_path = self._cfg.diarizer.speaker_embeddings.model_path - if model_path is not None and model_path.endswith('.nemo'): - self._speaker_model = ExtractSpeakerEmbeddingsModel.restore_from(model_path) - logging.info("Speaker Model restored locally from {}".format(model_path)) - else: - if model_path not in get_available_model_names(ExtractSpeakerEmbeddingsModel): - logging.warning( - "requested {} model name not available in pretrained models, instead".format(model_path) - ) - model_path = "speakerdiarization_speakernet" - logging.info("Loading pretrained {} model from NGC".format(model_path)) - self._speaker_model = ExtractSpeakerEmbeddingsModel.from_pretrained(model_name=model_path) - - self._speaker_dir = os.path.join(self._out_dir, 'speaker_outputs') - - def set_vad_model(self, vad_config): - with open_dict(self._cfg): - self._cfg.diarizer.vad = vad_config - self._init_vad_model() - def _init_vad_model(self): + """ + Initialize vad model with model name or path passed through config + """ model_path = self._cfg.diarizer.vad.model_path if model_path.endswith('.nemo'): self._vad_model = EncDecClassificationModel.restore_from(model_path) @@ -148,15 +127,35 @@ class ClusteringDiarizer(Model, DiarizationMixin): logging.info("Loading pretrained {} model from NGC".format(model_path)) self._vad_model = EncDecClassificationModel.from_pretrained(model_name=model_path) - self._vad_window_length_in_sec = self._cfg.diarizer.vad.window_length_in_sec - self._vad_shift_length_in_sec = self._cfg.diarizer.vad.shift_length_in_sec - self.has_vad_model_to_save = True + self._vad_window_length_in_sec = self._vad_params.window_length_in_sec + self._vad_shift_length_in_sec = self._vad_params.shift_length_in_sec self.has_vad_model = True + def _init_speaker_model(self): + """ + Initialize speaker embedding model with model name or path passed through config + """ + model_path = self._cfg.diarizer.speaker_embeddings.model_path + if model_path is not None and model_path.endswith('.nemo'): + self._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path) + logging.info("Speaker Model restored locally from {}".format(model_path)) + elif model_path.endswith('.ckpt'): + self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(model_path) + logging.info("Speaker Model restored locally from {}".format(model_path)) + else: + if model_path not in get_available_model_names(EncDecSpeakerLabelModel): + logging.warning( + "requested {} model name not available in pretrained models, instead".format(model_path) + ) + model_path = "ecapa_tdnn" + logging.info("Loading pretrained {} model from NGC".format(model_path)) + self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name=model_path) + def _setup_vad_test_data(self, manifest_vad_input): vad_dl_config = { 'manifest_filepath': manifest_vad_input, 'sample_rate': self._cfg.sample_rate, + 'batch_size': self._cfg.get('batch_size'), 'vad_stream': True, 'labels': ['infer',], 'time_length': self._vad_window_length_in_sec, @@ -170,11 +169,10 @@ class ClusteringDiarizer(Model, DiarizationMixin): spk_dl_config = { 'manifest_filepath': manifest_file, 'sample_rate': self._cfg.sample_rate, - 'batch_size': self._cfg.get('batch_size', 32), - 'time_length': self._cfg.diarizer.speaker_embeddings.window_length_in_sec, - 'shift_length': self._cfg.diarizer.speaker_embeddings.shift_length_in_sec, + 'batch_size': self._cfg.get('batch_size'), + 'time_length': self._speaker_params.window_length_in_sec, + 'shift_length': self._speaker_params.shift_length_in_sec, 'trim_silence': False, - 'embedding_dir': self._speaker_dir, 'labels': None, 'task': "diarization", 'num_workers': self._cfg.num_workers, @@ -182,6 +180,18 @@ class ClusteringDiarizer(Model, DiarizationMixin): self._speaker_model.setup_test_data(spk_dl_config) def _run_vad(self, manifest_file): + """ + Run voice activity detection. + Get log probability of voice activity detection and smoothes using the post processing parameters. + Using generated frame level predictions generated manifest file for later speaker embedding extraction. + input: + manifest_file (str) : Manifest file containing path to audio file and label as infer + + """ + + shutil.rmtree(self._vad_dir, ignore_errors=True) + os.makedirs(self._vad_dir) + self._vad_model = self._vad_model.to(self._device) self._vad_model.eval() @@ -191,8 +201,8 @@ class ClusteringDiarizer(Model, DiarizationMixin): all_len = 0 data = [] for line in open(manifest_file, 'r'): - file = os.path.basename(json.loads(line)['audio_filepath']) - data.append(os.path.splitext(file)[0]) + file = json.loads(line)['audio_filepath'] + data.append(get_uniqname_from_filepath(file)) status = get_vad_stream_status(data) for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader())): @@ -218,18 +228,17 @@ class ClusteringDiarizer(Model, DiarizationMixin): if status[i] == 'end' or status[i] == 'single': all_len = 0 - if not self._cfg.diarizer.vad.vad_decision_smoothing: + if not self._vad_params.smoothing: # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame; self.vad_pred_dir = self._vad_dir - else: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info("Generating predictions with overlapping input segments") smoothing_pred_dir = generate_overlap_vad_seq( frame_pred_dir=self._vad_dir, - smoothing_method=self._cfg.diarizer.vad.smoothing_params.method, - overlap=self._cfg.diarizer.vad.smoothing_params.overlap, + smoothing_method=self._vad_params.smoothing, + overlap=self._vad_params.overlap, seg_len=self._vad_window_length_in_sec, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, @@ -237,30 +246,81 @@ class ClusteringDiarizer(Model, DiarizationMixin): self.vad_pred_dir = smoothing_pred_dir logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.") - postprocessing_params = self._cfg.diarizer.vad.postprocessing_params - if self._cfg.diarizer.vad.get("threshold", None): - logging.info("threshold is not None. Use threshold and update onset=offset=threshold") - logging.warning("Support for threshold will be deprecated in release 1.5") - postprocessing_params['onset'] = self._cfg.diarizer.vad.threshold - postprocessing_params['offset'] = self._cfg.diarizer.vad.threshold table_out_dir = generate_vad_segment_table( vad_pred_dir=self.vad_pred_dir, - postprocessing_params=postprocessing_params, + postprocessing_params=self._vad_params, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) + AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy()) + for key in AUDIO_VAD_RTTM_MAP: + AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(table_out_dir, key + ".txt") - vad_table_list = [os.path.join(table_out_dir, key + ".txt") for key in self.AUDIO_RTTM_MAP] - write_rttm2manifest(self._cfg.diarizer.paths2audio_files, vad_table_list, self._vad_out_file) + write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file) self._speaker_manifest_path = self._vad_out_file + def _run_segmentation(self): + + self.subsegments_manifest_path = os.path.join(self._speaker_dir, 'subsegments.json') + self.subsegments_manifest_path = segments_manifest_to_subsegments_manifest( + segments_manifest_file=self._speaker_manifest_path, + subsegments_manifest_file=self.subsegments_manifest_path, + window=self._speaker_params.window_length_in_sec, + shift=self._speaker_params.shift_length_in_sec, + ) + + return None + + def _perform_speech_activity_detection(self): + """ + Checks for type of speech activity detection from config. Choices are NeMo VAD, + external vad manifest and oracle VAD (generates speech activity labels from provided RTTM files) + """ + if self.has_vad_model: + self._dont_auto_split = False + self._split_duration = 50 + manifest_vad_input = self._diarizer_params.manifest_filepath + + if not self._dont_auto_split: + logging.info("Split long audio file to avoid CUDA memory issue") + logging.debug("Try smaller split_duration if you still have CUDA memory issue") + config = { + 'manifest_filepath': manifest_vad_input, + 'time_length': self._vad_window_length_in_sec, + 'split_duration': self._split_duration, + 'num_workers': self._cfg.num_workers, + } + manifest_vad_input = prepare_manifest(config) + else: + logging.warning( + "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." + ) + + self._setup_vad_test_data(manifest_vad_input) + self._run_vad(manifest_vad_input) + + elif self._diarizer_params.vad.external_vad_manifest is not None: + self._speaker_manifest_path = self._diarizer_params.vad.external_vad_manifest + elif self._diarizer_params.oracle_vad: + self._speaker_manifest_path = os.path.join(self._speaker_dir, 'oracle_vad_manifest.json') + self._speaker_manifest_path = write_rttm2manifest(self.AUDIO_RTTM_MAP, self._speaker_manifest_path) + else: + raise ValueError( + "Only one of diarizer.oracle_vad, vad.model_path or vad.external_vad_manifest must be passed" + ) + def _extract_embeddings(self, manifest_file): + """ + This method extracts speaker embeddings from segments passed through manifest_file + Optionally you may save the intermediate speaker embeddings for debugging or any use. + """ logging.info("Extracting embeddings for Diarization") self._setup_spkr_test_data(manifest_file) - out_embeddings = defaultdict(list) + self.embeddings = defaultdict(list) self._speaker_model = self._speaker_model.to(self._device) self._speaker_model.eval() + self.time_stamps = {} all_embs = [] for test_batch in tqdm(self._speaker_model.test_dataloader()): @@ -277,105 +337,96 @@ class ClusteringDiarizer(Model, DiarizationMixin): for i, line in enumerate(manifest.readlines()): line = line.strip() dic = json.loads(line) - uniq_name = os.path.basename(dic['audio_filepath']).rsplit('.', 1)[0] - out_embeddings[uniq_name].extend([all_embs[i]]) + uniq_name = get_uniqname_from_filepath(dic['audio_filepath']) + self.embeddings[uniq_name].extend([all_embs[i]]) + if uniq_name not in self.time_stamps: + self.time_stamps[uniq_name] = [] + start = dic['offset'] + end = start + dic['duration'] + stamp = '{:.3f} {:.3f} '.format(start, end) + self.time_stamps[uniq_name].append(stamp) - embedding_dir = os.path.join(self._speaker_dir, 'embeddings') - if not os.path.exists(embedding_dir): - os.makedirs(embedding_dir, exist_ok=True) + if self._speaker_params.save_embeddings: + embedding_dir = os.path.join(self._speaker_dir, 'embeddings') + if not os.path.exists(embedding_dir): + os.makedirs(embedding_dir, exist_ok=True) - prefix = manifest_file.split('/')[-1].rsplit('.', 1)[-2] + prefix = get_uniqname_from_filepath(manifest_file) - name = os.path.join(embedding_dir, prefix) - self._embeddings_file = name + '_embeddings.pkl' - pkl.dump(out_embeddings, open(self._embeddings_file, 'wb')) - logging.info("Saved embedding files to {}".format(embedding_dir)) + name = os.path.join(embedding_dir, prefix) + self._embeddings_file = name + '_embeddings.pkl' + pkl.dump(self.embeddings, open(self._embeddings_file, 'wb')) + logging.info("Saved embedding files to {}".format(embedding_dir)) - def path2audio_files_to_manifest(self, paths2audio_files): - mfst_file = os.path.join(self._out_dir, 'manifest.json') - with open(mfst_file, 'w') as fp: + def path2audio_files_to_manifest(self, paths2audio_files, manifest_filepath): + with open(manifest_filepath, 'w') as fp: for audio_file in paths2audio_files: audio_file = audio_file.strip() entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} fp.write(json.dumps(entry) + '\n') - return mfst_file + def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0): + """ + Diarize files provided thorugh paths2audio_files or manifest file + input: + paths2audio_files (List[str]): list of paths to file containing audio file + batch_size (int): batch_size considered for extraction of speaker embeddings and VAD computation + """ - def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 1): - """ - """ + self._out_dir = self._diarizer_params.out_dir + if not os.path.exists(self._out_dir): + os.mkdir(self._out_dir) + + self._vad_dir = os.path.join(self._out_dir, 'vad_outputs') + self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json") + + if batch_size: + self._cfg.batch_size = batch_size if paths2audio_files: - self.paths2audio_files = paths2audio_files - else: - if self._cfg.diarizer.paths2audio_files is None: - raise ValueError("Pass path2audio files either through config or to diarize method") + if type(paths2audio_files) is list: + self._diarizer_params.manifest_filepath = os.path.json(self._out_dir, 'paths2audio_filepath.json') + self.path2audio_files_to_manifest(paths2audio_files, self._diarizer_params.manifest_filepath) else: - self.paths2audio_files = self._cfg.diarizer.paths2audio_files + raise ValueError("paths2audio_files must be of type list of paths to file containing audio file") - if type(self.paths2audio_files) is str and os.path.isfile(self.paths2audio_files): - paths2audio_files = [] - with open(self.paths2audio_files, 'r') as path2file: - for audiofile in path2file.readlines(): - audiofile = audiofile.strip() - paths2audio_files.append(audiofile) + self.AUDIO_RTTM_MAP = audio_rttm_map(self._diarizer_params.manifest_filepath) - elif type(self.paths2audio_files) in [list, ListConfig]: - paths2audio_files = list(self.paths2audio_files) + # Speech Activity Detection + self._perform_speech_activity_detection() - else: - raise ValueError("paths2audio_files must be of type list or path to file containing audio files") + # Segmentation + self._run_segmentation() - self.AUDIO_RTTM_MAP = audio_rttm_map(paths2audio_files, self._cfg.diarizer.path2groundtruth_rttm_files) - - if self.has_vad_model: - logging.info("Performing VAD") - mfst_file = self.path2audio_files_to_manifest(paths2audio_files) - self._dont_auto_split = False - self._split_duration = 50 - manifest_vad_input = mfst_file - - if not self._dont_auto_split: - logging.info("Split long audio file to avoid CUDA memory issue") - logging.debug("Try smaller split_duration if you still have CUDA memory issue") - config = { - 'manifest_filepath': mfst_file, - 'time_length': self._vad_window_length_in_sec, - 'split_duration': self._split_duration, - 'num_workers': self._cfg.num_workers, - } - manifest_vad_input = prepare_manifest(config) - else: - logging.warning( - "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." - ) - - self._setup_vad_test_data(manifest_vad_input) - self._run_vad(manifest_vad_input) - else: - if not os.path.exists(self._speaker_manifest_path): - raise NotFoundError("Oracle VAD based manifest file not found") - - self.subsegments_manifest_path = os.path.join(self._out_dir, 'subsegments.json') - self.subsegments_manifest_path = segments_manifest_to_subsegments_manifest( - segments_manifest_file=self._speaker_manifest_path, - subsegments_manifest_file=self.subsegments_manifest_path, - window=self._cfg.diarizer.speaker_embeddings.window_length_in_sec, - shift=self._cfg.diarizer.speaker_embeddings.shift_length_in_sec, - ) + # Embedding Extraction self._extract_embeddings(self.subsegments_manifest_path) + out_rttm_dir = os.path.join(self._out_dir, 'pred_rttms') os.makedirs(out_rttm_dir, exist_ok=True) - perform_diarization( - embeddings_file=self._embeddings_file, - reco2num=self._num_speakers, - manifest_path=self.subsegments_manifest_path, - audio_rttm_map=self.AUDIO_RTTM_MAP, + # Clustering + all_reference, all_hypothesis = perform_clustering( + embeddings=self.embeddings, + time_stamps=self.time_stamps, + AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP, out_rttm_dir=out_rttm_dir, - max_num_speakers=self.max_num_speakers, + clustering_params=self._cluster_params, ) + # TODO Resegmentation -> Coming Soon + + # Scoring + score = score_labels( + self.AUDIO_RTTM_MAP, + all_reference, + all_hypothesis, + collar=self._diarizer_params.collar, + ignore_overlap=self._diarizer_params.ignore_overlap, + ) + + logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir))) + return score + @staticmethod def __make_nemo_file_from_folder(filename, source_dir): with tarfile.open(filename, "w:gz") as tar: diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index 0e296fa29..1b0e25f33 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -21,19 +21,17 @@ from collections import OrderedDict as od from datetime import datetime from typing import List +import librosa import numpy as np import soundfile as sf import torch -import wget -from omegaconf import OmegaConf from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.metrics.wer_bpe import WERBPE from nemo.collections.asr.models import ClusteringDiarizer, EncDecCTCModel, EncDecCTCModelBPE, EncDecRNNTBPEModel -from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map as get_audio_rttm_map from nemo.collections.asr.parts.utils.speaker_utils import ( - get_DER, - labels_to_pyannote_object, + get_uniqname_from_filepath, + labels_to_rttmfile, rttm_to_labels, write_rttm2manifest, ) @@ -61,26 +59,6 @@ def write_txt(w_path, val): return None -def get_uniq_id_from_audio_path(audio_file_path): - """Get the unique ID from the audio file path - """ - return '.'.join(os.path.basename(audio_file_path).split('.')[:-1]) - - -def read_file_paths(file_list_path): - """Read file paths from the given list - """ - out_path_list = [] - if not file_list_path or (file_list_path in NONE_LIST): - raise ValueError("file_list_path is not provided.") - else: - with open(file_list_path, 'r') as path2file: - for _file in path2file.readlines(): - out_path_list.append(_file.strip()) - - return out_path_list - - class WERBPE_TS(WERBPE): """ This is WERBPE_TS class that is modified for generating word_timestamps with logits. @@ -306,15 +284,16 @@ class ASR_DIAR_OFFLINE(object): A Class designed for performing ASR and diarization together. """ - def __init__(self, params): + def __init__(self, **params): self.params = params - self.nonspeech_threshold = self.params['threshold'] + self.nonspeech_threshold = self.params['asr_based_vad_threshold'] self.root_path = None self.fix_word_ts_with_VAD = True self.run_ASR = None self.frame_VAD = {} + self.AUDIO_RTTM_MAP = {} - def set_asr_model(self, ASR_model_name): + def set_asr_model(self, asr_model): """ Setup the parameters for the given ASR model Currently, the following models are supported: @@ -324,79 +303,53 @@ class ASR_DIAR_OFFLINE(object): QuartzNet15x5Base-En """ - if 'QuartzNet' in ASR_model_name: + if 'QuartzNet' in asr_model: self.run_ASR = self.run_ASR_QuartzNet_CTC - asr_model = EncDecCTCModel.from_pretrained(model_name=ASR_model_name, strict=False) + asr_model = EncDecCTCModel.from_pretrained(model_name=asr_model, strict=False) self.params['offset'] = -0.18 self.model_stride_in_secs = 0.02 self.asr_delay_sec = -1 * self.params['offset'] - elif 'conformer_ctc' in ASR_model_name: + elif 'conformer_ctc' in asr_model: self.run_ASR = self.run_ASR_BPE_CTC - asr_model = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name, strict=False) + asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.04 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 - elif 'citrinet' in ASR_model_name: + elif 'citrinet' in asr_model: self.run_ASR = self.run_ASR_BPE_CTC - asr_model = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name, strict=False) + asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.08 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 - elif 'conformer_transducer' in ASR_model_name or 'contextnet' in ASR_model_name: + elif 'conformer_transducer' in asr_model or 'contextnet' in asr_model: self.run_ASR = self.run_ASR_BPE_RNNT - self.get_speech_labels_list = self.save_VAD_labels_list - asr_model = EncDecRNNTBPEModel.from_pretrained(model_name=ASR_model_name, strict=False) + asr_model = EncDecRNNTBPEModel.from_pretrained(model_name=asr_model, strict=False) self.model_stride_in_secs = 0.04 self.asr_delay_sec = 0.0 self.params['offset'] = 0 self.chunk_len_in_sec = 1.6 self.total_buffer_in_secs = 4 - else: - raise ValueError(f"ASR model name not found: {self.params['ASR_model_name']}") + raise ValueError(f"ASR model name not found: {asr_model}") self.params['time_stride'] = self.model_stride_in_secs self.asr_batch_size = 16 + self.audio_file_list = [value['audio_filepath'] for _, value in self.AUDIO_RTTM_MAP.items()] + return asr_model - def create_directories(self, output_path=None): - """Creates directories for transcribing with diarization. - """ - if output_path: - os.makedirs(output_path, exist_ok=True) - assert os.path.exists(output_path) - ROOT = output_path - else: - ROOT = os.path.join(os.getcwd(), 'asr_with_diar') - self.oracle_vad_dir = os.path.join(ROOT, 'oracle_vad') - self.json_result_dir = os.path.join(ROOT, 'json_result') - self.trans_with_spks_dir = os.path.join(ROOT, 'transcript_with_speaker_labels') - self.audacity_label_dir = os.path.join(ROOT, 'audacity_label') - - self.root_path = ROOT - os.makedirs(self.root_path, exist_ok=True) - os.makedirs(self.oracle_vad_dir, exist_ok=True) - os.makedirs(self.json_result_dir, exist_ok=True) - os.makedirs(self.trans_with_spks_dir, exist_ok=True) - os.makedirs(self.audacity_label_dir, exist_ok=True) - - data_dir = os.path.join(ROOT, 'data') - os.makedirs(data_dir, exist_ok=True) - - def run_ASR_QuartzNet_CTC(self, audio_file_list, _asr_model): + def run_ASR_QuartzNet_CTC(self, _asr_model): """ Run an QuartzNet ASR model and collect logit, timestamps and text output Args: - audio_file_list (list): - List of audio file paths. _asr_model (class): The loaded NeMo ASR model. @@ -418,7 +371,7 @@ class ASR_DIAR_OFFLINE(object): ) with torch.cuda.amp.autocast(): - transcript_logits_list = _asr_model.transcribe(audio_file_list, batch_size=1, logprobs=True) + transcript_logits_list = _asr_model.transcribe(self.audio_file_list, batch_size=1, logprobs=True) for logit_np in transcript_logits_list: log_prob = torch.from_numpy(logit_np) logits_len = torch.from_numpy(np.array([log_prob.shape[0]])) @@ -438,16 +391,14 @@ class ASR_DIAR_OFFLINE(object): word_ts_list.append(word_ts) return words_list, word_ts_list - def run_ASR_BPE_RNNT(self, audio_file_list, _asr_model): + def run_ASR_BPE_RNNT(self, _asr_model): raise NotImplementedError - def run_ASR_BPE_CTC(self, audio_file_list, _asr_model): + def run_ASR_BPE_CTC(self, _asr_model): """ Run a CTC-BPE based ASR model and collect logit, timestamps and text output Args: - audio_file_list (list): - List of audio file paths. _asr_model (class): The loaded NeMo ASR model. @@ -486,9 +437,8 @@ class ASR_DIAR_OFFLINE(object): ) with torch.cuda.amp.autocast(): - logging.info(f"Running ASR model {self.params['ASR_model_name']}") hyps, tokens_list, sample_list = get_wer_feat_logit( - audio_file_list, + self.audio_file_list, frame_asr, self.chunk_len_in_sec, tokens_per_chunk, @@ -537,12 +487,17 @@ class ASR_DIAR_OFFLINE(object): List that contains word timestamps. audio_file_list (list): List of audio file paths. - """ + self.VAD_RTTM_MAP = {} for i, word_timestamps in enumerate(word_ts_list): speech_labels_float = self._get_speech_labels_from_decoded_prediction(word_timestamps) speech_labels = self.get_str_speech_labels(speech_labels_float) - self.write_VAD_rttm_from_speech_labels(self.root_path, audio_file_list[i], speech_labels) + uniq_id = get_uniqname_from_filepath(audio_file_list[i]) + output_path = os.path.join(self.root_path, 'pred_rttms') + if not os.path.exists(output_path): + os.makedirs(output_path) + filename = labels_to_rttmfile(speech_labels, uniq_id, output_path) + self.VAD_RTTM_MAP[uniq_id] = {'audio_filepath': audio_file_list[i], 'rttm_filepath': filename} def _get_speech_labels_from_decoded_prediction(self, input_word_ts): """ @@ -604,80 +559,31 @@ class ASR_DIAR_OFFLINE(object): return word_timestamps def run_diarization( - self, - audio_file_list, - words_and_timestamps, - oracle_manifest=None, - oracle_num_speakers=None, - pretrained_speaker_model=None, - pretrained_vad_model=None, + self, diar_model_config, words_and_timestamps, ): """ Run the diarization process using the given VAD timestamp (oracle_manifest). Args: - audio_file_list (list): - List of audio file paths. word_and_timestamps (list): - List contains words and word-timestamps. - oracle_manifest (str): - A json file path which contains the timestamps of the VAD output. - if None, we use word-timestamps for VAD and segmentation. - oracle_num_speakers (int): - Oracle number of speakers. If None, the number of speakers is estimated. - pretrained_speaker_model (str): - NeMo model file path for speaker embedding extractor model. - - Return: - diar_labels (list): - List that contains diarization result in the form of - speaker labels and time stamps. + List contains words and word-timestamps """ - if oracle_num_speakers != None: - if oracle_num_speakers.isnumeric(): - oracle_num_speakers = int(oracle_num_speakers) - elif oracle_num_speakers in NONE_LIST: - oracle_num_speakers = None + if diar_model_config.diarizer.asr.parameters.asr_based_vad: + self.save_VAD_labels_list(words_and_timestamps, self.audio_file_list) + oracle_manifest = os.path.join(self.root_path, 'asr_vad_manifest.json') + oracle_manifest = write_rttm2manifest(self.VAD_RTTM_MAP, oracle_manifest) + diar_model_config.diarizer.vad.model_path = None + diar_model_config.diarizer.vad.external_vad_manifest = oracle_manifest - data_dir = os.path.join(self.root_path, 'data') + oracle_model = ClusteringDiarizer(cfg=diar_model_config) + score = oracle_model.diarize() + if diar_model_config.diarizer.vad.model_path is not None: + self.get_frame_level_VAD(vad_processing_dir=oracle_model.vad_pred_dir) - MODEL_CONFIG = os.path.join(data_dir, 'speaker_diarization.yaml') - if not os.path.exists(MODEL_CONFIG): - MODEL_CONFIG = wget.download(self.params['diar_config_url'], data_dir) + return score - config = OmegaConf.load(MODEL_CONFIG) - if oracle_manifest == 'asr_based_vad': - # Use ASR-based VAD for diarization. - self.save_VAD_labels_list(words_and_timestamps, audio_file_list) - oracle_manifest = self.write_VAD_rttm(self.oracle_vad_dir, audio_file_list) - - elif oracle_manifest == 'system_vad': - # Use System VAD for diarization. - logging.info(f"Using the provided system VAD model for diarization: {pretrained_vad_model}") - config.diarizer.vad.model_path = pretrained_vad_model - else: - config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest - - output_dir = os.path.join(self.root_path, 'oracle_vad') - config.diarizer.paths2audio_files = audio_file_list - config.diarizer.out_dir = output_dir # Directory to store intermediate files and prediction outputs - if pretrained_speaker_model: - config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model - config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest - config.diarizer.oracle_num_speakers = oracle_num_speakers - config.diarizer.speaker_embeddings.shift_length_in_sec = self.params['shift_length_in_sec'] - config.diarizer.speaker_embeddings.window_length_in_sec = self.params['window_length_in_sec'] - - oracle_model = ClusteringDiarizer(cfg=config) - oracle_model.diarize() - if oracle_manifest == 'system_vad': - self.get_frame_level_VAD(oracle_model, audio_file_list) - diar_labels = self.get_diarization_labels(audio_file_list) - - return diar_labels - - def get_frame_level_VAD(self, oracle_model, audio_file_list): + def get_frame_level_VAD(self, vad_processing_dir): """ Read frame-level VAD. Args: @@ -686,104 +592,68 @@ class ASR_DIAR_OFFLINE(object): audio_file_path (List): List contains file paths for audio files. """ - for k, audio_file_path in enumerate(audio_file_list): - uniq_id = get_uniq_id_from_audio_path(audio_file_path) - vad_pred_diar = oracle_model.vad_pred_dir - frame_vad = os.path.join(vad_pred_diar, uniq_id + '.median') - frame_vad_list = read_file_paths(frame_vad) - frame_vad_float_list = [float(x) for x in frame_vad_list] + for uniq_id in self.AUDIO_RTTM_MAP: + frame_vad = os.path.join(vad_processing_dir, uniq_id + '.median') + frame_vad_float_list = [] + with open(frame_vad, 'r') as fp: + for line in fp.readlines(): + frame_vad_float_list.append(float(line.strip())) self.frame_VAD[uniq_id] = frame_vad_float_list - def get_diarization_labels(self, audio_file_list): + def gather_eval_results(self, metric, mapping_dict): """ - Save the diarization labels into a list. - - Arg: - audio_file_list (list): - List of audio file paths. - - Return: - diar_labels (list): - List of the speaker labels for each speech segment. + Gathers diarization evaluation results from pyannote DiarizationErrorRate metric object. + Inputs + metric (DiarizationErrorRate metric): DiarizationErrorRate metric pyannote object + mapping_dict (dict): Dictionary containing speaker mapping labels for each audio file with key as uniq name + Returns + DER_result_dict (dict): Dictionary containing scores for each audio file along with aggreated results """ - diar_labels = [] - for k, audio_file_path in enumerate(audio_file_list): - uniq_id = get_uniq_id_from_audio_path(audio_file_path) - pred_rttm = os.path.join(self.oracle_vad_dir, 'pred_rttms', uniq_id + '.rttm') - pred_labels = rttm_to_labels(pred_rttm) - diar_labels.append(pred_labels) - est_n_spk = self.get_num_of_spk_from_labels(pred_labels) - logging.info(f"Estimated n_spk [{uniq_id}]: {est_n_spk}") - - return diar_labels - - def eval_diarization(self, audio_file_list, diar_labels, ref_rttm_file_list): - """ - Evaluate the predicted speaker labels (pred_rttm) using ref_rttm_file_list. - DER and speaker counting accuracy are calculated. - - Args: - audio_file_list (list): - List of audio file paths. - ref_rttm_file_list (list): - List of reference rttm paths. - - Returns: - ref_labels_list (list): - Return ref_labels_list for future use. - DER_result_dict (dict): - A dictionary that contains evaluation results. - """ - ref_labels_list = [] - all_hypotheses, all_references = [], [] + results = metric.results_ DER_result_dict = {} count_correct_spk_counting = 0 - - audio_rttm_map = get_audio_rttm_map(audio_file_list, ref_rttm_file_list) - for k, audio_file_path in enumerate(audio_file_list): - uniq_id = get_uniq_id_from_audio_path(audio_file_path) - rttm_file = audio_rttm_map[uniq_id]['rttm_path'] - if os.path.exists(rttm_file): - ref_labels = rttm_to_labels(rttm_file) - ref_labels_list.append(ref_labels) - reference = labels_to_pyannote_object(ref_labels) - all_references.append(reference) - else: - raise ValueError("No reference RTTM file provided.") - - pred_labels = diar_labels[k] + for result in results: + key, score = result + pred_rttm = os.path.join(self.root_path, 'pred_rttms', key + '.rttm') + pred_labels = rttm_to_labels(pred_rttm) est_n_spk = self.get_num_of_spk_from_labels(pred_labels) + ref_rttm = self.AUDIO_RTTM_MAP[key]['rttm_filepath'] + ref_labels = rttm_to_labels(ref_rttm) ref_n_spk = self.get_num_of_spk_from_labels(ref_labels) - hypothesis = labels_to_pyannote_object(pred_labels) - all_hypotheses.append(hypothesis) - DER, CER, FA, MISS, mapping = get_DER([reference], [hypothesis]) - DER_result_dict[uniq_id] = { + DER, CER, FA, MISS = ( + score['diarization error rate'], + score['confusion'], + score['false alarm'], + score['missed detection'], + ) + DER_result_dict[key] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, "n_spk": est_n_spk, - "mapping": mapping[0], + "mapping": mapping_dict[key], "spk_counting": (est_n_spk == ref_n_spk), } + logging.info("score for session {}: {}".format(key, DER_result_dict[key])) count_correct_spk_counting += int(est_n_spk == ref_n_spk) - DER, CER, FA, MISS, mapping = get_DER(all_references, all_hypotheses) - logging.info( - "Cumulative results of all the files: \n FA: {:.4f}\t MISS {:.4f}\t\ - Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - FA, MISS, DER, CER - ) + DER, CER, FA, MISS = ( + abs(metric), + metric['confusion'] / metric['total'], + metric['false alarm'] / metric['total'], + metric['missed detection'] / metric['total'], ) - DER_result_dict['total'] = { + DER_result_dict["total"] = { "DER": DER, "CER": CER, "FA": FA, "MISS": MISS, - "spk_counting_acc": count_correct_spk_counting / len(audio_file_list), + "spk_counting_acc": count_correct_spk_counting / len(metric.results_), } - return ref_labels_list, DER_result_dict + + return DER_result_dict @staticmethod def closest_silence_start(vad_index_word_end, vad_frames, params, offset=10): @@ -807,7 +677,7 @@ class ASR_DIAR_OFFLINE(object): c = vad_index_word_end + offset limit = int(100 * params['max_word_ts_length_in_sec'] + vad_index_word_end) while c < len(vad_frames): - if vad_frames[c] < params['VAD_threshold_for_word_ts']: + if vad_frames[c] < params['vad_threshold_for_word_ts']: break else: c += 1 @@ -839,7 +709,7 @@ class ASR_DIAR_OFFLINE(object): """ enhanced_word_ts_list = [] for idx, word_ts_seq_list in enumerate(word_ts_list): - uniq_id = get_uniq_id_from_audio_path(audio_file_list[idx]) + uniq_id = get_uniqname_from_filepath(audio_file_list[idx]) N = len(word_ts_seq_list) enhanced_word_ts_buffer = [] for k, word_ts in enumerate(word_ts_seq_list): @@ -857,11 +727,14 @@ class ASR_DIAR_OFFLINE(object): min_candidate = min(vad_est_len, len_to_next_word) fixed_word_len = max(min(params['max_word_ts_length_in_sec'], min_candidate), word_len) enhanced_word_ts_buffer.append([word_ts[0], word_ts[0] + fixed_word_len]) + else: + enhanced_word_ts_buffer.append([word_ts[0], word_ts[1]]) + enhanced_word_ts_list.append(enhanced_word_ts_buffer) return enhanced_word_ts_list def write_json_and_transcript( - self, audio_file_list, diar_labels, word_list, word_ts_list, + self, word_list, word_ts_list, ): """ Matches the diarization result with the ASR output. @@ -869,8 +742,6 @@ class ASR_DIAR_OFFLINE(object): in a for loop. Args: - audio_file_list (list): - List that contains audio file paths. diar_labels (list): List of the Diarization output labels in str. word_list (list): @@ -888,15 +759,16 @@ class ASR_DIAR_OFFLINE(object): """ total_riva_dict = {} if self.fix_word_ts_with_VAD: - word_ts_list = self.compensate_word_ts_list(audio_file_list, word_ts_list, self.params) + word_ts_list = self.compensate_word_ts_list(self.audio_file_list, word_ts_list, self.params) if self.frame_VAD == {}: logging.info( f"VAD timestamps are not provided and skipping word timestamp fix. Please check the VAD model." ) - for k, audio_file_path in enumerate(audio_file_list): - uniq_id = get_uniq_id_from_audio_path(audio_file_path) - labels = diar_labels[k] + for k, audio_file_path in enumerate(self.audio_file_list): + uniq_id = get_uniqname_from_filepath(audio_file_path) + pred_rttm = os.path.join(self.root_path, 'pred_rttms', uniq_id + '.rttm') + labels = rttm_to_labels(pred_rttm) audacity_label_words = [] n_spk = self.get_num_of_spk_from_labels(labels) string_out = '' @@ -915,7 +787,6 @@ class ASR_DIAR_OFFLINE(object): logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ") string_out = self.print_time(string_out, speaker, start_point, end_point, self.params) - word_pos, idx = 0, 0 for j, word_ts_stt_end in enumerate(word_ts_list[k]): @@ -941,7 +812,7 @@ class ASR_DIAR_OFFLINE(object): return total_riva_dict - def get_WDER(self, audio_file_list, total_riva_dict, DER_result_dict, ref_labels_list): + def get_WDER(self, total_riva_dict, DER_result_dict): """ Calculate word-level diarization error rate (WDER). WDER is calculated by counting the the wrongly diarized words and divided by the total number of words @@ -953,8 +824,6 @@ class ASR_DIAR_OFFLINE(object): DER_result_dict (dict): The dictionary that stores DER, FA, Miss, CER, mapping, the estimated number of speakers and speaker counting accuracy. - audio_file_list (list): - List that contains audio file paths. ref_labels_list (list): List that contains the ground truth speaker labels for each segment. @@ -964,10 +833,11 @@ class ASR_DIAR_OFFLINE(object): """ wder_dict = {} grand_total_word_count, grand_correct_word_count = 0, 0 - for k, audio_file_path in enumerate(audio_file_list): + for k, audio_file_path in enumerate(self.audio_file_list): - labels = ref_labels_list[k] - uniq_id = get_uniq_id_from_audio_path(audio_file_path) + uniq_id = get_uniqname_from_filepath(audio_file_path) + ref_rttm = self.AUDIO_RTTM_MAP[uniq_id]['rttm_filepath'] + labels = rttm_to_labels(ref_rttm) mapping_dict = DER_result_dict[uniq_id]['mapping'] words_list = total_riva_dict[uniq_id]['words'] @@ -1027,7 +897,7 @@ class ASR_DIAR_OFFLINE(object): Saves the diarization result into a csv file. """ row = [ - args.threshold, + args.asr_based_vad_threshold, WDER_dict['total'], DER_result_dict['total']['DER'], DER_result_dict['total']['FA'], @@ -1079,14 +949,14 @@ class ASR_DIAR_OFFLINE(object): """Writes output files and display logging messages. """ ROOT = self.root_path - logging.info(f"Writing {ROOT}/json_result/{uniq_id}.json") - dump_json_to_file(f'{ROOT}/json_result/{uniq_id}.json', riva_dict) + logging.info(f"Writing {ROOT}/pred_rttms/{uniq_id}.json") + dump_json_to_file(f'{ROOT}/pred_rttms/{uniq_id}.json', riva_dict) - logging.info(f"Writing {ROOT}/transcript_with_speaker_labels/{uniq_id}.txt") - write_txt(f'{ROOT}/transcript_with_speaker_labels/{uniq_id}.txt', string_out.strip()) + logging.info(f"Writing {ROOT}/pred_rttms/{uniq_id}.txt") + write_txt(f'{ROOT}/pred_rttms/{uniq_id}.txt', string_out.strip()) - logging.info(f"Writing {ROOT}/audacity_label/{uniq_id}.w.label") - write_txt(f'{ROOT}/audacity_label/{uniq_id}.w.label', '\n'.join(audacity_label_words)) + logging.info(f"Writing {ROOT}/pred_rttms/{uniq_id}.w.label") + write_txt(f'{ROOT}/pred_rttms/{uniq_id}.w.label', '\n'.join(audacity_label_words)) @staticmethod def clean_trans_and_TS(trans, char_ts): @@ -1119,51 +989,9 @@ class ASR_DIAR_OFFLINE(object): char_ts = char_ts[: -1 * diff_R] return trans, char_ts - @staticmethod - def write_VAD_rttm_from_speech_labels(ROOT, AUDIO_FILENAME, speech_labels): - """Writes a VAD rttm file from speech_labels list. - """ - uniq_id = get_uniq_id_from_audio_path(AUDIO_FILENAME) - oracle_vad_dir = os.path.join(ROOT, 'oracle_vad') - with open(f'{oracle_vad_dir}/{uniq_id}.rttm', 'w') as f: - for spl in speech_labels: - start, end, speaker = spl.split() - start, end = float(start), float(end) - f.write("SPEAKER {} 1 {:.3f} {:.3f} speech \n".format(uniq_id, start, end - start)) - - @staticmethod - def write_VAD_rttm(oracle_vad_dir, audio_file_list, reference_rttmfile_list_path=None): - """ - Writes VAD files to the oracle_vad_dir folder. - - Args: - oracle_vad_dir (str): - The path of oracle VAD folder. - audio_file_list (list): - List of audio file paths. - - Return: - oracle_manifest (str): - Returns the full path of orcale_manifest.json file. - """ - if not reference_rttmfile_list_path: - reference_rttmfile_list_path = [] - for path_name in audio_file_list: - uniq_id = get_uniq_id_from_audio_path(path_name) - reference_rttmfile_list_path.append(f'{oracle_vad_dir}/{uniq_id}.rttm') - - oracle_manifest = os.path.join(oracle_vad_dir, 'oracle_manifest.json') - - write_rttm2manifest( - paths2audio_files=audio_file_list, - paths2rttm_files=reference_rttmfile_list_path, - manifest_file=oracle_manifest, - ) - return oracle_manifest - @staticmethod def threshold_non_speech(source_list, params): - return list(filter(lambda x: x[1] - x[0] > params['threshold'], source_list)) + return list(filter(lambda x: x[1] - x[0] > params['asr_based_vad_threshold'], source_list)) @staticmethod def get_effective_WDER(DER_result_dict, WDER_dict): diff --git a/nemo/collections/asr/parts/utils/nmse_clustering.py b/nemo/collections/asr/parts/utils/nmse_clustering.py index 708a0492f..461904eaf 100644 --- a/nemo/collections/asr/parts/utils/nmse_clustering.py +++ b/nemo/collections/asr/parts/utils/nmse_clustering.py @@ -497,6 +497,8 @@ def COSclustering( max_num_speaker=8, min_samples_for_NMESC=6, enhanced_count_thres=80, + max_rp_threshold=0.25, + sparse_search_volume=30, fixed_thres=None, cuda=False, ): @@ -527,6 +529,16 @@ def COSclustering( Thus, getEnhancedSpeakerCount() employs anchor embeddings (dummy representations) to mitigate the effect of cluster sparsity. enhanced_count_thres = 80 is recommended. + + max_rp_threshold: (float) + Limits the range of parameter search. + Clustering performance can vary depending on this range. + Default is 0.25. + + sparse_search_volume: (int) + The number of p_values we search during NME analysis. + Default is 30. The lower the value, the faster NME-analysis becomes. + Lower than 20 might cause a poor parameter estimation. Returns: Y: (List[int]) @@ -546,10 +558,10 @@ def COSclustering( nmesc = NMESC( mat, max_num_speaker=max_num_speaker, - max_rp_threshold=0.25, + max_rp_threshold=max_rp_threshold, sparse_search=True, - sparse_search_volume=30, - fixed_thres=None, + sparse_search_volume=sparse_search_volume, + fixed_thres=fixed_thres, NME_mat_size=300, cuda=cuda, ) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 40e84728d..60b7b0793 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -14,13 +14,12 @@ import json import math import os -import pickle as pkl from copy import deepcopy import numpy as np +import soundfile as sf import torch -from omegaconf import ListConfig -from pyannote.core import Annotation, Segment +from pyannote.core import Annotation, Segment, Timeline from pyannote.metrics.diarization import DiarizationErrorRate from tqdm import tqdm @@ -33,57 +32,49 @@ This file contains all the utility functions required for speaker embeddings par """ -def audio_rttm_map(audio_file_list, rttm_file_list=None): - """ - Returns a AUDIO_TO_RTTM dictionary thats maps all the unique file names with - audio file paths and corresponding ground truth rttm files calculated from audio - file list and rttm file list - - Args: - audio_file_list(list,str): either list of audio file paths or file containing paths to audio files (required) - rttm_file_list(lisr,str): either list of rttm file paths or file containing paths to rttm files (optional) - [Required if DER needs to be calculated] - Returns: - AUDIO_RTTM_MAP (dict): dictionary thats maps all the unique file names with - audio file paths and corresponding ground truth rttm files - """ - rttm_notfound = False - if type(audio_file_list) in [list, ListConfig]: - audio_files = audio_file_list +def get_uniqname_from_filepath(filepath): + "return base name from provided filepath" + if type(filepath) is str: + basename = os.path.basename(filepath).rsplit('.', 1)[0] + return basename else: - audio_pointer = open(audio_file_list, 'r') - audio_files = audio_pointer.read().splitlines() + raise TypeError("input must be filepath string") - if rttm_file_list: - if type(rttm_file_list) in [list, ListConfig]: - rttm_files = rttm_file_list - else: - rttm_pointer = open(rttm_file_list, 'r') - rttm_files = rttm_pointer.read().splitlines() - else: - rttm_notfound = True - rttm_files = ['-'] * len(audio_files) - assert len(audio_files) == len(rttm_files) +def audio_rttm_map(manifest): + """ + This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, + cluster and unify time stamps + input: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists + + returns: + AUDIO_RTTM_MAP (dict) : Dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + """ AUDIO_RTTM_MAP = {} - rttm_dict = {} - audio_dict = {} - for audio_file, rttm_file in zip(audio_files, rttm_files): - uniq_audio_name = audio_file.split('/')[-1].rsplit('.', 1)[0] - uniq_rttm_name = rttm_file.split('/')[-1].rsplit('.', 1)[0] + with open(manifest, 'r') as inp_file: + lines = inp_file.readlines() + logging.info("Number of files to diarize: {}".format(len(lines))) + for line in lines: + line = line.strip() + dic = json.loads(line) - if rttm_notfound: - uniq_rttm_name = uniq_audio_name + meta = { + 'audio_filepath': dic['audio_filepath'], + 'rttm_filepath': dic.get('rttm_filepath', None), + 'text': dic.get('text', '-'), + 'num_speakers': dic.get('num_speakers', None), + 'uem_filepath': dic.get('uem_filepath'), + } - audio_dict[uniq_audio_name] = audio_file - rttm_dict[uniq_rttm_name] = rttm_file + uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) - for key, value in audio_dict.items(): - - AUDIO_RTTM_MAP[key] = {'audio_path': audio_dict[key], 'rttm_path': rttm_dict[key]} - - assert len(rttm_dict.items()) == len(audio_dict.items()) + if uniqname not in AUDIO_RTTM_MAP: + AUDIO_RTTM_MAP[uniqname] = meta + else: + raise KeyError( + "file {} is already part AUDIO_RTTM_Map, it might be duplicated".format(meta['audio_filepath']) + ) return AUDIO_RTTM_MAP @@ -128,11 +119,11 @@ def merge_stamps(lines): return overlap_stamps -def labels_to_pyannote_object(labels): +def labels_to_pyannote_object(labels, uniq_name=''): """ converts labels to pyannote object to calculate DER and for visualization """ - annotation = Annotation() + annotation = Annotation(uri=uniq_name) for label in labels: start, end, speaker = label.strip().split() start, end = float(start), float(end) @@ -141,6 +132,24 @@ def labels_to_pyannote_object(labels): return annotation +def uem_timeline_from_file(uem_file, uniq_name=''): + """ + outputs pyannote timeline segments for uem file + + file format + UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + """ + timeline = Timeline(uri=uniq_name) + with open(uem_file, 'r') as f: + lines = f.readlines() + for line in lines: + line = line.strip() + speaker_id, channel, start_time, end_time = line.split() + timeline.add(Segment(float(start_time), float(end_time))) + + return timeline + + def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): """ write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels @@ -155,6 +164,8 @@ def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): log = 'SPEAKER {} 1 {:.3f} {:.3f} {} \n'.format(uniq_id, start, duration, speaker) f.write(log) + return filename + def rttm_to_labels(rttm_filename): """ @@ -169,94 +180,52 @@ def rttm_to_labels(rttm_filename): return labels -def get_time_stamps(embeddings_file, reco2num, manifest_path): - """ - Loads embedding file and generates time stamps based on window and shift for speaker embeddings - clustering. - - Args: - embeddings_file (str): Path to embeddings pickle file - reco2num (int,str): common integer number of speakers for every recording or path to - file that states unique_file id with number of speakers - manifest_path (str): path to manifest for time stamps matching - sample_rate (int): sample rate - window (float): window length in sec for speaker embeddings extraction - shift (float): shift length in sec for speaker embeddings extraction window shift - - Returns: - embeddings (dict): Embeddings with key as unique_id - time_stamps (dict): time stamps list for each audio recording - speakers (dict): number of speaker for each audio recording - """ - embeddings = pkl.load(open(embeddings_file, 'rb')) - all_uniq_files = list(embeddings.keys()) - num_files = len(all_uniq_files) - logging.info("Number of files to diarize: {}".format(num_files)) - sample = all_uniq_files[0] - logging.info("sample '{}' embeddings shape is {}\n".format(sample, embeddings[sample][0].shape)) - - speakers = {} - if isinstance(reco2num, int) or reco2num is None: - for key in embeddings.keys(): - speakers[key] = reco2num - elif isinstance(reco2num, str) and os.path.exists(reco2num): - for key in open(reco2num).readlines(): - key = key.strip() - wav_id, num = key.split() - speakers[wav_id] = int(num) - else: - raise TypeError("reco2num must be None, int or path to file containing number of speakers for each uniq id") - - with open(manifest_path, 'r') as manifest: - time_stamps = {} - for line in manifest.readlines(): - line = line.strip() - line = json.loads(line) - audio, offset, duration = line['audio_filepath'], line['offset'], line['duration'] - audio = os.path.basename(audio).rsplit('.', 1)[0] - if audio not in time_stamps: - time_stamps[audio] = [] - start = offset - end = start + duration - stamp = '{:.3f} {:.3f} '.format(start, end) - time_stamps[audio].append(stamp) - - return embeddings, time_stamps, speakers - - -def perform_clustering(embeddings, time_stamps, speakers, audio_rttm_map, out_rttm_dir, max_num_speakers=8): +def perform_clustering(embeddings, time_stamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params): """ performs spectral clustering on embeddings with time stamps generated from VAD output Args: embeddings (dict): Embeddings with key as unique_id time_stamps (dict): time stamps list for each audio recording - speakers (dict): number of speaker for each audio recording - audio_rttm_map (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path + AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path out_rttm_dir (str): Path to write predicted rttms - max_num_speakers (int): maximum number of speakers to consider for spectral clustering. Will be ignored if speakers['key'] is not None + clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) Returns: - all_reference (list[Annotation]): reference annotations for score calculation - all_hypothesis (list[Annotation]): hypothesis annotations for score calculation + all_reference (list[uniq_name,Annotation]): reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation """ all_hypothesis = [] all_reference = [] no_references = False + max_num_speakers = clustering_params['max_num_speakers'] cuda = True if not torch.cuda.is_available(): logging.warning("cuda=False, using CPU for Eigen decompostion. This might slow down the clustering process.") cuda = False - for uniq_key in tqdm(embeddings.keys()): - NUM_speakers = speakers[uniq_key] + for uniq_key, value in tqdm(AUDIO_RTTM_MAP.items()): + if clustering_params.oracle_num_speakers: + num_speakers = value.get('num_speakers', None) + if num_speakers is None: + raise ValueError("Provided option as oracle num of speakers but num_speakers in manifest is null") + else: + num_speakers = None emb = embeddings[uniq_key] emb = np.asarray(emb) cluster_labels = COSclustering( - uniq_key, emb, oracle_num_speakers=NUM_speakers, max_num_speaker=max_num_speakers, cuda=cuda, + uniq_key, + emb, + oracle_num_speakers=num_speakers, + max_num_speaker=max_num_speakers, + enhanced_count_thres=clustering_params.enhanced_count_thres, + max_rp_threshold=clustering_params.max_rp_threshold, + sparse_search_volume=clustering_params.sparse_search_volume, + cuda=cuda, ) lines = time_stamps[uniq_key] @@ -269,14 +238,14 @@ def perform_clustering(embeddings, time_stamps, speakers, audio_rttm_map, out_rt labels = merge_stamps(a) if out_rttm_dir: labels_to_rttmfile(labels, uniq_key, out_rttm_dir) - hypothesis = labels_to_pyannote_object(labels) - all_hypothesis.append(hypothesis) + hypothesis = labels_to_pyannote_object(labels, uniq_name=uniq_key) + all_hypothesis.append([uniq_key, hypothesis]) - rttm_file = audio_rttm_map[uniq_key]['rttm_path'] - if os.path.exists(rttm_file) and not no_references: + rttm_file = value.get('rttm_filepath', None) + if rttm_file is not None and os.path.exists(rttm_file) and not no_references: ref_labels = rttm_to_labels(rttm_file) - reference = labels_to_pyannote_object(ref_labels) - all_reference.append(reference) + reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_key) + all_reference.append([uniq_key, reference]) else: no_references = True all_reference = [] @@ -284,19 +253,18 @@ def perform_clustering(embeddings, time_stamps, speakers, audio_rttm_map, out_rt return all_reference, all_hypothesis -def get_DER(all_reference, all_hypothesis, collar=0.5, skip_overlap=True): +def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True): """ calculates DER, CER, FA and MISS Args: - all_reference (list[Annotation]): reference annotations for score calculation - all_hypothesis (list[Annotation]): hypothesis annotations for score calculation + AUDIO_RTTM_MAP : Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation Returns: - DER (float): Diarization Error Rate - CER (float): Confusion Error Rate - FA (float): False Alarm - Miss (float): Miss Detection + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. + mapping (dict): Mapping dict containing the mapping speaker label for each audio input < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -304,93 +272,113 @@ def get_DER(all_reference, all_hypothesis, collar=0.5, skip_overlap=True): collar in md-eval.pl, 0.5s should be applied for pyannote.metrics. """ - metric = DiarizationErrorRate(collar=collar, skip_overlap=skip_overlap, uem=None) + metric = None + if len(all_reference) == len(all_hypothesis): + metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - mapping_dict = {} - for k, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): - metric(reference, hypothesis, detailed=True) - mapping_dict[k] = metric.optimal_mapping(reference, hypothesis) + mapping_dict = {} + for (reference, hypothesis) in zip(all_reference, all_hypothesis): + ref_key, ref_labels = reference + _, hyp_labels = hypothesis + uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) + if uem is not None: + uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key) + metric(ref_labels, hyp_labels, uem=uem, detailed=True) + mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) - DER = abs(metric) - CER = metric['confusion'] / metric['total'] - FA = metric['false alarm'] / metric['total'] - MISS = metric['missed detection'] / metric['total'] + DER = abs(metric) + CER = metric['confusion'] / metric['total'] + FA = metric['false alarm'] / metric['total'] + MISS = metric['missed detection'] / metric['total'] - metric.reset() - - return DER, CER, FA, MISS, mapping_dict - - -def perform_diarization( - embeddings_file=None, reco2num=2, manifest_path=None, audio_rttm_map=None, out_rttm_dir=None, max_num_speakers=8, -): - """ - Performs diarization with embeddings generated based on VAD time stamps with recording 2 num of speakers (reco2num) - for spectral clustering - """ - embeddings, time_stamps, speakers = get_time_stamps(embeddings_file, reco2num, manifest_path) - logging.info("Performing Clustering") - all_reference, all_hypothesis = perform_clustering( - embeddings, time_stamps, speakers, audio_rttm_map, out_rttm_dir, max_num_speakers - ) - - if len(all_reference) and len(all_hypothesis): - DER, CER, FA, MISS, _ = get_DER(all_reference, all_hypothesis) logging.info( - "Cumulative results of all the files: \n FA: {:.4f}\t MISS {:.4f}\t \ + "Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \ Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - FA, MISS, DER, CER + collar, ignore_overlap, FA, MISS, DER, CER ) ) + + return metric, mapping_dict else: - logging.warning("Please check if each ground truth RTTMs was present in provided path2groundtruth_rttm_files") - logging.warning("Skipping calculation of Diariazation Error Rate") + logging.warning( + "check if each ground truth RTTMs were present in provided manifest file. Skipping calculation of Diariazation Error Rate" + ) + + return None -def write_rttm2manifest(paths2audio_files, paths2rttm_files, manifest_file): +def write_rttm2manifest(AUDIO_RTTM_MAP, manifest_file): """ writes manifest file based on rttm files (or vad table out files). This manifest file would be used by speaker diarizer to compute embeddings and cluster them. This function also takes care of overlap time stamps Args: - audio_file_list(list,str): either list of audio file paths or file containing paths to audio files (required) - rttm_file_list(lisr,str): either list of rttm file paths or file containing paths to rttm files (optional) + AUDIO_RTTM_MAP: dict containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents, + these are used to extract oracle vad timestamps. manifest (str): path to write manifest file Returns: manifest (str): path to write manifest file """ - AUDIO_RTTM_MAP = audio_rttm_map(paths2audio_files, paths2rttm_files) with open(manifest_file, 'w') as outfile: for key in AUDIO_RTTM_MAP: - f = open(AUDIO_RTTM_MAP[key]['rttm_path'], 'r') - audio_path = AUDIO_RTTM_MAP[key]['audio_path'] + rttm_filename = AUDIO_RTTM_MAP[key]['rttm_filepath'] + if rttm_filename and os.path.exists(rttm_filename): + f = open(rttm_filename, 'r') + else: + raise FileNotFoundError( + "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( + rttm_filename + ) + ) + + audio_path = AUDIO_RTTM_MAP[key]['audio_filepath'] + if AUDIO_RTTM_MAP[key].get('duration', None): + max_duration = AUDIO_RTTM_MAP[key]['duration'] + else: + sound = sf.SoundFile(audio_path) + max_duration = sound.frames / sound.samplerate + lines = f.readlines() time_tup = (-1, -1) for line in lines: vad_out = line.strip().split() if len(vad_out) > 3: - start, dur, activity = float(vad_out[3]), float(vad_out[4]), vad_out[7] + start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] else: - start, dur, activity = float(vad_out[0]), float(vad_out[1]), vad_out[2] + start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2] start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur)) if time_tup[0] >= 0 and start > time_tup[1]: dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) - meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'} - json.dump(meta, outfile) - outfile.write("\n") + if time_tup[0] < max_duration and dur2 > 0: + meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'} + json.dump(meta, outfile) + outfile.write("\n") + else: + logging.warning( + "RTTM label has been truncated since start is greater than duration of audio file" + ) time_tup = (start, start + dur) else: if time_tup[0] == -1: - time_tup = (start, start + dur) + end_time = start + dur + if end_time > max_duration: + end_time = max_duration + time_tup = (start, end_time) else: - time_tup = (min(time_tup[0], start), max(time_tup[1], start + dur)) + end_time = max(time_tup[1], start + dur) + if end_time > max_duration: + end_time = max_duration + time_tup = (min(time_tup[0], start), end_time) dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) - meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'} - json.dump(meta, outfile) - outfile.write("\n") + if time_tup[0] < max_duration and dur2 > 0: + meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'} + json.dump(meta, outfile) + outfile.write("\n") + else: + logging.warning("RTTM label has been truncated since start is greater than duration of audio file") f.close() return manifest_file diff --git a/scripts/speaker_tasks/pathsfiles_to_manifest.py b/scripts/speaker_tasks/pathsfiles_to_manifest.py new file mode 100644 index 000000000..5e0e14514 --- /dev/null +++ b/scripts/speaker_tasks/pathsfiles_to_manifest.py @@ -0,0 +1,105 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import random +from collections import Counter + +from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels + +random.seed(42) + +""" +This script creates manifest file for speaker diarization inference purposes. +Useful to get manifest when you have list of audio files and optionally rttm and uem files for evaluation + +Note: make sure basename for each file is unique and rttm files also has the corresponding base name for mapping +""" + + +def write_file(name, lines, idx): + with open(name, 'w') as fout: + for i in idx: + dic = lines[i] + json.dump(dic, fout) + fout.write('\n') + logging.info("wrote", name) + + +def read_file(scp): + scp = open(scp, 'r').readlines() + return sorted(scp) + + +def main(wav_scp, rttm_scp=None, uem_scp=None, manifest_filepath=None): + if os.path.exists(manifest_filepath): + os.remove(manifest_filepath) + + wav_scp = read_file(wav_scp) + + if rttm_scp is not None: + rttm_scp = read_file(rttm_scp) + else: + rttm_scp = len(wav_scp) * [None] + + if uem_scp is not None: + uem_scp = read_file(uem_scp) + else: + uem_scp = len(wav_scp) * [None] + + lines = [] + for wav, rttm, uem in zip(wav_scp, rttm_scp, uem_scp): + audio_line = wav.strip() + if rttm is not None: + rttm = rttm.strip() + labels = rttm_to_labels(rttm) + num_speakers = Counter([l.split()[-1] for l in labels]).keys().__len__() + else: + num_speakers = None + + if uem is not None: + uem = uem.strip() + + meta = [ + { + "audio_filepath": audio_line, + "offset": 0, + "duration": None, + "label": "infer", + "text": "-", + "num_speakers": num_speakers, + "rttm_filepath": rttm, + "uem_filepath": uem, + } + ] + lines.extend(meta) + + write_file(manifest_filepath, lines, range(len(lines))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--paths2audio_files", help="path to text file containing list of audio files", type=str, required=True + ) + parser.add_argument("--paths2rttm_files", help="path to text file containing list of rttm files", type=str) + parser.add_argument("--paths2uem_files", help="path to uem files", type=str) + parser.add_argument("--manifest_filepath", help="scp file name", type=str, required=True) + + args = parser.parse_args() + + main(args.paths2audio_files, args.paths2rttm_files, args.paths2uem_files, args.manifest_filepath) diff --git a/scripts/speaker_tasks/rttm_to_manifest.py b/scripts/speaker_tasks/rttm_to_manifest.py deleted file mode 100644 index 3d7b772b0..000000000 --- a/scripts/speaker_tasks/rttm_to_manifest.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from nemo.collections.asr.parts.utils.speaker_utils import write_rttm2manifest -from nemo.utils import logging - -""" -This file converts vad outputs to manifest file for speaker diarization purposes -present in vad output directory. -every vad line consists of start_time, end_time , speech/non-speech -""" - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--paths2rttm_files", help="path to vad output rttm-like files. Could be a list or a text file", required=True - ) - parser.add_argument( - "--paths2audio_files", - help="path to audio files that vad was computed. Could be a list or a text file", - required=True, - ) - parser.add_argument("--manifest_file", help="output manifest file name", type=str, required=True) - args = parser.parse_args() - - write_rttm2manifest(args.paths2audio_files, args.paths2rttm_files, args.manifest_file) - logging.info("wrote {} file from vad output files present in {}".format(args.manifest_file, args.paths2rttm_files)) diff --git a/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb b/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb index 0b6e9eaf3..fcaec6cfa 100644 --- a/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb +++ b/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb @@ -99,18 +99,6 @@ "pp = pprint.PrettyPrinter(indent=4)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def read_file(path_to_file):\n", - " with open(path_to_file) as f:\n", - " contents = f.read().splitlines()\n", - " return contents" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -127,6 +115,7 @@ "ROOT = os.getcwd()\n", "data_dir = os.path.join(ROOT,'data')\n", "os.makedirs(data_dir, exist_ok=True)\n", + "\n", "an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n", "if not os.path.exists(os.path.join(data_dir,'an4_diarize_test.wav')):\n", " AUDIO_FILENAME = wget.download(an4_audio_url, data_dir)\n", @@ -213,21 +202,99 @@ "metadata": {}, "outputs": [], "source": [ - "CONFIG_URL = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml\"\n", + "from omegaconf import OmegaConf\n", + "import shutil\n", + "CONFIG_URL = \"https://raw.githubusercontent.com/NVIDIA/NeMo/modify_speaker_input/examples/speaker_tasks/diarization/conf/offline_diarization_with_asr.yaml\"\n", "\n", - "params = {\n", - " \"round_float\": 2,\n", - " \"window_length_in_sec\": 1.0,\n", - " \"shift_length_in_sec\": 0.25,\n", - " \"fix_word_ts_with_VAD\": False,\n", - " \"print_transcript\": False,\n", - " \"word_gap_in_sec\": 0.01,\n", - " \"max_word_ts_length_in_sec\": 0.6,\n", - " \"minimum\": True,\n", - " \"threshold\": 300, \n", - " \"diar_config_url\": CONFIG_URL,\n", - " \"ASR_model_name\": 'QuartzNet15x5Base-En',\n", - "}\n" + "if not os.path.exists(os.path.join(data_dir,'offline_diarization_with_asr.yaml')):\n", + " CONFIG = wget.download(CONFIG_URL, data_dir)\n", + "else:\n", + " CONFIG = os.path.join(data_dir,'offline_diarization_with_asr.yaml')\n", + " \n", + "cfg = OmegaConf.load(CONFIG)\n", + "print(OmegaConf.to_yaml(cfg))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Speaker Diarization scripts commonly expects following arguments:\n", + "1. manifest_filepath : Path to manifest file containing json lines of format: {\"audio_filepath\": \"/path/to/audio_file\", \"offset\": 0, \"duration\": null, \"label\": \"infer\", \"text\": \"-\", \"num_speakers\": null, \"rttm_filepath\": \"/path/to/rttm/file\", \"uem_filepath\"=\"/path/to/uem/filepath\"}\n", + "2. out_dir : directory where outputs and intermediate files are stored. \n", + "3. oracle_vad: If this is true then we extract speech activity labels from rttm files, if False then either \n", + "4. vad.model_path or external_manifestpath containing speech activity labels has to be passed. \n", + "\n", + "Mandatory fields are audio_filepath, offset, duration, label and text. For the rest if you would like to evaluate with known number of speakers pass the value else `null`. If you would like to score the system with known rttms then that should be passed as well, else `null`. uem file is used to score only part of your audio for evaluation purposes, hence pass if you would like to evaluate on it else `null`.\n", + "\n", + "\n", + "**Note:** we expect audio and corresponding RTTM have **same base name** and the name should be **unique**. \n", + "\n", + "For example: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets create manifest with the an4 audio and rttm available. If you have more than one files you may also use the script `NeMo/scripts/speaker_tasks/rttm_to_manifest.py` to generate manifest file from list of audio files and optionally rttm files " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a manifest for input with below format. \n", + "# {\"audio_filepath\": \"/path/to/audio_file\", \"offset\": 0, \"duration\": null, \"label\": \"infer\", \"text\": \"-\", \n", + "# \"num_speakers\": null, \"rttm_filepath\": \"/path/to/rttm/file\", \"uem_filepath\"=\"/path/to/uem/filepath\"}\n", + "import json\n", + "meta = {\n", + " 'audio_filepath': AUDIO_FILENAME, \n", + " 'offset': 0, \n", + " 'duration':None, \n", + " 'label': 'infer', \n", + " 'text': '-', \n", + " 'num_speakers': 2, \n", + " 'rttm_filepath': None, \n", + " 'uem_filepath' : None\n", + "}\n", + "with open(os.path.join(data_dir,'input_manifest.json'),'w') as fp:\n", + " json.dump(meta,fp)\n", + " fp.write('\\n')\n", + "\n", + "\n", + "cfg.diarizer.manifest_filepath = os.path.join(data_dir,'input_manifest.json')\n", + "!cat {cfg.diarizer.manifest_filepath}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set the parameters required for diarization, here we get voice activity labels from ASR, which is set through parameter `cfg.diarizer.asr.parameters.asr_based_vad`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pretrained_speaker_model='ecapa_tdnn'\n", + "cfg.diarizer.manifest_filepath = cfg.diarizer.manifest_filepath\n", + "cfg.diarizer.out_dir = data_dir #Directory to store intermediate files and prediction outputs\n", + "cfg.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n", + "cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5\n", + "cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75\n", + "cfg.diarizer.clustering.parameters.oracle_num_speakers=True\n", + "\n", + "# USE VAD generated from ASR timestamps\n", + "cfg.diarizer.asr.model_path = 'QuartzNet15x5Base-En'\n", + "cfg.diarizer.oracle_vad = False # ----> ORACLE VAD \n", + "cfg.diarizer.asr.parameters.asr_based_vad = True\n", + "cfg.diarizer.asr.parameters.threshold=300" ] }, { @@ -243,35 +310,13 @@ "metadata": {}, "outputs": [], "source": [ - "asr_diar_offline = ASR_DIAR_OFFLINE(params)\n", - "asr_model = asr_diar_offline.set_asr_model(params['ASR_model_name'])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will create folders that we need for storing VAD stamps and ASR/diarization results.\n", + "from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map\n", + "asr_diar_offline = ASR_DIAR_OFFLINE(**cfg.diarizer.asr.parameters)\n", + "asr_diar_offline.root_path = cfg.diarizer.out_dir\n", "\n", - "Under the folder named ``asr_with_diar``, the following folders will be created.\n", - "\n", - "- ``oracle_vad``\n", - "- ``json_result``\n", - "- ``transcript_with_speaker_labels``\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "asr_diar_offline.create_directories()\n", - "\n", - "print(\"Folders are created as below.\")\n", - "print(\"VAD file path: \\n\", asr_diar_offline.oracle_vad_dir)\n", - "print(\"JSON result path: \\n\", asr_diar_offline.json_result_dir)\n", - "print(\"Transcript result path: \\n\", asr_diar_offline.trans_with_spks_dir)" + "AUDIO_RTTM_MAP = audio_rttm_map(cfg.diarizer.manifest_filepath)\n", + "asr_diar_offline.AUDIO_RTTM_MAP = AUDIO_RTTM_MAP\n", + "asr_model = asr_diar_offline.set_asr_model(cfg.diarizer.asr.model_path)" ] }, { @@ -296,7 +341,7 @@ "metadata": {}, "outputs": [], "source": [ - "word_list, word_ts_list = asr_diar_offline.run_ASR(audio_file_list, asr_model)\n", + "word_list, word_ts_list = asr_diar_offline.run_ASR(asr_model)\n", "\n", "print(\"Decoded word output: \\n\", word_list[0])\n", "print(\"Word-level timestamps \\n\", word_ts_list[0])" @@ -325,23 +370,14 @@ "metadata": {}, "outputs": [], "source": [ - "oracle_manifest = 'asr_based_vad'\n", - "# If we know the number of speakers, we can assign \"2\".\n", - "num_speakers = None\n", - "speaker_embedding_model = 'ecapa_tdnn'\n", - "\n", - "diar_labels = asr_diar_offline.run_diarization(audio_file_list, \n", - " word_ts_list, \n", - " oracle_manifest=oracle_manifest, \n", - " oracle_num_speakers=num_speakers,\n", - " pretrained_speaker_model=speaker_embedding_model)" + "score = asr_diar_offline.run_diarization(cfg, word_ts_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`run_diarization()` function creates `./asr_with_diar/oracle_vad/pred_rttm/an4_diarize_test.rttm` file. Let's see what is written in this `rttm` file." + "`run_diarization()` function creates `an4_diarize_test.rttm` file. Let's see what is written in this `rttm` file." ] }, { @@ -350,29 +386,34 @@ "metadata": {}, "outputs": [], "source": [ - "predicted_speaker_label_rttm_path = f\"{ROOT}/asr_with_diar/oracle_vad/pred_rttms/an4_diarize_test.rttm\"\n", + "def read_file(path_to_file):\n", + " with open(path_to_file) as f:\n", + " contents = f.read().splitlines()\n", + " return contents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predicted_speaker_label_rttm_path = f\"{data_dir}/pred_rttms/an4_diarize_test.rttm\"\n", "pred_rttm = read_file(predicted_speaker_label_rttm_path)\n", "\n", "pp.pprint(pred_rttm)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`run_diarization()` also returns a variable named `diar_labels` which contains the estimated speaker label information with timestamps from the predicted rttm file." - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "print(\"Diarization Labels:\")\n", - "pp.pprint(diar_labels[0])\n", + "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels\n", + "pred_labels = rttm_to_labels(predicted_speaker_label_rttm_path)\n", "\n", - "color = get_color(signal, diar_labels[0])\n", + "color = get_color(signal, pred_labels)\n", "display_waveform(signal,'Audio with Speaker Labels', color)\n", "display(Audio(signal,rate=16000))" ] @@ -391,14 +432,14 @@ "metadata": {}, "outputs": [], "source": [ - "asr_output_dict = asr_diar_offline.write_json_and_transcript(audio_file_list, diar_labels, word_list, word_ts_list)" + "asr_output_dict = asr_diar_offline.write_json_and_transcript(word_list, word_ts_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "After running `write_json_and_transcript()` function, the transcription output will be located in `./asr_with_diar/transcript_with_speaker_labels` folder, which shows **start time to end time of the utterance, speaker ID, and words spoken** during the notified time." + "After running `write_json_and_transcript()` function, the transcription output will be located in `./pred_rttms` folder, which shows **start time to end time of the utterance, speaker ID, and words spoken** during the notified time." ] }, { @@ -407,7 +448,7 @@ "metadata": {}, "outputs": [], "source": [ - "transcription_path_to_file = f\"{ROOT}/asr_with_diar/transcript_with_speaker_labels/an4_diarize_test.txt\"\n", + "transcription_path_to_file = f\"{data_dir}/pred_rttms/an4_diarize_test.txt\"\n", "transcript = read_file(transcription_path_to_file)\n", "pp.pprint(transcript)" ] @@ -416,7 +457,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Another output is transcription output in JSON format, which is saved in `./asr_with_diar/json_result`. \n", + "Another output is transcription output in JSON format, which is saved in `./pred_rttms/an4_diarize_test.json`. \n", "\n", "In the JSON format output, we include information such as **transcription, estimated number of speakers (variable named `speaker_count`), start and end time of each word and most importantly, speaker label for each word.**" ] @@ -427,7 +468,7 @@ "metadata": {}, "outputs": [], "source": [ - "transcription_path_to_file = f\"{ROOT}/asr_with_diar/json_result/an4_diarize_test.json\"\n", + "transcription_path_to_file = f\"{data_dir}/pred_rttms/an4_diarize_test.json\"\n", "json_contents = read_file(transcription_path_to_file)\n", "pp.pprint(json_contents)" ] @@ -449,7 +490,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.8.2" } }, "nbformat": 4, diff --git a/tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb b/tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb index 2c59e7d6a..95dd358b3 100644 --- a/tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb +++ b/tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb @@ -151,32 +151,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Speaker Diarization scripts commonly expects two files:\n", - "1. paths2audio_files : either list of audio file paths or file containing paths to audio files for which we need to perform diarization.\n", - "2. path2groundtruth_rttm_files (optional): either list of rttm file paths or file containing paths to rttm files (this can be passed if we need to calculate DER rate based on our ground truth rttm files).\n", + "Speaker Diarization scripts commonly expects following arguments:\n", + "1. manifest_filepath : Path to manifest file containing json lines of format: {'audio_filepath': /path/to/audio_file, 'offset': 0, 'duration':None, 'label': 'infer', 'text': '-', 'num_speakers': None, 'rttm_filepath': /path/to/rttm/file, 'uem_filepath'='/path/to/uem/filepath'}\n", + "2. out_dir : directory where outputs and intermediate files are stored. \n", + "3. oracle_vad: If this is true then we extract speech activity labels from rttm files, if False then either \n", + "4. vad.model_path or external_manifestpath containing speech activity labels has to be passed. \n", + "\n", + "Mandatory fields are audio_filepath, offset, duration, label and text. For the rest if you would like to evaluate with known number of speakers pass the value else None. If you would like to score the system with known rttms then that should be passed as well, else None. uem file is used to score only part of your audio for evaluation purposes, hence pass if you would like to evaluate on it else None.\n", + "\n", "\n", "**Note** we expect audio and corresponding RTTM have **same base name** and the name should be **unique**. \n", "\n", - "For eg: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n", - "\n", - "Now let's create paths2audio_files list (or file) for which we need to perform diarization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "paths2audio_files = [an4_audio]\n", - "print(paths2audio_files)" + "For eg: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Similarly create` path2groundtruth_rttm_files` list (this is optional, and needed for score calculation)" + "Lets create manifest with the an4 audio and rttm available. If you have more than one files you may also use the script `pathsfiles_to_manifest.py` to generate manifest file from list of audio files and optionally rttm files " ] }, { @@ -185,8 +178,28 @@ "metadata": {}, "outputs": [], "source": [ - "path2groundtruth_rttm_files = [an4_rttm]\n", - "print(path2groundtruth_rttm_files)" + "# Create a manifest for input with below format. \n", + "# {'audio_filepath': /path/to/audio_file, 'offset': 0, 'duration':None, 'label': 'infer', 'text': '-', \n", + "# 'num_speakers': None, 'rttm_filepath': /path/to/rttm/file, 'uem_filepath'='/path/to/uem/filepath'}\n", + "import json\n", + "meta = {\n", + " 'audio_filepath': an4_audio, \n", + " 'offset': 0, \n", + " 'duration':None, \n", + " 'label': 'infer', \n", + " 'text': '-', \n", + " 'num_speakers': 2, \n", + " 'rttm_filepath': an4_rttm, \n", + " 'uem_filepath' : None\n", + "}\n", + "with open('data/input_manifest.json','w') as fp:\n", + " json.dump(meta,fp)\n", + " fp.write('\\n')\n", + "\n", + "!cat data/input_manifest.json\n", + "\n", + "output_dir = os.path.join(ROOT, 'oracle_vad')\n", + "os.makedirs(output_dir,exist_ok=True)" ] }, { @@ -203,39 +216,8 @@ "Oracle-vad diarization is to compute speaker embeddings from known speech label timestamps rather than depending on VAD output. This step can also be used to run speaker diarization with rttms generated from any external VAD, not just VAD model from NeMo.\n", "\n", "For it, the first step is to start converting reference audio rttm(vad) time stamps to oracle manifest file. This manifest file would be sent to our speaker diarizer to extract embeddings.\n", - "For that let's use write_rttm2manifest function, that takes paths2audio_files and paths2rttm_files as arguments " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nemo.collections.asr.parts.utils.speaker_utils import write_rttm2manifest\n", - "output_dir = os.path.join(ROOT, 'oracle_vad')\n", - "os.makedirs(output_dir,exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "oracle_manifest = os.path.join(output_dir,'oracle_manifest.json')\n", - "write_rttm2manifest(paths2audio_files=paths2audio_files,\n", - " paths2rttm_files=path2groundtruth_rttm_files,\n", - " manifest_file=oracle_manifest)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!cat {oracle_manifest}" + "\n", + "This is just an argument in our config, and system automatically computes oracle manifest based on the rttms provided through input manifest file" ] }, { @@ -253,10 +235,11 @@ "outputs": [], "source": [ "from omegaconf import OmegaConf\n", - "MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')\n", + "MODEL_CONFIG = os.path.join(data_dir,'offline_diarization.yaml')\n", "if not os.path.exists(MODEL_CONFIG):\n", - " config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/speaker_diarization.yaml\"\n", + " config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/modify_speaker_input/examples/speaker_tasks/diarization/conf/offline_diarization.yaml\"\n", " MODEL_CONFIG = wget.download(config_url,data_dir)\n", + "\n", "config = OmegaConf.load(MODEL_CONFIG)\n", "print(OmegaConf.to_yaml(config))" ] @@ -275,14 +258,14 @@ "outputs": [], "source": [ "pretrained_speaker_model='ecapa_tdnn'\n", - "config.diarizer.paths2audio_files = paths2audio_files\n", - "config.diarizer.path2groundtruth_rttm_files = path2groundtruth_rttm_files\n", + "config.diarizer.manifest_filepath = 'data/input_manifest.json'\n", "config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs\n", "\n", "config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n", - "# Ignoring vad we just need to pass the manifest file we created\n", - "config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest\n", - "config.diarizer.oracle_num_speakers=2" + "config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5\n", + "config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75\n", + "config.diarizer.oracle_vad = True # ----> ORACLE VAD \n", + "config.diarizer.clustering.parameters.oracle_num_speakers = True" ] }, { @@ -352,7 +335,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this method we compute VAD time stamps using NeMo VAD model on `paths2audio_files` and then use these time stamps of speech label to find speaker embeddings followed by clustering them into num of speakers" + "In this method we compute VAD time stamps using NeMo VAD model on input manifest file and then use these time stamps of speech label to find speaker embeddings followed by clustering them into num of speakers" ] }, { @@ -413,20 +396,23 @@ "outputs": [], "source": [ "output_dir = os.path.join(ROOT,'outputs')\n", - "config.diarizer.paths2audio_files = paths2audio_files\n", - "config.diarizer.path2groundtruth_rttm_files = path2groundtruth_rttm_files\n", - "config.diarizer.out_dir = output_dir # Directory to store intermediate files and prediction outputs\n", + "config.diarizer.manifest_filepath = 'data/input_manifest.json'\n", + "config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs\n", + "\n", "config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n", + "config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5\n", + "config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75\n", + "config.diarizer.oracle_vad = False # compute VAD provided with model_path to vad config\n", + "config.diarizer.clustering.parameters.oracle_num_speakers=True\n", "\n", "#Here we use our inhouse pretrained NeMo VAD \n", "config.diarizer.vad.model_path = pretrained_vad\n", "config.diarizer.vad.window_length_in_sec = 0.15\n", "config.diarizer.vad.shift_length_in_sec = 0.01\n", - "# config.diarizer.vad.threshold = 0.8 threshold would be deprecated in release 1.5\n", - "config.diarizer.vad.postprocessing_params.onset = 0.8 \n", - "config.diarizer.vad.postprocessing_params.offset = 0.7\n", - "config.diarizer.vad.postprocessing_params.min_duration_on = 0.1\n", - "config.diarizer.vad.postprocessing_params.min_duration_off = 0.3" + "config.diarizer.vad.parameters.onset = 0.8 \n", + "config.diarizer.vad.parameters.offset = 0.6\n", + "config.diarizer.vad.parameters.min_duration_on = 0.1\n", + "config.diarizer.vad.parameters.min_duration_off = 0.4" ] }, { @@ -490,13 +476,13 @@ "# you can also use single threshold(=onset=offset) for binarization and plot here\n", "from nemo.collections.asr.parts.utils.vad_utils import plot\n", "plot(\n", - " paths2audio_files[0],\n", + " an4_audio,\n", " 'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n", - " path2groundtruth_rttm_files[0],\n", - " per_args = config.diarizer.vad.postprocessing_params, #threshold\n", + " an4_rttm,\n", + " per_args = config.diarizer.vad.parameters, #threshold\n", " ) \n", "\n", - "print(f\"postprocessing_params: {config.diarizer.vad.postprocessing_params}\")" + "print(f\"postprocessing_params: {config.diarizer.vad.parameters}\")" ] }, { @@ -599,14 +585,14 @@ "outputs": [], "source": [ "quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=\"QuartzNet15x5Base-En\")\n", - "for fname, transcription in zip(paths2audio_files, quartznet.transcribe(paths2audio_files=paths2audio_files)):\n", - " print(f\"Audio in {fname} was recognized as: {transcription}\")" + "for fname, transcription in zip([an4_audio], quartznet.transcribe(paths2audio_files=[an4_audio])):\n", + " print(f\"Audio in {fname} was recognized as:\\n{transcription}\")" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -620,7 +606,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.2" }, "pycharm": { "stem_cell": {