Adding bucketing for ASR models with tarred datasets (#2990)
This commit is contained in:
parent
0e5a1f84c4
commit
a07dfa9d78
|
@ -960,7 +960,7 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
def __init__(
|
||||
self,
|
||||
audio_tar_filepaths: Union[str, List[str]],
|
||||
manifest_filepath: str,
|
||||
manifest_filepath: Union[str, List[str]],
|
||||
parser: Callable,
|
||||
sample_rate: int,
|
||||
int_values: bool = False,
|
||||
|
@ -978,7 +978,7 @@ class _TarredAudioToTextDataset(IterableDataset):
|
|||
world_size: int = 0,
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(','),
|
||||
manifests_files=manifest_filepath,
|
||||
parser=parser,
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
|
@ -1336,7 +1336,7 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
|
|||
def __init__(
|
||||
self,
|
||||
audio_tar_filepaths: Union[str, List[str]],
|
||||
manifest_filepath: str,
|
||||
manifest_filepath: Union[str, List[str]],
|
||||
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
|
||||
sample_rate: int,
|
||||
int_values: bool = False,
|
||||
|
|
|
@ -12,10 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from omegaconf.listconfig import ListConfig
|
||||
from torch.utils.data import ChainDataset
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
|
||||
from nemo.utils import logging
|
||||
|
@ -122,60 +124,21 @@ def get_bpe_dataset(
|
|||
return dataset
|
||||
|
||||
|
||||
def get_tarred_char_dataset(
|
||||
config: dict, shuffle_n: int, global_rank: int, world_size: int, augmentor: Optional['AudioAugmentor'] = None
|
||||
) -> audio_to_text.TarredAudioToCharDataset:
|
||||
"""
|
||||
Instantiates a Character Encoding based TarredAudioToCharDataset.
|
||||
|
||||
Args:
|
||||
config: Config of the TarredAudioToCharDataset.
|
||||
shuffle_n: How many samples to look ahead and load to be shuffled.
|
||||
See WebDataset documentation for more details.
|
||||
global_rank: Global rank of this device.
|
||||
world_size: Global world size in the training method.
|
||||
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
||||
|
||||
Returns:
|
||||
An instance of TarredAudioToCharDataset.
|
||||
"""
|
||||
dataset = audio_to_text.TarredAudioToCharDataset(
|
||||
audio_tar_filepaths=config['tarred_audio_filepaths'],
|
||||
manifest_filepath=config['manifest_filepath'],
|
||||
labels=config['labels'],
|
||||
sample_rate=config['sample_rate'],
|
||||
int_values=config.get('int_values', False),
|
||||
augmentor=augmentor,
|
||||
shuffle_n=shuffle_n,
|
||||
max_duration=config.get('max_duration', None),
|
||||
min_duration=config.get('min_duration', None),
|
||||
max_utts=config.get('max_utts', 0),
|
||||
blank_index=config.get('blank_index', -1),
|
||||
unk_index=config.get('unk_index', -1),
|
||||
normalize=config.get('normalize_transcripts', False),
|
||||
trim=config.get('trim_silence', False),
|
||||
parser=config.get('parser', 'en'),
|
||||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_tarred_bpe_dataset(
|
||||
def get_tarred_dataset(
|
||||
config: dict,
|
||||
tokenizer: 'TokenizerSpec',
|
||||
shuffle_n: int,
|
||||
global_rank: int,
|
||||
world_size: int,
|
||||
tokenizer: Optional['TokenizerSpec'] = None,
|
||||
augmentor: Optional['AudioAugmentor'] = None,
|
||||
) -> audio_to_text.TarredAudioToBPEDataset:
|
||||
) -> Union[audio_to_text.TarredAudioToBPEDataset, audio_to_text.TarredAudioToCharDataset]:
|
||||
"""
|
||||
Instantiates a Byte Pair Encoding / Word Piece Encoding based TarredAudioToBPEDataset.
|
||||
Instantiates a Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.
|
||||
|
||||
Args:
|
||||
config: Config of the TarredAudioToBPEDataset.
|
||||
tokenizer: An instance of a TokenizerSpec object.
|
||||
config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
|
||||
tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
|
||||
Passsing None would return a char-based dataset.
|
||||
shuffle_n: How many samples to look ahead and load to be shuffled.
|
||||
See WebDataset documentation for more details.
|
||||
global_rank: Global rank of this device.
|
||||
|
@ -183,26 +146,68 @@ def get_tarred_bpe_dataset(
|
|||
augmentor: Optional AudioAugmentor object for augmentations on audio data.
|
||||
|
||||
Returns:
|
||||
An instance of TarredAudioToBPEDataset.
|
||||
An instance of TarredAudioToBPEDataset or TarredAudioToCharDataset.
|
||||
"""
|
||||
dataset = audio_to_text.TarredAudioToBPEDataset(
|
||||
audio_tar_filepaths=config['tarred_audio_filepaths'],
|
||||
manifest_filepath=config['manifest_filepath'],
|
||||
tokenizer=tokenizer,
|
||||
sample_rate=config['sample_rate'],
|
||||
int_values=config.get('int_values', False),
|
||||
augmentor=augmentor,
|
||||
shuffle_n=shuffle_n,
|
||||
max_duration=config.get('max_duration', None),
|
||||
min_duration=config.get('min_duration', None),
|
||||
max_utts=config.get('max_utts', 0),
|
||||
trim=config.get('trim_silence', False),
|
||||
use_start_end_token=config.get('use_start_end_token', True),
|
||||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
return dataset
|
||||
tarred_audio_filepaths = config['tarred_audio_filepaths']
|
||||
manifest_filepaths = config['manifest_filepath']
|
||||
datasets = []
|
||||
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
|
||||
manifest_filepaths = convert_to_config_list(manifest_filepaths)
|
||||
|
||||
if len(manifest_filepaths) != len(tarred_audio_filepaths):
|
||||
raise ValueError(f"manifest_filepaths and tarred_audio_filepaths need to have the same number of buckets.")
|
||||
|
||||
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
|
||||
zip(tarred_audio_filepaths, manifest_filepaths)
|
||||
):
|
||||
if len(tarred_audio_filepath) == 1:
|
||||
tarred_audio_filepath = tarred_audio_filepath[0]
|
||||
if tokenizer is None:
|
||||
dataset = audio_to_text.TarredAudioToCharDataset(
|
||||
audio_tar_filepaths=tarred_audio_filepath,
|
||||
manifest_filepath=manifest_filepath,
|
||||
labels=config['labels'],
|
||||
sample_rate=config['sample_rate'],
|
||||
int_values=config.get('int_values', False),
|
||||
augmentor=augmentor,
|
||||
shuffle_n=shuffle_n,
|
||||
max_duration=config.get('max_duration', None),
|
||||
min_duration=config.get('min_duration', None),
|
||||
max_utts=config.get('max_utts', 0),
|
||||
blank_index=config.get('blank_index', -1),
|
||||
unk_index=config.get('unk_index', -1),
|
||||
normalize=config.get('normalize_transcripts', False),
|
||||
trim=config.get('trim_silence', False),
|
||||
parser=config.get('parser', 'en'),
|
||||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
else:
|
||||
dataset = audio_to_text.TarredAudioToBPEDataset(
|
||||
audio_tar_filepaths=tarred_audio_filepath,
|
||||
manifest_filepath=manifest_filepath,
|
||||
tokenizer=tokenizer,
|
||||
sample_rate=config['sample_rate'],
|
||||
int_values=config.get('int_values', False),
|
||||
augmentor=augmentor,
|
||||
shuffle_n=shuffle_n,
|
||||
max_duration=config.get('max_duration', None),
|
||||
min_duration=config.get('min_duration', None),
|
||||
max_utts=config.get('max_utts', 0),
|
||||
trim=config.get('trim_silence', False),
|
||||
use_start_end_token=config.get('use_start_end_token', True),
|
||||
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
|
||||
global_rank=global_rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
datasets.append(dataset)
|
||||
|
||||
if len(datasets) > 1:
|
||||
return ChainDataset(datasets)
|
||||
else:
|
||||
return datasets[0]
|
||||
|
||||
|
||||
def get_dali_char_dataset(
|
||||
|
@ -292,3 +297,19 @@ def get_dali_bpe_dataset(
|
|||
preprocessor_cfg=preprocessor_cfg,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_to_config_list(initial_list):
|
||||
if initial_list is None or initial_list == []:
|
||||
raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")
|
||||
if not isinstance(initial_list, ListConfig):
|
||||
initial_list = ListConfig([initial_list])
|
||||
|
||||
for list_idx, list_val in enumerate(initial_list):
|
||||
if type(list_val) != type(initial_list[0]):
|
||||
raise ValueError(
|
||||
"manifest_filepaths and tarred_audio_filepaths need to be a list of lists for bucketing or just a list of strings"
|
||||
)
|
||||
if type(initial_list[0]) is not ListConfig:
|
||||
initial_list = ListConfig([initial_list])
|
||||
return initial_list
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Dict, Optional
|
|||
|
||||
import torch
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
|
||||
from torch.utils.data import ChainDataset
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text_dataset
|
||||
from nemo.collections.asr.losses.ctc import CTCLoss
|
||||
|
@ -221,7 +222,7 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
|
|||
return None
|
||||
|
||||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
||||
dataset = audio_to_text_dataset.get_tarred_bpe_dataset(
|
||||
dataset = audio_to_text_dataset.get_tarred_dataset(
|
||||
config=config,
|
||||
tokenizer=self.tokenizer,
|
||||
shuffle_n=shuffle_n,
|
||||
|
@ -238,11 +239,15 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
|
|||
dataset = audio_to_text_dataset.get_bpe_dataset(
|
||||
config=config, tokenizer=self.tokenizer, augmentor=augmentor
|
||||
)
|
||||
if type(dataset) is ChainDataset:
|
||||
collate_fn = dataset.datasets[0].collate_fn
|
||||
else:
|
||||
collate_fn = dataset.collate_fn
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=config['batch_size'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=config.get('drop_last', False),
|
||||
shuffle=shuffle,
|
||||
num_workers=config.get('num_workers', 0),
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import Dict, List, Optional, Union
|
|||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning import Trainer
|
||||
from torch.utils.data import ChainDataset
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text_dataset
|
||||
|
@ -382,7 +383,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
return None
|
||||
|
||||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
||||
dataset = audio_to_text_dataset.get_tarred_char_dataset(
|
||||
dataset = audio_to_text_dataset.get_tarred_dataset(
|
||||
config=config,
|
||||
shuffle_n=shuffle_n,
|
||||
global_rank=self.global_rank,
|
||||
|
@ -397,10 +398,15 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
|
|||
|
||||
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
|
||||
|
||||
if type(dataset) is ChainDataset:
|
||||
collate_fn = dataset.datasets[0].collate_fn
|
||||
else:
|
||||
collate_fn = dataset.collate_fn
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=config['batch_size'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=config.get('drop_last', False),
|
||||
shuffle=shuffle,
|
||||
num_workers=config.get('num_workers', 0),
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Dict, List, Optional
|
|||
import torch
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning import Trainer
|
||||
from torch.utils.data import ChainDataset
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text_dataset
|
||||
from nemo.collections.asr.losses.rnnt import RNNTLoss
|
||||
|
@ -299,7 +300,7 @@ class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin):
|
|||
return None
|
||||
|
||||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
||||
dataset = audio_to_text_dataset.get_tarred_bpe_dataset(
|
||||
dataset = audio_to_text_dataset.get_tarred_dataset(
|
||||
config=config,
|
||||
tokenizer=self.tokenizer,
|
||||
shuffle_n=shuffle_n,
|
||||
|
@ -317,10 +318,15 @@ class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin):
|
|||
config=config, tokenizer=self.tokenizer, augmentor=augmentor
|
||||
)
|
||||
|
||||
if type(dataset) is ChainDataset:
|
||||
collate_fn = dataset.datasets[0].collate_fn
|
||||
else:
|
||||
collate_fn = dataset.collate_fn
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=config['batch_size'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=config.get('drop_last', False),
|
||||
shuffle=shuffle,
|
||||
num_workers=config.get('num_workers', 0),
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Union
|
|||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning import Trainer
|
||||
from torch.utils.data import ChainDataset
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text_dataset
|
||||
|
@ -437,7 +438,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
|
|||
return None
|
||||
|
||||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
||||
dataset = audio_to_text_dataset.get_tarred_char_dataset(
|
||||
dataset = audio_to_text_dataset.get_tarred_dataset(
|
||||
config=config,
|
||||
shuffle_n=shuffle_n,
|
||||
global_rank=self.global_rank,
|
||||
|
@ -452,10 +453,15 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
|
|||
|
||||
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
|
||||
|
||||
if type(dataset) is ChainDataset:
|
||||
collate_fn = dataset.datasets[0].collate_fn
|
||||
else:
|
||||
collate_fn = dataset.collate_fn
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=config['batch_size'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=config.get('drop_last', False),
|
||||
shuffle=shuffle,
|
||||
num_workers=config.get('num_workers', 0),
|
||||
|
|
|
@ -26,6 +26,7 @@ import torch
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning import Trainer
|
||||
from torch import nn
|
||||
from torch.utils.data import ChainDataset
|
||||
|
||||
from nemo.collections.asr.data import audio_to_text_dataset
|
||||
from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss
|
||||
|
@ -478,7 +479,7 @@ class Wav2VecEncoderModel(ModelPT):
|
|||
return None
|
||||
|
||||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
|
||||
dataset = audio_to_text_dataset.get_tarred_char_dataset(
|
||||
dataset = audio_to_text_dataset.get_tarred_dataset(
|
||||
config=config,
|
||||
shuffle_n=shuffle_n,
|
||||
global_rank=self.global_rank,
|
||||
|
@ -493,10 +494,15 @@ class Wav2VecEncoderModel(ModelPT):
|
|||
|
||||
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
|
||||
|
||||
if type(dataset) is ChainDataset:
|
||||
collate_fn = dataset.datasets[0].collate_fn
|
||||
else:
|
||||
collate_fn = dataset.collate_fn
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=config['batch_size'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=config.get('drop_last', False),
|
||||
shuffle=shuffle,
|
||||
num_workers=config.get('num_workers', 0),
|
||||
|
|
|
@ -20,6 +20,12 @@
|
|||
# Because we will use it to handle files which have duplicate filenames but with different offsets
|
||||
# (see function create_shard for details)
|
||||
|
||||
# Recommend to use --sort_in_shards to speedup the training by reducing the paddings in the batches
|
||||
|
||||
# Bucketing can also help to improve the training speed. You may use --buckets_num to specify the number of buckets.
|
||||
# It creates multiple tarred datasets, one per bucket, based on the audio durations.
|
||||
# The range of [min_duration, max_duration) is split into equal sized buckets.
|
||||
|
||||
# Usage:
|
||||
1) Creating a new tarfile dataset
|
||||
|
||||
|
@ -30,6 +36,7 @@ python convert_to_tarred_audio_dataset.py \
|
|||
--max_duration=<float representing maximum duration of audio samples> \
|
||||
--min_duration=<float representing minimum duration of audio samples> \
|
||||
--shuffle --shuffle_seed=1
|
||||
--sort_in_shards
|
||||
|
||||
|
||||
2) Concatenating more tarfiles to a pre-existing tarred dataset
|
||||
|
@ -41,8 +48,9 @@ python convert_to_tarred_audio_dataset.py \
|
|||
--max_duration=<float representing maximum duration of audio samples> \
|
||||
--min_duration=<float representing minimum duration of audio samples> \
|
||||
--shuffle --shuffle_seed=1 \
|
||||
--sort_in_shards
|
||||
--concat_manifest_paths \
|
||||
<space seperated paths to 1 or more manifest files to concatenate into the original tarred dataset>
|
||||
<space separated paths to 1 or more manifest files to concatenate into the original tarred dataset>
|
||||
|
||||
3) Writing an empty metadata file
|
||||
|
||||
|
@ -53,6 +61,7 @@ python convert_to_tarred_audio_dataset.py \
|
|||
--max_duration=16.7 \
|
||||
--min_duration=0.01 \
|
||||
--shuffle \
|
||||
--sort_in_shards
|
||||
--shuffle_seed=1 \
|
||||
--write_metadata
|
||||
|
||||
|
@ -122,7 +131,18 @@ parser.add_argument(
|
|||
action='store_true',
|
||||
help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.",
|
||||
)
|
||||
parser.add_argument("--shuffle_seed", type=int, help="Random seed for use if shuffling is enabled.")
|
||||
|
||||
parser.add_argument(
|
||||
"--sort_in_shards",
|
||||
action='store_true',
|
||||
help="Whether or not to sort samples inside the shards based on their duration.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--buckets_num", type=int, default=1, help="Number of buckets to create based on duration.",
|
||||
)
|
||||
|
||||
parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.")
|
||||
parser.add_argument(
|
||||
'--write_metadata',
|
||||
action='store_true',
|
||||
|
@ -143,6 +163,7 @@ class ASRTarredDatasetConfig:
|
|||
max_duration: Optional[float] = None
|
||||
min_duration: Optional[float] = None
|
||||
shuffle_seed: Optional[int] = None
|
||||
sort_in_shards: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -225,6 +246,9 @@ class ASRTarredDatasetBuilder:
|
|||
if len(filtered_entries) > 0:
|
||||
print(f"Filtered {len(filtered_entries)} files which amounts to {filtered_duration} seconds of audio.")
|
||||
|
||||
if len(entries) == 0:
|
||||
print("No tarred dataset was created as there were 0 valid samples after filtering!")
|
||||
return
|
||||
if config.shuffle:
|
||||
random.seed(config.shuffle_seed)
|
||||
print("Shuffling...")
|
||||
|
@ -248,10 +272,12 @@ class ASRTarredDatasetBuilder:
|
|||
start_indices.append(start_idx)
|
||||
end_indices.append(end_idx)
|
||||
|
||||
manifest_folder, _ = os.path.split(manifest_path)
|
||||
|
||||
with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel:
|
||||
# Call parallel tarfile construction
|
||||
new_entries_list = parallel(
|
||||
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i)
|
||||
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder)
|
||||
for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices))
|
||||
)
|
||||
|
||||
|
@ -348,6 +374,10 @@ class ASRTarredDatasetBuilder:
|
|||
|
||||
entries.extend(new_entries)
|
||||
|
||||
if len(entries) == 0:
|
||||
print("No tarred dataset was created as there were 0 valid samples after filtering!")
|
||||
return
|
||||
|
||||
if config.shuffle:
|
||||
random.seed(config.shuffle_seed)
|
||||
print("Shuffling...")
|
||||
|
@ -384,10 +414,12 @@ class ASRTarredDatasetBuilder:
|
|||
end_indices.append(end_idx)
|
||||
shard_indices.append(shard_idx)
|
||||
|
||||
manifest_folder, _ = os.path.split(base_manifest_path)
|
||||
|
||||
with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel:
|
||||
# Call parallel tarfile construction
|
||||
new_entries_list = parallel(
|
||||
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, shard_idx)
|
||||
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder)
|
||||
for i, (start_idx, end_idx, shard_idx) in enumerate(zip(start_indices, end_indices, shard_indices))
|
||||
)
|
||||
|
||||
|
@ -450,7 +482,7 @@ class ASRTarredDatasetBuilder:
|
|||
for line in m:
|
||||
entry = json.loads(line)
|
||||
if (config.max_duration is None or entry['duration'] < config.max_duration) and (
|
||||
config.min_duration is None or entry['duration'] > config.min_duration
|
||||
config.min_duration is None or entry['duration'] >= config.min_duration
|
||||
):
|
||||
entries.append(entry)
|
||||
else:
|
||||
|
@ -459,22 +491,32 @@ class ASRTarredDatasetBuilder:
|
|||
|
||||
return entries, filtered_entries, filtered_duration
|
||||
|
||||
def _create_shard(self, entries, target_dir, shard_id):
|
||||
def _create_shard(self, entries, target_dir, shard_id, manifest_folder):
|
||||
"""Creates a tarball containing the audio files from `entries`.
|
||||
"""
|
||||
if self.config.sort_in_shards:
|
||||
entries.sort(key=lambda x: x["duration"], reverse=False)
|
||||
|
||||
new_entries = []
|
||||
tar = tarfile.open(os.path.join(target_dir, f'audio_{shard_id}.tar'), mode='w', dereference=True)
|
||||
|
||||
count = dict()
|
||||
for entry in entries:
|
||||
# We squash the filename since we do not preserve directory structure of audio files in the tarball.
|
||||
base, ext = os.path.splitext(entry['audio_filepath'])
|
||||
if os.path.exists(entry["audio_filepath"]):
|
||||
audio_filepath = entry["audio_filepath"]
|
||||
else:
|
||||
audio_filepath = os.path.join(manifest_folder, entry["audio_filepath"])
|
||||
if not os.path.exists(audio_filepath):
|
||||
raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!")
|
||||
|
||||
base, ext = os.path.splitext(audio_filepath)
|
||||
base = base.replace('/', '_')
|
||||
# Need the following replacement as long as WebDataset splits on first period
|
||||
base = base.replace('.', '_')
|
||||
squashed_filename = f'{base}{ext}'
|
||||
if squashed_filename not in count:
|
||||
tar.add(entry['audio_filepath'], arcname=squashed_filename)
|
||||
tar.add(audio_filepath, arcname=squashed_filename)
|
||||
|
||||
if 'label' in entry:
|
||||
base, ext = os.path.splitext(squashed_filename)
|
||||
|
@ -522,28 +564,35 @@ class ASRTarredDatasetBuilder:
|
|||
|
||||
|
||||
def main():
|
||||
manifest_path = args.manifest_path
|
||||
concat_manifest_paths = args.concat_manifest_paths
|
||||
target_dir = args.target_dir
|
||||
metadata_path = args.metadata_path
|
||||
num_shards = args.num_shards
|
||||
max_duration = args.max_duration
|
||||
min_duration = args.min_duration
|
||||
shuffle = args.shuffle
|
||||
seed = args.shuffle_seed if args.shuffle_seed else None
|
||||
write_metadata = args.write_metadata
|
||||
num_workers = args.workers
|
||||
if args.buckets_num > 1:
|
||||
bucket_length = (args.max_duration - args.min_duration) / float(args.buckets_num)
|
||||
for i in range(args.buckets_num):
|
||||
min_duration = args.min_duration + i * bucket_length
|
||||
max_duration = min_duration + bucket_length
|
||||
if i == args.buckets_num - 1:
|
||||
# add a small number to cover the samples with exactly duration of max_duration in the last bucket.
|
||||
max_duration += 1e-5
|
||||
target_dir = os.path.join(args.target_dir, f"bucket{i+1}")
|
||||
print(f"Creating bucket {i+1} with min_duration={min_duration} and max_duration={max_duration} ...")
|
||||
print(f"Results are being saved at: {target_dir}.")
|
||||
create_tar_datasets(min_duration=min_duration, max_duration=max_duration, target_dir=target_dir)
|
||||
print(f"Bucket {i+1} is created.")
|
||||
else:
|
||||
create_tar_datasets(min_duration=args.min_duration, max_duration=args.max_duration, target_dir=args.target_dir)
|
||||
|
||||
|
||||
def create_tar_datasets(min_duration: float, max_duration: float, target_dir: str):
|
||||
builder = ASRTarredDatasetBuilder()
|
||||
|
||||
if write_metadata:
|
||||
if args.write_metadata:
|
||||
metadata = ASRTarredDatasetMetadata()
|
||||
dataset_cfg = ASRTarredDatasetConfig(
|
||||
num_shards=num_shards,
|
||||
shuffle=shuffle,
|
||||
num_shards=args.num_shards,
|
||||
shuffle=args.shuffle,
|
||||
max_duration=max_duration,
|
||||
min_duration=min_duration,
|
||||
shuffle_seed=seed,
|
||||
shuffle_seed=args.shuffle_seed,
|
||||
sort_in_shards=args.sort_in_shards,
|
||||
)
|
||||
metadata.dataset_config = dataset_cfg
|
||||
|
||||
|
@ -552,26 +601,29 @@ def main():
|
|||
print(f"Default metadata written to {output_path}")
|
||||
exit(0)
|
||||
|
||||
if concat_manifest_paths is None or len(concat_manifest_paths) == 0:
|
||||
if args.concat_manifest_paths is None or len(args.concat_manifest_paths) == 0:
|
||||
print("Creating new tarred dataset ...")
|
||||
|
||||
# Create a tarred dataset from scratch
|
||||
config = ASRTarredDatasetConfig(
|
||||
num_shards=num_shards,
|
||||
shuffle=shuffle,
|
||||
num_shards=args.num_shards,
|
||||
shuffle=args.shuffle,
|
||||
max_duration=max_duration,
|
||||
min_duration=min_duration,
|
||||
shuffle_seed=seed,
|
||||
shuffle_seed=args.shuffle_seed,
|
||||
sort_in_shards=args.sort_in_shards,
|
||||
)
|
||||
builder.configure(config)
|
||||
builder.create_new_dataset(manifest_path=manifest_path, target_dir=target_dir, num_workers=num_workers)
|
||||
builder.create_new_dataset(manifest_path=args.manifest_path, target_dir=target_dir, num_workers=args.workers)
|
||||
|
||||
else:
|
||||
if args.buckets_num > 1:
|
||||
raise ValueError("Concatenation feature does not support buckets_num > 1.")
|
||||
print("Concatenating multiple tarred datasets ...")
|
||||
|
||||
# Implicitly update config from base details
|
||||
if metadata_path is not None:
|
||||
metadata = ASRTarredDatasetMetadata.from_file(metadata_path)
|
||||
if args.metadata_path is not None:
|
||||
metadata = ASRTarredDatasetMetadata.from_file(args.metadata_path)
|
||||
else:
|
||||
raise ValueError("`metadata` yaml file path must be provided!")
|
||||
|
||||
|
@ -583,18 +635,19 @@ def main():
|
|||
# Add command line overrides (everything other than num_shards)
|
||||
metadata.dataset_config.max_duration = max_duration
|
||||
metadata.dataset_config.min_duration = min_duration
|
||||
metadata.dataset_config.shuffle = shuffle
|
||||
metadata.dataset_config.shuffle_seed = seed
|
||||
metadata.dataset_config.shuffle = args.shuffle
|
||||
metadata.dataset_config.shuffle_seed = args.shuffle_seed
|
||||
metadata.dataset_config.sort_in_shards = args.sort_in_shards
|
||||
|
||||
builder.configure(metadata.dataset_config)
|
||||
|
||||
# Concatenate a tarred dataset onto a previous one
|
||||
builder.create_concatenated_dataset(
|
||||
base_manifest_path=manifest_path,
|
||||
manifest_paths=concat_manifest_paths,
|
||||
base_manifest_path=args.manifest_path,
|
||||
manifest_paths=args.concat_manifest_paths,
|
||||
metadata=metadata,
|
||||
target_dir=target_dir,
|
||||
num_workers=num_workers,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue