Adding bucketing for ASR models with tarred datasets (#2990)

This commit is contained in:
Vahid Noroozi 2021-10-13 21:40:50 -07:00 committed by GitHub
parent 0e5a1f84c4
commit a07dfa9d78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 217 additions and 114 deletions

View file

@ -960,7 +960,7 @@ class _TarredAudioToTextDataset(IterableDataset):
def __init__(
self,
audio_tar_filepaths: Union[str, List[str]],
manifest_filepath: str,
manifest_filepath: Union[str, List[str]],
parser: Callable,
sample_rate: int,
int_values: bool = False,
@ -978,7 +978,7 @@ class _TarredAudioToTextDataset(IterableDataset):
world_size: int = 0,
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(','),
manifests_files=manifest_filepath,
parser=parser,
min_duration=min_duration,
max_duration=max_duration,
@ -1336,7 +1336,7 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset):
def __init__(
self,
audio_tar_filepaths: Union[str, List[str]],
manifest_filepath: str,
manifest_filepath: Union[str, List[str]],
tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec',
sample_rate: int,
int_values: bool = False,

View file

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union
import torch
from omegaconf import DictConfig, open_dict
from omegaconf.listconfig import ListConfig
from torch.utils.data import ChainDataset
from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
from nemo.utils import logging
@ -122,60 +124,21 @@ def get_bpe_dataset(
return dataset
def get_tarred_char_dataset(
config: dict, shuffle_n: int, global_rank: int, world_size: int, augmentor: Optional['AudioAugmentor'] = None
) -> audio_to_text.TarredAudioToCharDataset:
"""
Instantiates a Character Encoding based TarredAudioToCharDataset.
Args:
config: Config of the TarredAudioToCharDataset.
shuffle_n: How many samples to look ahead and load to be shuffled.
See WebDataset documentation for more details.
global_rank: Global rank of this device.
world_size: Global world size in the training method.
augmentor: Optional AudioAugmentor object for augmentations on audio data.
Returns:
An instance of TarredAudioToCharDataset.
"""
dataset = audio_to_text.TarredAudioToCharDataset(
audio_tar_filepaths=config['tarred_audio_filepaths'],
manifest_filepath=config['manifest_filepath'],
labels=config['labels'],
sample_rate=config['sample_rate'],
int_values=config.get('int_values', False),
augmentor=augmentor,
shuffle_n=shuffle_n,
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
blank_index=config.get('blank_index', -1),
unk_index=config.get('unk_index', -1),
normalize=config.get('normalize_transcripts', False),
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
)
return dataset
def get_tarred_bpe_dataset(
def get_tarred_dataset(
config: dict,
tokenizer: 'TokenizerSpec',
shuffle_n: int,
global_rank: int,
world_size: int,
tokenizer: Optional['TokenizerSpec'] = None,
augmentor: Optional['AudioAugmentor'] = None,
) -> audio_to_text.TarredAudioToBPEDataset:
) -> Union[audio_to_text.TarredAudioToBPEDataset, audio_to_text.TarredAudioToCharDataset]:
"""
Instantiates a Byte Pair Encoding / Word Piece Encoding based TarredAudioToBPEDataset.
Instantiates a Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.
Args:
config: Config of the TarredAudioToBPEDataset.
tokenizer: An instance of a TokenizerSpec object.
config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
Passsing None would return a char-based dataset.
shuffle_n: How many samples to look ahead and load to be shuffled.
See WebDataset documentation for more details.
global_rank: Global rank of this device.
@ -183,26 +146,68 @@ def get_tarred_bpe_dataset(
augmentor: Optional AudioAugmentor object for augmentations on audio data.
Returns:
An instance of TarredAudioToBPEDataset.
An instance of TarredAudioToBPEDataset or TarredAudioToCharDataset.
"""
dataset = audio_to_text.TarredAudioToBPEDataset(
audio_tar_filepaths=config['tarred_audio_filepaths'],
manifest_filepath=config['manifest_filepath'],
tokenizer=tokenizer,
sample_rate=config['sample_rate'],
int_values=config.get('int_values', False),
augmentor=augmentor,
shuffle_n=shuffle_n,
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
)
return dataset
tarred_audio_filepaths = config['tarred_audio_filepaths']
manifest_filepaths = config['manifest_filepath']
datasets = []
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
manifest_filepaths = convert_to_config_list(manifest_filepaths)
if len(manifest_filepaths) != len(tarred_audio_filepaths):
raise ValueError(f"manifest_filepaths and tarred_audio_filepaths need to have the same number of buckets.")
for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
zip(tarred_audio_filepaths, manifest_filepaths)
):
if len(tarred_audio_filepath) == 1:
tarred_audio_filepath = tarred_audio_filepath[0]
if tokenizer is None:
dataset = audio_to_text.TarredAudioToCharDataset(
audio_tar_filepaths=tarred_audio_filepath,
manifest_filepath=manifest_filepath,
labels=config['labels'],
sample_rate=config['sample_rate'],
int_values=config.get('int_values', False),
augmentor=augmentor,
shuffle_n=shuffle_n,
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
blank_index=config.get('blank_index', -1),
unk_index=config.get('unk_index', -1),
normalize=config.get('normalize_transcripts', False),
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
)
else:
dataset = audio_to_text.TarredAudioToBPEDataset(
audio_tar_filepaths=tarred_audio_filepath,
manifest_filepath=manifest_filepath,
tokenizer=tokenizer,
sample_rate=config['sample_rate'],
int_values=config.get('int_values', False),
augmentor=augmentor,
shuffle_n=shuffle_n,
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
)
datasets.append(dataset)
if len(datasets) > 1:
return ChainDataset(datasets)
else:
return datasets[0]
def get_dali_char_dataset(
@ -292,3 +297,19 @@ def get_dali_bpe_dataset(
preprocessor_cfg=preprocessor_cfg,
)
return dataset
def convert_to_config_list(initial_list):
if initial_list is None or initial_list == []:
raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")
if not isinstance(initial_list, ListConfig):
initial_list = ListConfig([initial_list])
for list_idx, list_val in enumerate(initial_list):
if type(list_val) != type(initial_list[0]):
raise ValueError(
"manifest_filepaths and tarred_audio_filepaths need to be a list of lists for bucketing or just a list of strings"
)
if type(initial_list[0]) is not ListConfig:
initial_list = ListConfig([initial_list])
return initial_list

View file

@ -18,6 +18,7 @@ from typing import Dict, Optional
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from torch.utils.data import ChainDataset
from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.losses.ctc import CTCLoss
@ -221,7 +222,7 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_bpe_dataset(
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
tokenizer=self.tokenizer,
shuffle_n=shuffle_n,
@ -238,11 +239,15 @@ class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
dataset = audio_to_text_dataset.get_bpe_dataset(
config=config, tokenizer=self.tokenizer, augmentor=augmentor
)
if type(dataset) is ChainDataset:
collate_fn = dataset.datasets[0].collate_fn
else:
collate_fn = dataset.collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=dataset.collate_fn,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),

View file

@ -21,6 +21,7 @@ from typing import Dict, List, Optional, Union
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import ChainDataset
from tqdm.auto import tqdm
from nemo.collections.asr.data import audio_to_text_dataset
@ -382,7 +383,7 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_char_dataset(
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
shuffle_n=shuffle_n,
global_rank=self.global_rank,
@ -397,10 +398,15 @@ class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
if type(dataset) is ChainDataset:
collate_fn = dataset.datasets[0].collate_fn
else:
collate_fn = dataset.collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=dataset.collate_fn,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),

View file

@ -19,6 +19,7 @@ from typing import Dict, List, Optional
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import ChainDataset
from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.losses.rnnt import RNNTLoss
@ -299,7 +300,7 @@ class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin):
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_bpe_dataset(
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
tokenizer=self.tokenizer,
shuffle_n=shuffle_n,
@ -317,10 +318,15 @@ class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin):
config=config, tokenizer=self.tokenizer, augmentor=augmentor
)
if type(dataset) is ChainDataset:
collate_fn = dataset.datasets[0].collate_fn
else:
collate_fn = dataset.collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=dataset.collate_fn,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),

