VAD postprocessing - binarization, filtering (#2636)
* Add more paramter for binarization Signed-off-by: fayejf <fayejf07@gmail.com> * more Signed-off-by: fayejf <fayejf07@gmail.com> * optimze and update diarizer Signed-off-by: fayejf <fayejf07@gmail.com> * updates Signed-off-by: fayejf <fayejf07@gmail.com> * binarization and order of filtering for plot Signed-off-by: fayejf <fayejf07@gmail.com> * updates Signed-off-by: fayejf <fayejf07@gmail.com> * add dur offset for plotting Signed-off-by: fayejf <fayejf07@gmail.com> * fix and clean Signed-off-by: fayejf <fayejf07@gmail.com> * style fix and updates Signed-off-by: fayejf <fayejf07@gmail.com> * rename Signed-off-by: fayejf <fayejf07@gmail.com> * small udpates Signed-off-by: fayejf <fayejf07@gmail.com> * typo Signed-off-by: fayejf <fayejf07@gmail.com> * fix lgtm Signed-off-by: fayejf <fayejf07@gmail.com> * add back support threshold Signed-off-by: fayejf <fayejf07@gmail.com> * update yaml to be of ch109 Signed-off-by: fayejf <fayejf07@gmail.com> * style fix Signed-off-by: fayejf <fayejf07@gmail.com> * update Speaker_Diarization_Inference.ipynb Signed-off-by: fayejf <fayejf07@gmail.com> * updates Online_Offline_Microphone_VAD_Demo.ipynb Signed-off-by: fayejf <fayejf07@gmail.com> * update config in sd rst Signed-off-by: fayejf <fayejf07@gmail.com> * updated from feedback Signed-off-by: fayejf <fayejf07@gmail.com> * updates Signed-off-by: fayejf <fayejf07@gmail.com> * update yaml about dev Signed-off-by: fayejf <fayejf07@gmail.com> * add missing discription in docstring Signed-off-by: fayejf <fayejf07@gmail.com> * typo and dir name Signed-off-by: fayejf <fayejf07@gmail.com> * style fix Signed-off-by: fayejf <fayejf07@gmail.com>
This commit is contained in:
parent
a1085018d3
commit
7f01d2a51a
|
@ -50,11 +50,19 @@ Diarizer Architecture Configurations
|
|||
model_path: null #.nemo local model path or pretrained model name or none
|
||||
window_length_in_sec: 0.15
|
||||
shift_length_in_sec: 0.01
|
||||
threshold: 0.5 # tune threshold on dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
|
||||
vad_decision_smoothing: True
|
||||
smoothing_params:
|
||||
method: "median"
|
||||
overlap: 0.875
|
||||
threshold: null # if not null, use this value and ignore the following onset and offset. Warning! it will be removed in release 1.5!
|
||||
postprocessing_params: # Tune parameter on dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
|
||||
onset: 0.4
|
||||
offset: 0.7
|
||||
pad_onset: 0.05
|
||||
pad_offset: -0.1
|
||||
min_duration_on: 0.2
|
||||
min_duration_off: 0.2
|
||||
filter_speech_first: True
|
||||
|
||||
speaker_embeddings:
|
||||
oracle_vad_manifest: null # leave it null if to perform diarization with above VAD model else path to manifest file genrerated as shown in Datasets section
|
||||
|
|
|
@ -45,7 +45,7 @@ You may perform inference and transcribe a sample of speech after loading the mo
|
|||
Setting argument ``logprobs`` to True would return the log probabilities instead of transcriptions. You may find more details in `Modules <../api.html#modules>`__.
|
||||
|
||||
Learn how to fine tune on your own data or on subset classes in ``<NeMo_git_root>/tutorials/asr/Speech_Commands.ipynb``
|
||||
f
|
||||
|
||||
|
||||
**Run VAD inference:**
|
||||
|
||||
|
@ -53,7 +53,20 @@ f
|
|||
|
||||
python examples/speaker_recognition/vad_infer.py --vad_model="vad_marblenet" --dataset=<FULL PATH OF MANIFEST TO BE PERFORMED INFERENCE ON> --out_dir='frame/demo' --time_length=0.63
|
||||
|
||||
Have a look at scripts under ``<NeMo-git-root>/scripts/voice_activity_detection`` for posterior processing and threshold tuning.
|
||||
Have a look at scripts under ``<NeMo-git-root>/scripts/voice_activity_detection`` for posterior processing, postprocessing and threshold tuning.
|
||||
|
||||
Posterior processing includes generating predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
|
||||
|
||||
For VAD postprocessing we introduce
|
||||
|
||||
Binarization:
|
||||
- ``onset`` and ``offset`` threshold for detecting the beginning and end of a speech.
|
||||
- padding durations ``pad_onset`` before and padding duarations ``pad_offset`` after each speech segment;
|
||||
|
||||
Filtering:
|
||||
- ``min_duration_on`` threshold for short speech segment deletion,
|
||||
- ``min_duration_on`` threshold for small silence deletion,
|
||||
- ``filter_speech_first`` to control whether to perform short speech segment deletion first.
|
||||
|
||||
|
||||
NGC Pretrained Checkpoints
|
||||
|
|
|
@ -15,11 +15,19 @@ diarizer:
|
|||
model_path: null #.nemo local model path or pretrained model name or none
|
||||
window_length_in_sec: 0.15
|
||||
shift_length_in_sec: 0.01
|
||||
threshold: 0.7 # tune threshold on dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
|
||||
vad_decision_smoothing: True
|
||||
smoothing_params:
|
||||
method: "median"
|
||||
overlap: 0.875
|
||||
threshold: null # if not null, use this value and ignore the following onset and offset. Warning! it will be removed in release 1.5!
|
||||
postprocessing_params: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) You might need to choose resonable threshold or tune the threshold on your dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
|
||||
onset: 0.4
|
||||
offset: 0.7
|
||||
pad_onset: 0.05
|
||||
pad_offset: -0.1
|
||||
min_duration_on: 0.2
|
||||
min_duration_off: 0.2
|
||||
filter_speech_first: True
|
||||
|
||||
speaker_embeddings:
|
||||
oracle_vad_manifest: null
|
||||
|
|
|
@ -237,9 +237,16 @@ class ClusteringDiarizer(Model, DiarizationMixin):
|
|||
self.vad_pred_dir = smoothing_pred_dir
|
||||
|
||||
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
|
||||
postprocessing_params = self._cfg.diarizer.vad.postprocessing_params
|
||||
if self._cfg.diarizer.vad.get("threshold", None):
|
||||
logging.info("threshold is not None. Use threshold and update onset=offset=threshold")
|
||||
logging.warning("Support for threshold will be deprecated in release 1.5")
|
||||
postprocessing_params['onset'] = self._cfg.diarizer.vad.threshold
|
||||
postprocessing_params['offset'] = self._cfg.diarizer.vad.threshold
|
||||
|
||||
table_out_dir = generate_vad_segment_table(
|
||||
vad_pred_dir=self.vad_pred_dir,
|
||||
threshold=self._cfg.diarizer.vad.threshold,
|
||||
postprocessing_params=postprocessing_params,
|
||||
shift_len=self._vad_shift_length_in_sec,
|
||||
num_workers=self._cfg.num_workers,
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from itertools import repeat
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
@ -23,12 +24,13 @@ import numpy as np
|
|||
import pandas as pd
|
||||
from pyannote.core import Annotation, Segment
|
||||
from pyannote.metrics import detection
|
||||
from sklearn.model_selection import ParameterGrid
|
||||
|
||||
from nemo.utils import logging
|
||||
|
||||
|
||||
"""
|
||||
This file contains all the utility functions required for speaker embeddings part in diarization scripts
|
||||
This file contains all the utility functions required for voice activity detection.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -279,12 +281,13 @@ def generate_overlap_vad_seq_per_file(frame_filepath, per_args):
|
|||
|
||||
round_final = np.round(preds, 4)
|
||||
np.savetxt(overlap_filepath, round_final, delimiter='\n')
|
||||
return overlap_filepath
|
||||
|
||||
except Exception as e:
|
||||
raise (e)
|
||||
|
||||
|
||||
def generate_vad_segment_table(vad_pred_dir, threshold, shift_len, num_workers):
|
||||
def generate_vad_segment_table(vad_pred_dir, postprocessing_params, shift_len, num_workers, threshold=None):
|
||||
"""
|
||||
Convert frame level prediction to speech segment in start and end times format.
|
||||
And save to csv file in rttm-like format
|
||||
|
@ -292,7 +295,7 @@ def generate_vad_segment_table(vad_pred_dir, threshold, shift_len, num_workers):
|
|||
17,18, speech
|
||||
Args:
|
||||
vad_pred_dir (str): directory of prediction files to be processed.
|
||||
threshold (float): threshold for prediction score (from 0 to 1).
|
||||
postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering.
|
||||
shift_len (float): amount of shift of window for generating the frame.
|
||||
out_dir (str): output dir of generated table/csv file.
|
||||
num_workers(float): number of process for multiprocessing
|
||||
|
@ -304,15 +307,18 @@ def generate_vad_segment_table(vad_pred_dir, threshold, shift_len, num_workers):
|
|||
suffixes = ("frame", "mean", "median")
|
||||
vad_pred_filepath_list = [os.path.join(vad_pred_dir, x) for x in os.listdir(vad_pred_dir) if x.endswith(suffixes)]
|
||||
|
||||
table_out_dir = os.path.join(vad_pred_dir, "table_output_" + str(threshold))
|
||||
table_out_dir_name = "table_output_tmp_"
|
||||
for key in postprocessing_params:
|
||||
table_out_dir_name = table_out_dir_name + str(key) + str(postprocessing_params[key]) + "_"
|
||||
|
||||
table_out_dir = os.path.join(vad_pred_dir, table_out_dir_name)
|
||||
|
||||
if not os.path.exists(table_out_dir):
|
||||
os.mkdir(table_out_dir)
|
||||
|
||||
per_args = {
|
||||
"threshold": threshold,
|
||||
"shift_len": shift_len,
|
||||
"out_dir": table_out_dir,
|
||||
}
|
||||
per_args = {"shift_len": shift_len, "out_dir": table_out_dir}
|
||||
|
||||
per_args = {**per_args, **postprocessing_params}
|
||||
|
||||
p.starmap(generate_vad_segment_table_per_file, zip(vad_pred_filepath_list, repeat(per_args)))
|
||||
p.close()
|
||||
|
@ -325,41 +331,199 @@ def generate_vad_segment_table_per_file(pred_filepath, per_args):
|
|||
"""
|
||||
See discription in generate_overlap_vad_seq.
|
||||
"""
|
||||
threshold = per_args['threshold']
|
||||
shift_len = per_args['shift_len']
|
||||
out_dir = per_args['out_dir']
|
||||
|
||||
name = pred_filepath.split("/")[-1].rsplit(".", 1)[0]
|
||||
|
||||
sequence = np.loadtxt(pred_filepath)
|
||||
start = 0
|
||||
start_list = [0]
|
||||
dur_list = []
|
||||
state_list = []
|
||||
current_state = "non-speech"
|
||||
for i in range(len(sequence) - 1):
|
||||
current_state = "non-speech" if sequence[i] <= threshold else "speech"
|
||||
next_state = "non-speech" if sequence[i + 1] <= threshold else "speech"
|
||||
if next_state != current_state:
|
||||
dur = i * shift_len + shift_len - start # shift_len for handling joint
|
||||
state_list.append(current_state)
|
||||
dur_list.append(dur)
|
||||
|
||||
start = (i + 1) * shift_len
|
||||
start_list.append(start)
|
||||
speech_segments = binarization(sequence, per_args)
|
||||
speech_segments = filtering(speech_segments, per_args)
|
||||
|
||||
dur_list.append((i + 1) * shift_len + shift_len - start)
|
||||
state_list.append(current_state)
|
||||
|
||||
seg_table = pd.DataFrame({'start': start_list, 'dur': dur_list, 'vad': state_list})
|
||||
seg_speech_table = seg_table[seg_table['vad'] == 'speech']
|
||||
seg_speech_table = pd.DataFrame(speech_segments, columns=['start', 'end'])
|
||||
seg_speech_table = seg_speech_table.sort_values('start', ascending=True)
|
||||
seg_speech_table['dur'] = seg_speech_table['end'] - seg_speech_table['start'] + shift_len
|
||||
seg_speech_table['vad'] = 'speech'
|
||||
|
||||
save_name = name + ".txt"
|
||||
save_path = os.path.join(out_dir, save_name)
|
||||
seg_speech_table.to_csv(save_path, sep='\t', index=False, header=False)
|
||||
seg_speech_table.to_csv(save_path, columns=['start', 'dur', 'vad'], sep='\t', index=False, header=False)
|
||||
return save_path
|
||||
|
||||
|
||||
def binarization(sequence, per_args):
|
||||
"""
|
||||
Binarize predictions to speech and non-speech
|
||||
|
||||
Reference
|
||||
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
|
||||
|
||||
Args:
|
||||
sequence (list) : A list of frame level predictions.
|
||||
per_args:
|
||||
onset (float): onset threshold for detecting the beginning and end of a speech
|
||||
offset (float): offset threshold for detecting the end of a speech.
|
||||
pad_onset (float): adding durations before each speech segment
|
||||
pad_offset (float): adding durations after each speech segment;
|
||||
shift_len (float): amount of shift of window for generating the frame.
|
||||
|
||||
Returns:
|
||||
speech_segments(set): Set of speech segment in (start, end) format.
|
||||
"""
|
||||
shift_len = per_args.get('shift_len', 0.01)
|
||||
|
||||
onset = per_args.get('onset', 0.5)
|
||||
offset = per_args.get('offset', 0.5)
|
||||
pad_onset = per_args.get('pad_onset', 0) # 0.0
|
||||
pad_offset = per_args.get('pad_offset', 0)
|
||||
|
||||
onset, offset = cal_vad_onset_offset(per_args.get('scale', 'absolute'), onset, offset, sequence)
|
||||
|
||||
speech = False
|
||||
speech_segments = set() # {(start1, end1), (start2, end2)}
|
||||
for i in range(1, len(sequence)):
|
||||
# Current frame is speech
|
||||
if speech:
|
||||
# Switch from speech to non-speech
|
||||
if sequence[i] < offset:
|
||||
if start - pad_onset >= 0:
|
||||
speech_segments.add((start - pad_onset, i * shift_len + pad_offset))
|
||||
else:
|
||||
speech_segments.add((0, i * shift_len + pad_offset))
|
||||
start = i * shift_len
|
||||
speech = False
|
||||
|
||||
# Current frame is non-speech
|
||||
else:
|
||||
# Switch from non-speech to speech
|
||||
if sequence[i] > onset:
|
||||
start = i * shift_len
|
||||
speech = True
|
||||
|
||||
# if it's speech at the end, add final segment
|
||||
if speech:
|
||||
speech_segments.add((start - pad_onset, i * shift_len + pad_offset))
|
||||
|
||||
# Merge the overlapped speech segments due to padding
|
||||
speech_segments = merge_overlap_segment(speech_segments) # not sorted
|
||||
|
||||
return speech_segments
|
||||
|
||||
|
||||
def filtering(speech_segments, per_args):
|
||||
"""
|
||||
Filter out short non_speech and speech segments.
|
||||
|
||||
Reference
|
||||
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
|
||||
Args:
|
||||
speech_segments (set): A set of speech segment in (start, end) format. {(start1, end1), (start2, end2)}
|
||||
per_args:
|
||||
min_duration_on (float): threshold for small non_speech deletion
|
||||
min_duration_off (float): threshold for short speech segment deletion
|
||||
filter_speech_first (boolean): Whether to perform short speech segment deletion first.
|
||||
|
||||
Returns:
|
||||
speech_segments(set): Filtered set of speech segment in (start, end) format. {(start1, end1), (start2, end2)}
|
||||
"""
|
||||
min_duration_on = per_args.get('min_duration_on', 0.0)
|
||||
min_duration_off = per_args.get('min_duration_off', 0.0)
|
||||
filter_speech_first = per_args.get('filter_speech_first', True)
|
||||
|
||||
if filter_speech_first:
|
||||
# Filter out the shorter speech segments
|
||||
if min_duration_on > 0.0:
|
||||
speech_segments = filter_short_segments(speech_segments, min_duration_on)
|
||||
# Filter out the shorter non-speech segments and return to be as speech segments
|
||||
if min_duration_off > 0.0:
|
||||
# Find non-speech segments
|
||||
non_speech_segments = get_gap_segments(speech_segments)
|
||||
# Find shorter non-speech segments
|
||||
short_non_speech_segments = non_speech_segments - filter_short_segments(
|
||||
non_speech_segments, min_duration_off
|
||||
)
|
||||
# Return shorter non-speech segments to be as speech segments
|
||||
speech_segments.update(short_non_speech_segments)
|
||||
# Merge the overlapped speech segments
|
||||
speech_segments = merge_overlap_segment(speech_segments)
|
||||
else:
|
||||
if min_duration_off > 0.0:
|
||||
# Find non-speech segments
|
||||
non_speech_segments = get_gap_segments(speech_segments)
|
||||
# Find shorter non-speech segments
|
||||
short_non_speech_segments = non_speech_segments - filter_short_segments(
|
||||
non_speech_segments, min_duration_off
|
||||
)
|
||||
# Return shorter non-speech segments to be as speech segments
|
||||
speech_segments.update(short_non_speech_segments)
|
||||
# Merge the overlapped speech segments
|
||||
speech_segments = merge_overlap_segment(speech_segments)
|
||||
if min_duration_on > 0.0:
|
||||
speech_segments = filter_short_segments(speech_segments, min_duration_on)
|
||||
return speech_segments
|
||||
|
||||
|
||||
def filter_short_segments(segments, threshold):
|
||||
"""
|
||||
Remove segments which duration is smaller than a threshold.
|
||||
"""
|
||||
res = set()
|
||||
for seg in segments:
|
||||
if seg[1] - seg[0] >= threshold:
|
||||
res.add(seg)
|
||||
return res
|
||||
|
||||
|
||||
def get_gap_segments(segments):
|
||||
"""
|
||||
Get the gap segments. {(start1, end1), (start2, end2), (start3, end3)} -> {(end1, start2), (end2, start3)}
|
||||
"""
|
||||
segments = [list(i) for i in segments]
|
||||
segments.sort(key=lambda x: x[0])
|
||||
gap_segments = set()
|
||||
for i in range(len(segments) - 1):
|
||||
gap_segments.add((segments[i][1], segments[i + 1][0]))
|
||||
return gap_segments
|
||||
|
||||
|
||||
def merge_overlap_segment(segments):
|
||||
"""
|
||||
Merged the overlapped segemtns {(0, 1.5), (1, 3.5), } -> {(0, 3.5), }
|
||||
"""
|
||||
segments = [list(i) for i in segments]
|
||||
segments.sort(key=lambda x: x[0])
|
||||
merged = []
|
||||
for segment in segments:
|
||||
if not merged or merged[-1][1] < segment[0]:
|
||||
merged.append(segment)
|
||||
else:
|
||||
merged[-1][1] = max(merged[-1][1], segment[1])
|
||||
|
||||
merged_set = set([tuple(t) for t in merged])
|
||||
return merged_set
|
||||
|
||||
|
||||
def cal_vad_onset_offset(scale, onset, offset, sequence=None):
|
||||
"""
|
||||
Calculate onset and offset threshold given different scale.
|
||||
"""
|
||||
if scale == "absolute":
|
||||
mini = 0
|
||||
maxi = 1
|
||||
elif scale == "relative":
|
||||
mini = np.nanmin(sequence)
|
||||
maxi = np.nanmax(sequence)
|
||||
elif scale == "percentile":
|
||||
mini = np.nanpercentile(sequence, 1)
|
||||
maxi = np.nanpercentile(sequence, 99)
|
||||
|
||||
onset = mini + onset * (maxi - mini)
|
||||
offset = mini + offset * (maxi - mini)
|
||||
return onset, offset
|
||||
|
||||
|
||||
def vad_construct_pyannote_object_per_file(vad_table_filepath, groundtruth_RTTM_file):
|
||||
"""
|
||||
Construct pyannote object for evaluation.
|
||||
|
@ -387,11 +551,31 @@ def vad_construct_pyannote_object_per_file(vad_table_filepath, groundtruth_RTTM_
|
|||
return reference, hypothesis
|
||||
|
||||
|
||||
def vad_tune_threshold_on_dev(thresholds, vad_pred, groundtruth_RTTM, vad_pred_method="frame", focus_metric="DetER"):
|
||||
def get_parameter_grid(params):
|
||||
"""
|
||||
Tune threshold on dev set. Return best threshold which gives the lowest detection error rate (DetER) in thresholds.
|
||||
Get the parameter grid given a dictionary of parameters.
|
||||
"""
|
||||
has_filter_speech_first = False
|
||||
if 'filter_speech_first' in params:
|
||||
filter_speech_first = params['filter_speech_first']
|
||||
has_filter_speech_first = True
|
||||
params.pop("filter_speech_first")
|
||||
|
||||
params_grid = list(ParameterGrid(params))
|
||||
|
||||
if has_filter_speech_first:
|
||||
for i in params_grid:
|
||||
i['filter_speech_first'] = filter_speech_first
|
||||
return params_grid
|
||||
|
||||
|
||||
def vad_tune_threshold_on_dev(
|
||||
params, vad_pred, groundtruth_RTTM, result_file="res", vad_pred_method="frame", focus_metric="DetER"
|
||||
):
|
||||
"""
|
||||
Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds.
|
||||
Args:
|
||||
thresholds (list): list of thresholds.
|
||||
params (dict): dictionary of parameter to be tuned on.
|
||||
vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median".
|
||||
vad_pred_dir (str): directory of vad predictions or a file contains the paths of them
|
||||
groundtruth_RTTM_dir (str): directory of groundtruch rttm files or a file contains the paths of them.
|
||||
|
@ -399,61 +583,82 @@ def vad_tune_threshold_on_dev(thresholds, vad_pred, groundtruth_RTTM, vad_pred_m
|
|||
Returns:
|
||||
best_threhsold (float): threshold that gives lowest DetER.
|
||||
"""
|
||||
threshold_perf = {}
|
||||
best_threhsold = thresholds[0]
|
||||
min_score = 100
|
||||
|
||||
all_perf = {}
|
||||
try:
|
||||
thresholds[0] >= 0 and thresholds[-1] <= 1
|
||||
check_if_param_valid(params)
|
||||
except:
|
||||
raise ValueError("Invalid threshold! Should be in [0, 1]")
|
||||
raise ValueError("Please check if the parameters are valid")
|
||||
|
||||
for threshold in thresholds:
|
||||
metric = detection.DetectionErrorRate()
|
||||
paired_filenames, groundtruth_RTTM_dict, vad_pred_dict = pred_rttm_map(
|
||||
vad_pred, groundtruth_RTTM, vad_pred_method
|
||||
)
|
||||
print(paired_filenames)
|
||||
paired_filenames, groundtruth_RTTM_dict, vad_pred_dict = pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method)
|
||||
metric = detection.DetectionErrorRate()
|
||||
params_grid = get_parameter_grid(params)
|
||||
|
||||
for param in params_grid:
|
||||
# perform binarization, filtering accoring to param and write to rttm-like table
|
||||
vad_table_dir = generate_vad_segment_table(vad_pred, param, shift_len=0.01, num_workers=20)
|
||||
|
||||
# add reference and hypothesis to metrics
|
||||
for filename in paired_filenames:
|
||||
vad_pred_filepath = vad_pred_dict[filename]
|
||||
groundtruth_RTTM_file = groundtruth_RTTM_dict[filename]
|
||||
|
||||
if os.path.isdir(vad_pred):
|
||||
table_out_dir = os.path.join(vad_pred, "table_output_" + str(threshold))
|
||||
else:
|
||||
table_out_dir = os.path.join("tmp_table_outputs", "table_output_" + str(threshold))
|
||||
|
||||
if not os.path.exists(table_out_dir):
|
||||
os.makedirs(table_out_dir)
|
||||
|
||||
per_args = {"threshold": threshold, "shift_len": 0.01, "out_dir": table_out_dir}
|
||||
vad_table_filepath = generate_vad_segment_table_per_file(vad_pred_filepath, per_args)
|
||||
|
||||
vad_table_filepath = os.path.join(vad_table_dir, filename + ".txt")
|
||||
reference, hypothesis = vad_construct_pyannote_object_per_file(vad_table_filepath, groundtruth_RTTM_file)
|
||||
metric(reference, hypothesis) # accumulation
|
||||
|
||||
# delete tmp table files
|
||||
shutil.rmtree(vad_table_dir, ignore_errors=True)
|
||||
|
||||
report = metric.report(display=False)
|
||||
DetER = report.iloc[[-1]][('detection error rate', '%')].item()
|
||||
FA = report.iloc[[-1]][('false alarm', '%')].item()
|
||||
MISS = report.iloc[[-1]][('miss', '%')].item()
|
||||
|
||||
if focus_metric == "DetER":
|
||||
score = DetER
|
||||
elif focus_metric == "FA":
|
||||
score = FA
|
||||
elif focus_metric == "MISS":
|
||||
score = MISS
|
||||
else:
|
||||
raise ValueError("Metric we care most should be only in 'DetER', 'FA'or 'MISS'!")
|
||||
assert (
|
||||
focus_metric == "DetER" or focus_metric == "FA" or focus_metric == "MISS"
|
||||
), "Metric we care most should be only in 'DetER', 'FA'or 'MISS'!"
|
||||
all_perf[str(param)] = {'DetER (%)': DetER, 'FA (%)': FA, 'MISS (%)': MISS}
|
||||
logging.info(f"parameter {param}, {all_perf[str(param)] }")
|
||||
|
||||
score = all_perf[str(param)][focus_metric + ' (%)']
|
||||
|
||||
threshold_perf[threshold] = {'DetER (%)': DetER, 'FA (%)': FA, 'MISS (%)': MISS}
|
||||
logging.info(f"threshold {threshold}, {threshold_perf[threshold]}")
|
||||
del report
|
||||
metric.reset() # reset internal accumulator
|
||||
|
||||
# save results for analysis
|
||||
with open(result_file + ".txt", "a") as fp:
|
||||
fp.write(f"{param}, {all_perf[str(param)] }\n")
|
||||
|
||||
if score < min_score:
|
||||
best_threhsold = param
|
||||
optimal_scores = all_perf[str(param)]
|
||||
min_score = score
|
||||
best_threhsold = threshold
|
||||
return best_threhsold
|
||||
|
||||
return best_threhsold, optimal_scores
|
||||
|
||||
|
||||
def check_if_param_valid(params):
|
||||
"""
|
||||
Check if the parameters are valid.
|
||||
"""
|
||||
for i in params:
|
||||
if i == "filter_speech_first":
|
||||
if not type(params["filter_speech_first"]) == bool:
|
||||
raise ValueError("Invalid inputs! filter_speech_first should be either True or False!")
|
||||
elif i == "pad_onset":
|
||||
continue
|
||||
elif i == "pad_offset":
|
||||
continue
|
||||
else:
|
||||
for j in params[i]:
|
||||
if not j >= 0:
|
||||
raise ValueError(
|
||||
"Invalid inputs! All float parameters excpet pad_onset and pad_offset should be larger than 0!"
|
||||
)
|
||||
|
||||
if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])):
|
||||
raise ValueError("Invalid inputs! The onset and offset thresholds should be in range [0, 1]!")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method="frame"):
|
||||
|
@ -492,7 +697,15 @@ def pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method="frame"):
|
|||
return paired_filenames, groundtruth_RTTM_dict, vad_pred_dict
|
||||
|
||||
|
||||
def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshold=0.85):
|
||||
def plot(
|
||||
path2audio_file,
|
||||
path2_vad_pred,
|
||||
path2ground_truth_label=None,
|
||||
offset=0,
|
||||
duration=None,
|
||||
threshold=None,
|
||||
per_args=None,
|
||||
):
|
||||
"""
|
||||
Plot VAD outputs for demonstration in tutorial
|
||||
Args:
|
||||
|
@ -503,10 +716,14 @@ def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshol
|
|||
"""
|
||||
plt.figure(figsize=[20, 2])
|
||||
FRAME_LEN = 0.01
|
||||
audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True)
|
||||
|
||||
audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True, offset=offset, duration=duration)
|
||||
dur = librosa.get_duration(audio, sr=sample_rate)
|
||||
time = np.arange(0, dur, FRAME_LEN)
|
||||
|
||||
time = np.arange(offset, offset + dur, FRAME_LEN)
|
||||
frame = np.loadtxt(path2_vad_pred)
|
||||
frame = frame[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)]
|
||||
|
||||
len_pred = len(frame)
|
||||
ax1 = plt.subplot()
|
||||
ax1.plot(np.arange(audio.size) / sample_rate, audio, 'gray')
|
||||
|
@ -515,11 +732,24 @@ def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshol
|
|||
ax1.set_ylabel('Signal')
|
||||
ax1.set_ylim([-1, 1])
|
||||
ax2 = ax1.twinx()
|
||||
|
||||
prob = frame
|
||||
pred = np.where(prob >= threshold, 1, 0)
|
||||
if threshold and per_args:
|
||||
raise ValueError("threshold and per_args cannot be used at same time!")
|
||||
if not threshold and not per_args:
|
||||
raise ValueError("One and only one of threshold and per_args must have been used!")
|
||||
|
||||
if threshold:
|
||||
pred = np.where(prob >= threshold, 1, 0)
|
||||
if per_args:
|
||||
speech_segments = binarization(prob, per_args)
|
||||
speech_segments = filtering(speech_segments, per_args)
|
||||
pred = gen_pred_from_speech_segments(speech_segments, prob)
|
||||
|
||||
if path2ground_truth_label:
|
||||
label = extract_labels(path2ground_truth_label, time)
|
||||
ax2.plot(np.arange(len_pred) * FRAME_LEN, label, 'r', label='label')
|
||||
|
||||
ax2.plot(np.arange(len_pred) * FRAME_LEN, pred, 'b', label='pred')
|
||||
ax2.plot(np.arange(len_pred) * FRAME_LEN, prob, 'g--', label='speech prob')
|
||||
ax2.tick_params(axis='y', labelcolor='r')
|
||||
|
@ -529,6 +759,18 @@ def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshol
|
|||
return None
|
||||
|
||||
|
||||
def gen_pred_from_speech_segments(speech_segments, prob, shift_len=0.01):
|
||||
pred = np.zeros(prob.shape)
|
||||
speech_segments = [list(i) for i in speech_segments]
|
||||
speech_segments.sort(key=lambda x: x[0])
|
||||
|
||||
for seg in speech_segments:
|
||||
start = int(seg[0] / shift_len)
|
||||
end = int(seg[1] / shift_len)
|
||||
pred[start:end] = 1
|
||||
return pred
|
||||
|
||||
|
||||
def extract_labels(path2ground_truth_label, time):
|
||||
"""
|
||||
Extract groundtruth label for given time period.
|
||||
|
|
|
@ -19,11 +19,45 @@ import numpy as np
|
|||
from nemo.collections.asr.parts.utils.vad_utils import vad_tune_threshold_on_dev
|
||||
from nemo.utils import logging
|
||||
|
||||
"""
|
||||
This script is designed for thresholds tuning for postprocessing of VAD
|
||||
See details about it in nemo/collections/asr/parts/utils/vad_utils/binarization and filtering
|
||||
|
||||
Usage:
|
||||
python vad_tune_threshold.py \
|
||||
--onset_range="0,1,0.2" --offset_range="0,1,0.2" --min_duration_on_range="0.1,0.8,0.05" --min_duration_off_range="0.1,0.8,0.05" --not_filter_speech_first \
|
||||
--vad_pred=<FULL PATH OF FOLDER OF FRAME LEVEL PREDICTION FILES> \
|
||||
--groundtruth_RTTM=<DIRECTORY OF VAD PREDICTIONS OR A FILE CONTAINS THE PATHS OF THEM> \
|
||||
--vad_pred_method="median"
|
||||
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--onset_range", help="range of onset in list 'START,END,STEP' to be tuned on", type=str)
|
||||
parser.add_argument("--offset_range", help="range of offset in list 'START,END,STEP' to be tuned on", type=str)
|
||||
parser.add_argument(
|
||||
"--threshold_range", help="range of threshold in list 'START,END,STEP' to be tuned on", required=True
|
||||
"--pad_onset_range",
|
||||
help="range of pad_onset in list 'START,END,STEP' to be tuned on. pad_onset could be negative float",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pad_offset_range",
|
||||
help="range of pad_offset in list 'START,END,STEP' to be tuned on. pad_offset could be negative float",
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min_duration_on_range", help="range of min_duration_on in list 'START,END,STEP' to be tuned on", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_duration_off_range", help="range of min_duration_off in list 'START,END,STEP' to be tuned on", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--not_filter_speech_first",
|
||||
help="Whether to filter short speech first during filtering, should be either True or False!",
|
||||
action='store_true',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vad_pred", help="Directory of vad predictions or a file contains the paths of them.", required=True
|
||||
)
|
||||
|
@ -33,6 +67,9 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result_file", help="Filename of txt to store results", default="res",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vad_pred_method",
|
||||
help="suffix of prediction file. Should be either in 'frame', 'mean' or 'median'",
|
||||
|
@ -46,13 +83,50 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = {}
|
||||
try:
|
||||
start, stop, step = [float(i) for i in args.threshold_range.split(",")]
|
||||
thresholds = np.arange(start, stop, step)
|
||||
except:
|
||||
raise ValueError("Theshold input is invalid! Please enter it as a 'START,STOP,STEP' ")
|
||||
# if not input range for values of parameters, use default value defined in function binarization and filtering in nemo/collections/asr/parts/utils/vad_utils.py
|
||||
if args.onset_range:
|
||||
start, stop, step = [float(i) for i in args.onset_range.split(",")]
|
||||
onset = np.arange(start, stop, step)
|
||||
params['onset'] = onset
|
||||
|
||||
best_threhsold = vad_tune_threshold_on_dev(
|
||||
thresholds, args.vad_pred, args.groundtruth_RTTM, args.vad_pred_method, args.focus_metric
|
||||
if args.offset_range:
|
||||
start, stop, step = [float(i) for i in args.offset_range.split(",")]
|
||||
offset = np.arange(start, stop, step)
|
||||
params['offset'] = offset
|
||||
|
||||
if args.pad_onset_range:
|
||||
start, stop, step = [float(i) for i in args.pad_onset_range.split(",")]
|
||||
pad_onset = np.arange(start, stop, step)
|
||||
params['pad_onset'] = pad_onset
|
||||
|
||||
if args.pad_offset_range:
|
||||
start, stop, step = [float(i) for i in args.pad_offset_range.split(",")]
|
||||
pad_offset = np.arange(start, stop, step)
|
||||
params['pad_offset'] = pad_offset
|
||||
|
||||
if args.min_duration_on_range:
|
||||
start, stop, step = [float(i) for i in args.min_duration_on_range.split(",")]
|
||||
min_duration_on = np.arange(start, stop, step)
|
||||
params['min_duration_on'] = min_duration_on
|
||||
|
||||
if args.min_duration_off_range:
|
||||
start, stop, step = [float(i) for i in args.min_duration_off_range.split(",")]
|
||||
min_duration_off = np.arange(start, stop, step)
|
||||
params['min_duration_off'] = min_duration_off
|
||||
|
||||
if args.not_filter_speech_first:
|
||||
params['filter_speech_first'] = False
|
||||
|
||||
except:
|
||||
raise ValueError(
|
||||
"Theshold input is invalid! Please enter it as a 'START,STOP,STEP' for onset, offset, min_duration_on and min_duration_off, and enter True/False for filter_speech_first"
|
||||
)
|
||||
|
||||
best_threhsold, optimal_scores = vad_tune_threshold_on_dev(
|
||||
params, args.vad_pred, args.groundtruth_RTTM, args.result_file, args.vad_pred_method, args.focus_metric
|
||||
)
|
||||
logging.info(
|
||||
f"Best combination of thresholds for binarization selected from input ranges is {best_threhsold}, and the optimal score is {optimal_scores}"
|
||||
)
|
||||
logging.info(f"Best threshold selected from {thresholds} is {best_threhsold}!")
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
"This notebook demonstrates how to perform\n",
|
||||
"1. [offline streaming inference on audio files (offline VAD)](#Offline-streaming-inference);\n",
|
||||
"2. [finetuning](#Finetune) and use [posterior](#Posterior);\n",
|
||||
"2. [threshold tuning](#Tuning-threshold);\n",
|
||||
"2. [vad postproceesing and threshold tuning](#VAD-postprocessing-and-Tuning-threshold);\n",
|
||||
"4. [online streaming inference](#Online-streaming-inference);\n",
|
||||
"3. [online streaming inference from a microphone's stream](#Online-streaming-inference-through-microphone).\n"
|
||||
]
|
||||
|
@ -236,13 +236,23 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tuning threshold"
|
||||
"## VAD postprocessing and Tuning threshold"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can use a single **threshold** to binarize predictions or use typical VAD postpocessing including\n",
|
||||
"\n",
|
||||
"### Binarization:\n",
|
||||
"1. **onset** and **offset** threshold for detecting the beginning and end of a speech;\n",
|
||||
"2. padding durations before (**pad_onset**) and after (**pad_offset**) each speech segment.\n",
|
||||
"\n",
|
||||
"### Filtering:\n",
|
||||
"1. threshold for short speech segment deletion (**min_duration_on**);\n",
|
||||
"2. threshold for small silence deletion (**min_duration_off**);\n",
|
||||
"3. Whether to perform short speech segment deletion first (**filter_speech_first**).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Of course you can do threshold tuning on frame level prediction. We also provide a script \n",
|
||||
|
@ -250,7 +260,7 @@
|
|||
"<NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"to help you find best threshold if you have ground truth label file in RTTM format. "
|
||||
"to help you find best thresholds if you have ground truth label file in RTTM format. "
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -354,6 +364,7 @@
|
|||
"# 1) use reset() method to reset FrameVAD's state\n",
|
||||
"# 2) call transcribe(frame) to do VAD on\n",
|
||||
"# contiguous signal's frames\n",
|
||||
"# To simplify the flow, we use single threshold to binarize predictions.\n",
|
||||
"class FrameVAD:\n",
|
||||
" \n",
|
||||
" def __init__(self, model_definition,\n",
|
||||
|
@ -579,9 +590,9 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above prediction is based on `threshold=0.4`.\n",
|
||||
"To simplify the flow, the above prediction is based on single threshold and `threshold=0.4`.\n",
|
||||
"\n",
|
||||
"You can play with other [threshold](#Tuning-threshold) and see how they would impact performance. \n",
|
||||
"You can play with other [threshold](#VAD-postprocessing-and-Tuning-threshold) or use postprocessing and see how they would impact performance. \n",
|
||||
"\n",
|
||||
"**Note** if you want better performance, [finetune](#Finetune) on your data and use posteriors such as [overlapped prediction](#Posterior). \n",
|
||||
"\n",
|
||||
|
|
|
@ -255,7 +255,7 @@
|
|||
"from omegaconf import OmegaConf\n",
|
||||
"MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')\n",
|
||||
"if not os.path.exists(MODEL_CONFIG):\n",
|
||||
" config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/speaker_recognition/conf/speaker_diarization.yaml\"\n",
|
||||
" config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_recognition/conf/speaker_diarization.yaml\"\n",
|
||||
" MODEL_CONFIG = wget.download(config_url,data_dir)\n",
|
||||
"config = OmegaConf.load(MODEL_CONFIG)\n",
|
||||
"print(OmegaConf.to_yaml(config))"
|
||||
|
@ -422,7 +422,11 @@
|
|||
"config.diarizer.vad.model_path = pretrained_vad\n",
|
||||
"config.diarizer.vad.window_length_in_sec = 0.15\n",
|
||||
"config.diarizer.vad.shift_length_in_sec = 0.01\n",
|
||||
"config.diarizer.vad.threshold = 0.8"
|
||||
"# config.diarizer.vad.threshold = 0.8 threshold would be deprecated in release 1.5\n",
|
||||
"config.diarizer.vad.postprocessing_params.onset = 0.8 \n",
|
||||
"config.diarizer.vad.postprocessing_params.offset = 0.7\n",
|
||||
"config.diarizer.vad.postprocessing_params.min_duration_on = 0.1\n",
|
||||
"config.diarizer.vad.postprocessing_params.min_duration_off = 0.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -479,19 +483,20 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VAD predicted time stamps\n",
|
||||
"from nemo.collections.asr.parts.utils.vad_utils import extract_labels, plot\n",
|
||||
"# you can also use single threshold(=onset=offset) for binarization and plot here\n",
|
||||
"from nemo.collections.asr.parts.utils.vad_utils import plot\n",
|
||||
"plot(\n",
|
||||
" paths2audio_files[0],\n",
|
||||
" 'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n",
|
||||
" path2groundtruth_rttm_files[0],\n",
|
||||
" per_args = config.diarizer.vad.postprocessing_params, #threshold\n",
|
||||
" ) \n",
|
||||
"\n",
|
||||
"plot(paths2audio_files[0],\n",
|
||||
" 'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n",
|
||||
" path2groundtruth_rttm_files[0],\n",
|
||||
" threshold=config.diarizer.vad.threshold)\n",
|
||||
"print(f\"threshold: {config.diarizer.vad.threshold}\")"
|
||||
"print(f\"postprocessing_params: {config.diarizer.vad.postprocessing_params}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -629,4 +634,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue