VAD postprocessing - binarization, filtering (#2636)

* Add more paramter for binarization

Signed-off-by: fayejf <fayejf07@gmail.com>

* more

Signed-off-by: fayejf <fayejf07@gmail.com>

* optimze and update diarizer

Signed-off-by: fayejf <fayejf07@gmail.com>

* updates

Signed-off-by: fayejf <fayejf07@gmail.com>

* binarization and order of filtering for plot

Signed-off-by: fayejf <fayejf07@gmail.com>

* updates

Signed-off-by: fayejf <fayejf07@gmail.com>

* add dur offset for plotting

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix and clean

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fix and updates

Signed-off-by: fayejf <fayejf07@gmail.com>

* rename

Signed-off-by: fayejf <fayejf07@gmail.com>

* small udpates

Signed-off-by: fayejf <fayejf07@gmail.com>

* typo

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix lgtm

Signed-off-by: fayejf <fayejf07@gmail.com>

* add back support threshold

Signed-off-by: fayejf <fayejf07@gmail.com>

* update yaml to be of ch109

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* update Speaker_Diarization_Inference.ipynb

Signed-off-by: fayejf <fayejf07@gmail.com>

* updates Online_Offline_Microphone_VAD_Demo.ipynb

Signed-off-by: fayejf <fayejf07@gmail.com>

* update config in sd rst

Signed-off-by: fayejf <fayejf07@gmail.com>

* updated from feedback

Signed-off-by: fayejf <fayejf07@gmail.com>

* updates

Signed-off-by: fayejf <fayejf07@gmail.com>

* update yaml about dev

Signed-off-by: fayejf <fayejf07@gmail.com>

* add missing discription in docstring

Signed-off-by: fayejf <fayejf07@gmail.com>

* typo and dir name

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fix

Signed-off-by: fayejf <fayejf07@gmail.com>
This commit is contained in:
fayejf 2021-08-12 21:04:03 -07:00 committed by GitHub
parent a1085018d3
commit 7f01d2a51a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 472 additions and 104 deletions

View file

@ -50,11 +50,19 @@ Diarizer Architecture Configurations
model_path: null #.nemo local model path or pretrained model name or none
window_length_in_sec: 0.15
shift_length_in_sec: 0.01
threshold: 0.5 # tune threshold on dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
vad_decision_smoothing: True
smoothing_params:
method: "median"
overlap: 0.875
threshold: null # if not null, use this value and ignore the following onset and offset. Warning! it will be removed in release 1.5!
postprocessing_params: # Tune parameter on dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
onset: 0.4
offset: 0.7
pad_onset: 0.05
pad_offset: -0.1
min_duration_on: 0.2
min_duration_off: 0.2
filter_speech_first: True
speaker_embeddings:
oracle_vad_manifest: null # leave it null if to perform diarization with above VAD model else path to manifest file genrerated as shown in Datasets section

View file

@ -45,7 +45,7 @@ You may perform inference and transcribe a sample of speech after loading the mo
Setting argument ``logprobs`` to True would return the log probabilities instead of transcriptions. You may find more details in `Modules <../api.html#modules>`__.
Learn how to fine tune on your own data or on subset classes in ``<NeMo_git_root>/tutorials/asr/Speech_Commands.ipynb``
f
**Run VAD inference:**
@ -53,7 +53,20 @@ f
python examples/speaker_recognition/vad_infer.py --vad_model="vad_marblenet" --dataset=<FULL PATH OF MANIFEST TO BE PERFORMED INFERENCE ON> --out_dir='frame/demo' --time_length=0.63
Have a look at scripts under ``<NeMo-git-root>/scripts/voice_activity_detection`` for posterior processing and threshold tuning.
Have a look at scripts under ``<NeMo-git-root>/scripts/voice_activity_detection`` for posterior processing, postprocessing and threshold tuning.
Posterior processing includes generating predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
For VAD postprocessing we introduce
Binarization:
- ``onset`` and ``offset`` threshold for detecting the beginning and end of a speech.
- padding durations ``pad_onset`` before and padding duarations ``pad_offset`` after each speech segment;
Filtering:
- ``min_duration_on`` threshold for short speech segment deletion,
- ``min_duration_on`` threshold for small silence deletion,
- ``filter_speech_first`` to control whether to perform short speech segment deletion first.
NGC Pretrained Checkpoints

View file

@ -15,11 +15,19 @@ diarizer:
model_path: null #.nemo local model path or pretrained model name or none
window_length_in_sec: 0.15
shift_length_in_sec: 0.01
threshold: 0.7 # tune threshold on dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
vad_decision_smoothing: True
smoothing_params:
method: "median"
overlap: 0.875
threshold: null # if not null, use this value and ignore the following onset and offset. Warning! it will be removed in release 1.5!
postprocessing_params: # Tuned parameter for CH109! (with 11 moved multi-speech sessions as dev set) You might need to choose resonable threshold or tune the threshold on your dev set. Check <NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py
onset: 0.4
offset: 0.7
pad_onset: 0.05
pad_offset: -0.1
min_duration_on: 0.2
min_duration_off: 0.2
filter_speech_first: True
speaker_embeddings:
oracle_vad_manifest: null

View file

@ -237,9 +237,16 @@ class ClusteringDiarizer(Model, DiarizationMixin):
self.vad_pred_dir = smoothing_pred_dir
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
postprocessing_params = self._cfg.diarizer.vad.postprocessing_params
if self._cfg.diarizer.vad.get("threshold", None):
logging.info("threshold is not None. Use threshold and update onset=offset=threshold")
logging.warning("Support for threshold will be deprecated in release 1.5")
postprocessing_params['onset'] = self._cfg.diarizer.vad.threshold
postprocessing_params['offset'] = self._cfg.diarizer.vad.threshold
table_out_dir = generate_vad_segment_table(
vad_pred_dir=self.vad_pred_dir,
threshold=self._cfg.diarizer.vad.threshold,
postprocessing_params=postprocessing_params,
shift_len=self._vad_shift_length_in_sec,
num_workers=self._cfg.num_workers,
)

View file

@ -14,6 +14,7 @@
import glob
import json
import os
import shutil
from itertools import repeat
from multiprocessing import Pool
@ -23,12 +24,13 @@ import numpy as np
import pandas as pd
from pyannote.core import Annotation, Segment
from pyannote.metrics import detection
from sklearn.model_selection import ParameterGrid
from nemo.utils import logging
"""
This file contains all the utility functions required for speaker embeddings part in diarization scripts
This file contains all the utility functions required for voice activity detection.
"""
@ -279,12 +281,13 @@ def generate_overlap_vad_seq_per_file(frame_filepath, per_args):
round_final = np.round(preds, 4)
np.savetxt(overlap_filepath, round_final, delimiter='\n')
return overlap_filepath
except Exception as e:
raise (e)
def generate_vad_segment_table(vad_pred_dir, threshold, shift_len, num_workers):
def generate_vad_segment_table(vad_pred_dir, postprocessing_params, shift_len, num_workers, threshold=None):
"""
Convert frame level prediction to speech segment in start and end times format.
And save to csv file in rttm-like format
@ -292,7 +295,7 @@ def generate_vad_segment_table(vad_pred_dir, threshold, shift_len, num_workers):
17,18, speech
Args:
vad_pred_dir (str): directory of prediction files to be processed.
threshold (float): threshold for prediction score (from 0 to 1).
postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering.
shift_len (float): amount of shift of window for generating the frame.
out_dir (str): output dir of generated table/csv file.
num_workers(float): number of process for multiprocessing
@ -304,15 +307,18 @@ def generate_vad_segment_table(vad_pred_dir, threshold, shift_len, num_workers):
suffixes = ("frame", "mean", "median")
vad_pred_filepath_list = [os.path.join(vad_pred_dir, x) for x in os.listdir(vad_pred_dir) if x.endswith(suffixes)]
table_out_dir = os.path.join(vad_pred_dir, "table_output_" + str(threshold))
table_out_dir_name = "table_output_tmp_"
for key in postprocessing_params:
table_out_dir_name = table_out_dir_name + str(key) + str(postprocessing_params[key]) + "_"
table_out_dir = os.path.join(vad_pred_dir, table_out_dir_name)
if not os.path.exists(table_out_dir):
os.mkdir(table_out_dir)
per_args = {
"threshold": threshold,
"shift_len": shift_len,
"out_dir": table_out_dir,
}
per_args = {"shift_len": shift_len, "out_dir": table_out_dir}
per_args = {**per_args, **postprocessing_params}
p.starmap(generate_vad_segment_table_per_file, zip(vad_pred_filepath_list, repeat(per_args)))
p.close()
@ -325,41 +331,199 @@ def generate_vad_segment_table_per_file(pred_filepath, per_args):
"""
See discription in generate_overlap_vad_seq.
"""
threshold = per_args['threshold']
shift_len = per_args['shift_len']
out_dir = per_args['out_dir']
name = pred_filepath.split("/")[-1].rsplit(".", 1)[0]
sequence = np.loadtxt(pred_filepath)
start = 0
start_list = [0]
dur_list = []
state_list = []
current_state = "non-speech"
for i in range(len(sequence) - 1):
current_state = "non-speech" if sequence[i] <= threshold else "speech"
next_state = "non-speech" if sequence[i + 1] <= threshold else "speech"
if next_state != current_state:
dur = i * shift_len + shift_len - start # shift_len for handling joint
state_list.append(current_state)
dur_list.append(dur)
start = (i + 1) * shift_len
start_list.append(start)
speech_segments = binarization(sequence, per_args)
speech_segments = filtering(speech_segments, per_args)
dur_list.append((i + 1) * shift_len + shift_len - start)
state_list.append(current_state)
seg_table = pd.DataFrame({'start': start_list, 'dur': dur_list, 'vad': state_list})
seg_speech_table = seg_table[seg_table['vad'] == 'speech']
seg_speech_table = pd.DataFrame(speech_segments, columns=['start', 'end'])
seg_speech_table = seg_speech_table.sort_values('start', ascending=True)
seg_speech_table['dur'] = seg_speech_table['end'] - seg_speech_table['start'] + shift_len
seg_speech_table['vad'] = 'speech'
save_name = name + ".txt"
save_path = os.path.join(out_dir, save_name)
seg_speech_table.to_csv(save_path, sep='\t', index=False, header=False)
seg_speech_table.to_csv(save_path, columns=['start', 'dur', 'vad'], sep='\t', index=False, header=False)
return save_path
def binarization(sequence, per_args):
"""
Binarize predictions to speech and non-speech
Reference
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Args:
sequence (list) : A list of frame level predictions.
per_args:
onset (float): onset threshold for detecting the beginning and end of a speech
offset (float): offset threshold for detecting the end of a speech.
pad_onset (float): adding durations before each speech segment
pad_offset (float): adding durations after each speech segment;
shift_len (float): amount of shift of window for generating the frame.
Returns:
speech_segments(set): Set of speech segment in (start, end) format.
"""
shift_len = per_args.get('shift_len', 0.01)
onset = per_args.get('onset', 0.5)
offset = per_args.get('offset', 0.5)
pad_onset = per_args.get('pad_onset', 0) # 0.0
pad_offset = per_args.get('pad_offset', 0)
onset, offset = cal_vad_onset_offset(per_args.get('scale', 'absolute'), onset, offset, sequence)
speech = False
speech_segments = set() # {(start1, end1), (start2, end2)}
for i in range(1, len(sequence)):
# Current frame is speech
if speech:
# Switch from speech to non-speech
if sequence[i] < offset:
if start - pad_onset >= 0:
speech_segments.add((start - pad_onset, i * shift_len + pad_offset))
else:
speech_segments.add((0, i * shift_len + pad_offset))
start = i * shift_len
speech = False
# Current frame is non-speech
else:
# Switch from non-speech to speech
if sequence[i] > onset:
start = i * shift_len
speech = True
# if it's speech at the end, add final segment
if speech:
speech_segments.add((start - pad_onset, i * shift_len + pad_offset))
# Merge the overlapped speech segments due to padding
speech_segments = merge_overlap_segment(speech_segments) # not sorted
return speech_segments
def filtering(speech_segments, per_args):
"""
Filter out short non_speech and speech segments.
Reference
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Args:
speech_segments (set): A set of speech segment in (start, end) format. {(start1, end1), (start2, end2)}
per_args:
min_duration_on (float): threshold for small non_speech deletion
min_duration_off (float): threshold for short speech segment deletion
filter_speech_first (boolean): Whether to perform short speech segment deletion first.
Returns:
speech_segments(set): Filtered set of speech segment in (start, end) format. {(start1, end1), (start2, end2)}
"""
min_duration_on = per_args.get('min_duration_on', 0.0)
min_duration_off = per_args.get('min_duration_off', 0.0)
filter_speech_first = per_args.get('filter_speech_first', True)
if filter_speech_first:
# Filter out the shorter speech segments
if min_duration_on > 0.0:
speech_segments = filter_short_segments(speech_segments, min_duration_on)
# Filter out the shorter non-speech segments and return to be as speech segments
if min_duration_off > 0.0:
# Find non-speech segments
non_speech_segments = get_gap_segments(speech_segments)
# Find shorter non-speech segments
short_non_speech_segments = non_speech_segments - filter_short_segments(
non_speech_segments, min_duration_off
)
# Return shorter non-speech segments to be as speech segments
speech_segments.update(short_non_speech_segments)
# Merge the overlapped speech segments
speech_segments = merge_overlap_segment(speech_segments)
else:
if min_duration_off > 0.0:
# Find non-speech segments
non_speech_segments = get_gap_segments(speech_segments)
# Find shorter non-speech segments
short_non_speech_segments = non_speech_segments - filter_short_segments(
non_speech_segments, min_duration_off
)
# Return shorter non-speech segments to be as speech segments
speech_segments.update(short_non_speech_segments)
# Merge the overlapped speech segments
speech_segments = merge_overlap_segment(speech_segments)
if min_duration_on > 0.0:
speech_segments = filter_short_segments(speech_segments, min_duration_on)
return speech_segments
def filter_short_segments(segments, threshold):
"""
Remove segments which duration is smaller than a threshold.
"""
res = set()
for seg in segments:
if seg[1] - seg[0] >= threshold:
res.add(seg)
return res
def get_gap_segments(segments):
"""
Get the gap segments. {(start1, end1), (start2, end2), (start3, end3)} -> {(end1, start2), (end2, start3)}
"""
segments = [list(i) for i in segments]
segments.sort(key=lambda x: x[0])
gap_segments = set()
for i in range(len(segments) - 1):
gap_segments.add((segments[i][1], segments[i + 1][0]))
return gap_segments
def merge_overlap_segment(segments):
"""
Merged the overlapped segemtns {(0, 1.5), (1, 3.5), } -> {(0, 3.5), }
"""
segments = [list(i) for i in segments]
segments.sort(key=lambda x: x[0])
merged = []
for segment in segments:
if not merged or merged[-1][1] < segment[0]:
merged.append(segment)
else:
merged[-1][1] = max(merged[-1][1], segment[1])
merged_set = set([tuple(t) for t in merged])
return merged_set
def cal_vad_onset_offset(scale, onset, offset, sequence=None):
"""
Calculate onset and offset threshold given different scale.
"""
if scale == "absolute":
mini = 0
maxi = 1
elif scale == "relative":
mini = np.nanmin(sequence)
maxi = np.nanmax(sequence)
elif scale == "percentile":
mini = np.nanpercentile(sequence, 1)
maxi = np.nanpercentile(sequence, 99)
onset = mini + onset * (maxi - mini)
offset = mini + offset * (maxi - mini)
return onset, offset
def vad_construct_pyannote_object_per_file(vad_table_filepath, groundtruth_RTTM_file):
"""
Construct pyannote object for evaluation.
@ -387,11 +551,31 @@ def vad_construct_pyannote_object_per_file(vad_table_filepath, groundtruth_RTTM_
return reference, hypothesis
def vad_tune_threshold_on_dev(thresholds, vad_pred, groundtruth_RTTM, vad_pred_method="frame", focus_metric="DetER"):
def get_parameter_grid(params):
"""
Tune threshold on dev set. Return best threshold which gives the lowest detection error rate (DetER) in thresholds.
Get the parameter grid given a dictionary of parameters.
"""
has_filter_speech_first = False
if 'filter_speech_first' in params:
filter_speech_first = params['filter_speech_first']
has_filter_speech_first = True
params.pop("filter_speech_first")
params_grid = list(ParameterGrid(params))
if has_filter_speech_first:
for i in params_grid:
i['filter_speech_first'] = filter_speech_first
return params_grid
def vad_tune_threshold_on_dev(
params, vad_pred, groundtruth_RTTM, result_file="res", vad_pred_method="frame", focus_metric="DetER"
):
"""
Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds.
Args:
thresholds (list): list of thresholds.
params (dict): dictionary of parameter to be tuned on.
vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median".
vad_pred_dir (str): directory of vad predictions or a file contains the paths of them
groundtruth_RTTM_dir (str): directory of groundtruch rttm files or a file contains the paths of them.
@ -399,61 +583,82 @@ def vad_tune_threshold_on_dev(thresholds, vad_pred, groundtruth_RTTM, vad_pred_m
Returns:
best_threhsold (float): threshold that gives lowest DetER.
"""
threshold_perf = {}
best_threhsold = thresholds[0]
min_score = 100
all_perf = {}
try:
thresholds[0] >= 0 and thresholds[-1] <= 1
check_if_param_valid(params)
except:
raise ValueError("Invalid threshold! Should be in [0, 1]")
raise ValueError("Please check if the parameters are valid")
for threshold in thresholds:
metric = detection.DetectionErrorRate()
paired_filenames, groundtruth_RTTM_dict, vad_pred_dict = pred_rttm_map(
vad_pred, groundtruth_RTTM, vad_pred_method
)
print(paired_filenames)
paired_filenames, groundtruth_RTTM_dict, vad_pred_dict = pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method)
metric = detection.DetectionErrorRate()
params_grid = get_parameter_grid(params)
for param in params_grid:
# perform binarization, filtering accoring to param and write to rttm-like table
vad_table_dir = generate_vad_segment_table(vad_pred, param, shift_len=0.01, num_workers=20)
# add reference and hypothesis to metrics
for filename in paired_filenames:
vad_pred_filepath = vad_pred_dict[filename]
groundtruth_RTTM_file = groundtruth_RTTM_dict[filename]
if os.path.isdir(vad_pred):
table_out_dir = os.path.join(vad_pred, "table_output_" + str(threshold))
else:
table_out_dir = os.path.join("tmp_table_outputs", "table_output_" + str(threshold))
if not os.path.exists(table_out_dir):
os.makedirs(table_out_dir)
per_args = {"threshold": threshold, "shift_len": 0.01, "out_dir": table_out_dir}
vad_table_filepath = generate_vad_segment_table_per_file(vad_pred_filepath, per_args)
vad_table_filepath = os.path.join(vad_table_dir, filename + ".txt")
reference, hypothesis = vad_construct_pyannote_object_per_file(vad_table_filepath, groundtruth_RTTM_file)
metric(reference, hypothesis) # accumulation
# delete tmp table files
shutil.rmtree(vad_table_dir, ignore_errors=True)
report = metric.report(display=False)
DetER = report.iloc[[-1]][('detection error rate', '%')].item()
FA = report.iloc[[-1]][('false alarm', '%')].item()
MISS = report.iloc[[-1]][('miss', '%')].item()
if focus_metric == "DetER":
score = DetER
elif focus_metric == "FA":
score = FA
elif focus_metric == "MISS":
score = MISS
else:
raise ValueError("Metric we care most should be only in 'DetER', 'FA'or 'MISS'!")
assert (
focus_metric == "DetER" or focus_metric == "FA" or focus_metric == "MISS"
), "Metric we care most should be only in 'DetER', 'FA'or 'MISS'!"
all_perf[str(param)] = {'DetER (%)': DetER, 'FA (%)': FA, 'MISS (%)': MISS}
logging.info(f"parameter {param}, {all_perf[str(param)] }")
score = all_perf[str(param)][focus_metric + ' (%)']
threshold_perf[threshold] = {'DetER (%)': DetER, 'FA (%)': FA, 'MISS (%)': MISS}
logging.info(f"threshold {threshold}, {threshold_perf[threshold]}")
del report
metric.reset() # reset internal accumulator
# save results for analysis
with open(result_file + ".txt", "a") as fp:
fp.write(f"{param}, {all_perf[str(param)] }\n")
if score < min_score:
best_threhsold = param
optimal_scores = all_perf[str(param)]
min_score = score
best_threhsold = threshold
return best_threhsold
return best_threhsold, optimal_scores
def check_if_param_valid(params):
"""
Check if the parameters are valid.
"""
for i in params:
if i == "filter_speech_first":
if not type(params["filter_speech_first"]) == bool:
raise ValueError("Invalid inputs! filter_speech_first should be either True or False!")
elif i == "pad_onset":
continue
elif i == "pad_offset":
continue
else:
for j in params[i]:
if not j >= 0:
raise ValueError(
"Invalid inputs! All float parameters excpet pad_onset and pad_offset should be larger than 0!"
)
if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])):
raise ValueError("Invalid inputs! The onset and offset thresholds should be in range [0, 1]!")
return True
def pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method="frame"):
@ -492,7 +697,15 @@ def pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method="frame"):
return paired_filenames, groundtruth_RTTM_dict, vad_pred_dict
def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshold=0.85):
def plot(
path2audio_file,
path2_vad_pred,
path2ground_truth_label=None,
offset=0,
duration=None,
threshold=None,
per_args=None,
):
"""
Plot VAD outputs for demonstration in tutorial
Args:
@ -503,10 +716,14 @@ def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshol
"""
plt.figure(figsize=[20, 2])
FRAME_LEN = 0.01
audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True)
audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True, offset=offset, duration=duration)
dur = librosa.get_duration(audio, sr=sample_rate)
time = np.arange(0, dur, FRAME_LEN)
time = np.arange(offset, offset + dur, FRAME_LEN)
frame = np.loadtxt(path2_vad_pred)
frame = frame[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)]
len_pred = len(frame)
ax1 = plt.subplot()
ax1.plot(np.arange(audio.size) / sample_rate, audio, 'gray')
@ -515,11 +732,24 @@ def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshol
ax1.set_ylabel('Signal')
ax1.set_ylim([-1, 1])
ax2 = ax1.twinx()
prob = frame
pred = np.where(prob >= threshold, 1, 0)
if threshold and per_args:
raise ValueError("threshold and per_args cannot be used at same time!")
if not threshold and not per_args:
raise ValueError("One and only one of threshold and per_args must have been used!")
if threshold:
pred = np.where(prob >= threshold, 1, 0)
if per_args:
speech_segments = binarization(prob, per_args)
speech_segments = filtering(speech_segments, per_args)
pred = gen_pred_from_speech_segments(speech_segments, prob)
if path2ground_truth_label:
label = extract_labels(path2ground_truth_label, time)
ax2.plot(np.arange(len_pred) * FRAME_LEN, label, 'r', label='label')
ax2.plot(np.arange(len_pred) * FRAME_LEN, pred, 'b', label='pred')
ax2.plot(np.arange(len_pred) * FRAME_LEN, prob, 'g--', label='speech prob')
ax2.tick_params(axis='y', labelcolor='r')
@ -529,6 +759,18 @@ def plot(path2audio_file, path2_vad_pred, path2ground_truth_label=None, threshol
return None
def gen_pred_from_speech_segments(speech_segments, prob, shift_len=0.01):
pred = np.zeros(prob.shape)
speech_segments = [list(i) for i in speech_segments]
speech_segments.sort(key=lambda x: x[0])
for seg in speech_segments:
start = int(seg[0] / shift_len)
end = int(seg[1] / shift_len)
pred[start:end] = 1
return pred
def extract_labels(path2ground_truth_label, time):
"""
Extract groundtruth label for given time period.

View file

@ -19,11 +19,45 @@ import numpy as np
from nemo.collections.asr.parts.utils.vad_utils import vad_tune_threshold_on_dev
from nemo.utils import logging
"""
This script is designed for thresholds tuning for postprocessing of VAD
See details about it in nemo/collections/asr/parts/utils/vad_utils/binarization and filtering
Usage:
python vad_tune_threshold.py \
--onset_range="0,1,0.2" --offset_range="0,1,0.2" --min_duration_on_range="0.1,0.8,0.05" --min_duration_off_range="0.1,0.8,0.05" --not_filter_speech_first \
--vad_pred=<FULL PATH OF FOLDER OF FRAME LEVEL PREDICTION FILES> \
--groundtruth_RTTM=<DIRECTORY OF VAD PREDICTIONS OR A FILE CONTAINS THE PATHS OF THEM> \
--vad_pred_method="median"
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--onset_range", help="range of onset in list 'START,END,STEP' to be tuned on", type=str)
parser.add_argument("--offset_range", help="range of offset in list 'START,END,STEP' to be tuned on", type=str)
parser.add_argument(
"--threshold_range", help="range of threshold in list 'START,END,STEP' to be tuned on", required=True
"--pad_onset_range",
help="range of pad_onset in list 'START,END,STEP' to be tuned on. pad_onset could be negative float",
type=str,
)
parser.add_argument(
"--pad_offset_range",
help="range of pad_offset in list 'START,END,STEP' to be tuned on. pad_offset could be negative float",
type=str,
)
parser.add_argument(
"--min_duration_on_range", help="range of min_duration_on in list 'START,END,STEP' to be tuned on", type=str
)
parser.add_argument(
"--min_duration_off_range", help="range of min_duration_off in list 'START,END,STEP' to be tuned on", type=str
)
parser.add_argument(
"--not_filter_speech_first",
help="Whether to filter short speech first during filtering, should be either True or False!",
action='store_true',
)
parser.add_argument(
"--vad_pred", help="Directory of vad predictions or a file contains the paths of them.", required=True
)
@ -33,6 +67,9 @@ if __name__ == "__main__":
type=str,
required=True,
)
parser.add_argument(
"--result_file", help="Filename of txt to store results", default="res",
)
parser.add_argument(
"--vad_pred_method",
help="suffix of prediction file. Should be either in 'frame', 'mean' or 'median'",
@ -46,13 +83,50 @@ if __name__ == "__main__":
)
args = parser.parse_args()
params = {}
try:
start, stop, step = [float(i) for i in args.threshold_range.split(",")]
thresholds = np.arange(start, stop, step)
except:
raise ValueError("Theshold input is invalid! Please enter it as a 'START,STOP,STEP' ")
# if not input range for values of parameters, use default value defined in function binarization and filtering in nemo/collections/asr/parts/utils/vad_utils.py
if args.onset_range:
start, stop, step = [float(i) for i in args.onset_range.split(",")]
onset = np.arange(start, stop, step)
params['onset'] = onset
best_threhsold = vad_tune_threshold_on_dev(
thresholds, args.vad_pred, args.groundtruth_RTTM, args.vad_pred_method, args.focus_metric
if args.offset_range:
start, stop, step = [float(i) for i in args.offset_range.split(",")]
offset = np.arange(start, stop, step)
params['offset'] = offset
if args.pad_onset_range:
start, stop, step = [float(i) for i in args.pad_onset_range.split(",")]
pad_onset = np.arange(start, stop, step)
params['pad_onset'] = pad_onset
if args.pad_offset_range:
start, stop, step = [float(i) for i in args.pad_offset_range.split(",")]
pad_offset = np.arange(start, stop, step)
params['pad_offset'] = pad_offset
if args.min_duration_on_range:
start, stop, step = [float(i) for i in args.min_duration_on_range.split(",")]
min_duration_on = np.arange(start, stop, step)
params['min_duration_on'] = min_duration_on
if args.min_duration_off_range:
start, stop, step = [float(i) for i in args.min_duration_off_range.split(",")]
min_duration_off = np.arange(start, stop, step)
params['min_duration_off'] = min_duration_off
if args.not_filter_speech_first:
params['filter_speech_first'] = False
except:
raise ValueError(
"Theshold input is invalid! Please enter it as a 'START,STOP,STEP' for onset, offset, min_duration_on and min_duration_off, and enter True/False for filter_speech_first"
)
best_threhsold, optimal_scores = vad_tune_threshold_on_dev(
params, args.vad_pred, args.groundtruth_RTTM, args.result_file, args.vad_pred_method, args.focus_metric
)
logging.info(
f"Best combination of thresholds for binarization selected from input ranges is {best_threhsold}, and the optimal score is {optimal_scores}"
)
logging.info(f"Best threshold selected from {thresholds} is {best_threhsold}!")

View file

@ -43,7 +43,7 @@
"This notebook demonstrates how to perform\n",
"1. [offline streaming inference on audio files (offline VAD)](#Offline-streaming-inference);\n",
"2. [finetuning](#Finetune) and use [posterior](#Posterior);\n",
"2. [threshold tuning](#Tuning-threshold);\n",
"2. [vad postproceesing and threshold tuning](#VAD-postprocessing-and-Tuning-threshold);\n",
"4. [online streaming inference](#Online-streaming-inference);\n",
"3. [online streaming inference from a microphone's stream](#Online-streaming-inference-through-microphone).\n"
]
@ -236,13 +236,23 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tuning threshold"
"## VAD postprocessing and Tuning threshold"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use a single **threshold** to binarize predictions or use typical VAD postpocessing including\n",
"\n",
"### Binarization:\n",
"1. **onset** and **offset** threshold for detecting the beginning and end of a speech;\n",
"2. padding durations before (**pad_onset**) and after (**pad_offset**) each speech segment.\n",
"\n",
"### Filtering:\n",
"1. threshold for short speech segment deletion (**min_duration_on**);\n",
"2. threshold for small silence deletion (**min_duration_off**);\n",
"3. Whether to perform short speech segment deletion first (**filter_speech_first**).\n",
"\n",
"\n",
"Of course you can do threshold tuning on frame level prediction. We also provide a script \n",
@ -250,7 +260,7 @@
"<NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py\n",
"```\n",
"\n",
"to help you find best threshold if you have ground truth label file in RTTM format. "
"to help you find best thresholds if you have ground truth label file in RTTM format. "
]
},
{
@ -354,6 +364,7 @@
"# 1) use reset() method to reset FrameVAD's state\n",
"# 2) call transcribe(frame) to do VAD on\n",
"# contiguous signal's frames\n",
"# To simplify the flow, we use single threshold to binarize predictions.\n",
"class FrameVAD:\n",
" \n",
" def __init__(self, model_definition,\n",
@ -579,9 +590,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The above prediction is based on `threshold=0.4`.\n",
"To simplify the flow, the above prediction is based on single threshold and `threshold=0.4`.\n",
"\n",
"You can play with other [threshold](#Tuning-threshold) and see how they would impact performance. \n",
"You can play with other [threshold](#VAD-postprocessing-and-Tuning-threshold) or use postprocessing and see how they would impact performance. \n",
"\n",
"**Note** if you want better performance, [finetune](#Finetune) on your data and use posteriors such as [overlapped prediction](#Posterior). \n",
"\n",

View file

@ -255,7 +255,7 @@
"from omegaconf import OmegaConf\n",
"MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')\n",
"if not os.path.exists(MODEL_CONFIG):\n",
" config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/stable/examples/speaker_recognition/conf/speaker_diarization.yaml\"\n",
" config_url = \"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_recognition/conf/speaker_diarization.yaml\"\n",
" MODEL_CONFIG = wget.download(config_url,data_dir)\n",
"config = OmegaConf.load(MODEL_CONFIG)\n",
"print(OmegaConf.to_yaml(config))"
@ -422,7 +422,11 @@
"config.diarizer.vad.model_path = pretrained_vad\n",
"config.diarizer.vad.window_length_in_sec = 0.15\n",
"config.diarizer.vad.shift_length_in_sec = 0.01\n",
"config.diarizer.vad.threshold = 0.8"
"# config.diarizer.vad.threshold = 0.8 threshold would be deprecated in release 1.5\n",
"config.diarizer.vad.postprocessing_params.onset = 0.8 \n",
"config.diarizer.vad.postprocessing_params.offset = 0.7\n",
"config.diarizer.vad.postprocessing_params.min_duration_on = 0.1\n",
"config.diarizer.vad.postprocessing_params.min_duration_off = 0.3"
]
},
{
@ -479,19 +483,20 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"# VAD predicted time stamps\n",
"from nemo.collections.asr.parts.utils.vad_utils import extract_labels, plot\n",
"# you can also use single threshold(=onset=offset) for binarization and plot here\n",
"from nemo.collections.asr.parts.utils.vad_utils import plot\n",
"plot(\n",
" paths2audio_files[0],\n",
" 'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n",
" path2groundtruth_rttm_files[0],\n",
" per_args = config.diarizer.vad.postprocessing_params, #threshold\n",
" ) \n",
"\n",
"plot(paths2audio_files[0],\n",
" 'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', \n",
" path2groundtruth_rttm_files[0],\n",
" threshold=config.diarizer.vad.threshold)\n",
"print(f\"threshold: {config.diarizer.vad.threshold}\")"
"print(f\"postprocessing_params: {config.diarizer.vad.postprocessing_params}\")"
]
},
{
@ -629,4 +634,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}