View file

@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Union
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import ChainDataset
from tqdm.auto import tqdm
from nemo.collections.asr.data import audio_to_text_dataset
@ -437,7 +438,7 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_char_dataset(
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
shuffle_n=shuffle_n,
global_rank=self.global_rank,
@ -452,10 +453,15 @@ class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecJointModel):
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
if type(dataset) is ChainDataset:
collate_fn = dataset.datasets[0].collate_fn
else:
collate_fn = dataset.collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=dataset.collate_fn,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),

View file

@ -26,6 +26,7 @@ import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from torch import nn
from torch.utils.data import ChainDataset
from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.losses.wav2vecloss import Wav2VecLoss
@ -478,7 +479,7 @@ class Wav2VecEncoderModel(ModelPT):
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_char_dataset(
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
shuffle_n=shuffle_n,
global_rank=self.global_rank,
@ -493,10 +494,15 @@ class Wav2VecEncoderModel(ModelPT):
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
if type(dataset) is ChainDataset:
collate_fn = dataset.datasets[0].collate_fn
else:
collate_fn = dataset.collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=dataset.collate_fn,
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),

View file

@ -20,6 +20,12 @@
# Because we will use it to handle files which have duplicate filenames but with different offsets
# (see function create_shard for details)
# Recommend to use --sort_in_shards to speedup the training by reducing the paddings in the batches
# Bucketing can also help to improve the training speed. You may use --buckets_num to specify the number of buckets.
# It creates multiple tarred datasets, one per bucket, based on the audio durations.
# The range of [min_duration, max_duration) is split into equal sized buckets.
# Usage:
1) Creating a new tarfile dataset
@ -30,6 +36,7 @@ python convert_to_tarred_audio_dataset.py \
--max_duration=<float representing maximum duration of audio samples> \
--min_duration=<float representing minimum duration of audio samples> \
--shuffle --shuffle_seed=1
--sort_in_shards
2) Concatenating more tarfiles to a pre-existing tarred dataset
@ -41,8 +48,9 @@ python convert_to_tarred_audio_dataset.py \
--max_duration=<float representing maximum duration of audio samples> \
--min_duration=<float representing minimum duration of audio samples> \
--shuffle --shuffle_seed=1 \
--sort_in_shards
--concat_manifest_paths \
<space seperated paths to 1 or more manifest files to concatenate into the original tarred dataset>
<space separated paths to 1 or more manifest files to concatenate into the original tarred dataset>
3) Writing an empty metadata file
@ -53,6 +61,7 @@ python convert_to_tarred_audio_dataset.py \
--max_duration=16.7 \
--min_duration=0.01 \
--shuffle \
--sort_in_shards
--shuffle_seed=1 \
--write_metadata
@ -122,7 +131,18 @@ parser.add_argument(
action='store_true',
help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.",
)
parser.add_argument("--shuffle_seed", type=int, help="Random seed for use if shuffling is enabled.")
parser.add_argument(
"--sort_in_shards",
action='store_true',
help="Whether or not to sort samples inside the shards based on their duration.",
)
parser.add_argument(
"--buckets_num", type=int, default=1, help="Number of buckets to create based on duration.",
)
parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.")
parser.add_argument(
'--write_metadata',
action='store_true',
@ -143,6 +163,7 @@ class ASRTarredDatasetConfig:
max_duration: Optional[float] = None
min_duration: Optional[float] = None
shuffle_seed: Optional[int] = None
sort_in_shards: bool = True
@dataclass
@ -225,6 +246,9 @@ class ASRTarredDatasetBuilder:
if len(filtered_entries) > 0:
print(f"Filtered {len(filtered_entries)} files which amounts to {filtered_duration} seconds of audio.")
if len(entries) == 0:
print("No tarred dataset was created as there were 0 valid samples after filtering!")
return
if config.shuffle:
random.seed(config.shuffle_seed)
print("Shuffling...")
@ -248,10 +272,12 @@ class ASRTarredDatasetBuilder:
start_indices.append(start_idx)
end_indices.append(end_idx)
manifest_folder, _ = os.path.split(manifest_path)
with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel:
# Call parallel tarfile construction
new_entries_list = parallel(
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i)
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder)
for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices))
)
@ -348,6 +374,10 @@ class ASRTarredDatasetBuilder:
entries.extend(new_entries)
if len(entries) == 0:
print("No tarred dataset was created as there were 0 valid samples after filtering!")
return
if config.shuffle:
random.seed(config.shuffle_seed)
print("Shuffling...")
@ -384,10 +414,12 @@ class ASRTarredDatasetBuilder:
end_indices.append(end_idx)
shard_indices.append(shard_idx)
manifest_folder, _ = os.path.split(base_manifest_path)
with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel:
# Call parallel tarfile construction
new_entries_list = parallel(
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, shard_idx)
delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder)
for i, (start_idx, end_idx, shard_idx) in enumerate(zip(start_indices, end_indices, shard_indices))
)
@ -450,7 +482,7 @@ class ASRTarredDatasetBuilder:
for line in m:
entry = json.loads(line)
if (config.max_duration is None or entry['duration'] < config.max_duration) and (
config.min_duration is None or entry['duration'] > config.min_duration
config.min_duration is None or entry['duration'] >= config.min_duration
):
entries.append(entry)
else:
@ -459,22 +491,32 @@ class ASRTarredDatasetBuilder:
return entries, filtered_entries, filtered_duration
def _create_shard(self, entries, target_dir, shard_id):
def _create_shard(self, entries, target_dir, shard_id, manifest_folder):
"""Creates a tarball containing the audio files from `entries`.
"""
if self.config.sort_in_shards:
entries.sort(key=lambda x: x["duration"], reverse=False)
new_entries = []
tar = tarfile.open(os.path.join(target_dir, f'audio_{shard_id}.tar'), mode='w', dereference=True)
count = dict()
for entry in entries:
# We squash the filename since we do not preserve directory structure of audio files in the tarball.
base, ext = os.path.splitext(entry['audio_filepath'])
if os.path.exists(entry["audio_filepath"]):
audio_filepath = entry["audio_filepath"]
else:
audio_filepath = os.path.join(manifest_folder, entry["audio_filepath"])
if not os.path.exists(audio_filepath):
raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!")
base, ext = os.path.splitext(audio_filepath)
base = base.replace('/', '_')
# Need the following replacement as long as WebDataset splits on first period
base = base.replace('.', '_')
squashed_filename = f'{base}{ext}'
if squashed_filename not in count:
tar.add(entry['audio_filepath'], arcname=squashed_filename)
tar.add(audio_filepath, arcname=squashed_filename)
if 'label' in entry:
base, ext = os.path.splitext(squashed_filename)
@ -522,28 +564,35 @@ class ASRTarredDatasetBuilder:
def main():
manifest_path = args.manifest_path
concat_manifest_paths = args.concat_manifest_paths
target_dir = args.target_dir
metadata_path = args.metadata_path
num_shards = args.num_shards
max_duration = args.max_duration
min_duration = args.min_duration
shuffle = args.shuffle
seed = args.shuffle_seed if args.shuffle_seed else None
write_metadata = args.write_metadata
num_workers = args.workers
if args.buckets_num > 1:
bucket_length = (args.max_duration - args.min_duration) / float(args.buckets_num)
for i in range(args.buckets_num):
min_duration = args.min_duration + i * bucket_length
max_duration = min_duration + bucket_length
if i == args.buckets_num - 1:
# add a small number to cover the samples with exactly duration of max_duration in the last bucket.
max_duration += 1e-5
target_dir = os.path.join(args.target_dir, f"bucket{i+1}")
print(f"Creating bucket {i+1} with min_duration={min_duration} and max_duration={max_duration} ...")
print(f"Results are being saved at: {target_dir}.")
create_tar_datasets(min_duration=min_duration, max_duration=max_duration, target_dir=target_dir)
print(f"Bucket {i+1} is created.")
else:
create_tar_datasets(min_duration=args.min_duration, max_duration=args.max_duration, target_dir=args.target_dir)
def create_tar_datasets(min_duration: float, max_duration: float, target_dir: str):
builder = ASRTarredDatasetBuilder()
if write_metadata:
if args.write_metadata:
metadata = ASRTarredDatasetMetadata()
dataset_cfg = ASRTarredDatasetConfig(
num_shards=num_shards,
shuffle=shuffle,
num_shards=args.num_shards,
shuffle=args.shuffle,
max_duration=max_duration,
min_duration=min_duration,
shuffle_seed=seed,
shuffle_seed=args.shuffle_seed,
sort_in_shards=args.sort_in_shards,
)
metadata.dataset_config = dataset_cfg
@ -552,26 +601,29 @@ def main():
print(f"Default metadata written to {output_path}")
exit(0)
if concat_manifest_paths is None or len(concat_manifest_paths) == 0:
if args.concat_manifest_paths is None or len(args.concat_manifest_paths) == 0:
print("Creating new tarred dataset ...")
# Create a tarred dataset from scratch
config = ASRTarredDatasetConfig(
num_shards=num_shards,
shuffle=shuffle,
num_shards=args.num_shards,
shuffle=args.shuffle,
max_duration=max_duration,
min_duration=min_duration,
shuffle_seed=seed,
shuffle_seed=args.shuffle_seed,
sort_in_shards=args.sort_in_shards,
)
builder.configure(config)
builder.create_new_dataset(manifest_path=manifest_path, target_dir=target_dir, num_workers=num_workers)
builder.create_new_dataset(manifest_path=args.manifest_path, target_dir=target_dir, num_workers=args.workers)
else:
if args.buckets_num > 1:
raise ValueError("Concatenation feature does not support buckets_num > 1.")
print("Concatenating multiple tarred datasets ...")
# Implicitly update config from base details
if metadata_path is not None:
metadata = ASRTarredDatasetMetadata.from_file(metadata_path)
if args.metadata_path is not None:
metadata = ASRTarredDatasetMetadata.from_file(args.metadata_path)
else:
raise ValueError("`metadata` yaml file path must be provided!")
@ -583,18 +635,19 @@ def main():
# Add command line overrides (everything other than num_shards)
metadata.dataset_config.max_duration = max_duration
metadata.dataset_config.min_duration = min_duration
metadata.dataset_config.shuffle = shuffle
metadata.dataset_config.shuffle_seed = seed
metadata.dataset_config.shuffle = args.shuffle
metadata.dataset_config.shuffle_seed = args.shuffle_seed
metadata.dataset_config.sort_in_shards = args.sort_in_shards
builder.configure(metadata.dataset_config)
# Concatenate a tarred dataset onto a previous one
builder.create_concatenated_dataset(
base_manifest_path=manifest_path,
manifest_paths=concat_manifest_paths,
base_manifest_path=args.manifest_path,
manifest_paths=args.concat_manifest_paths,
metadata=metadata,
target_dir=target_dir,
num_workers=num_workers,
num_workers=args.workers,
)