DeepLearningExamples/PyTorch/LanguageModeling/BERT/data/utils/create_mixed_dataset_ids.py
Przemek Strzelczyk 0663b67c1a Updating models
2019-07-08 22:51:28 +02:00

135 lines
3.8 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
import os
import random
from io import open
import h5py
import numpy as np
from tqdm import tqdm, trange
import random
import collections
import math
from tqdm import tqdm
import multiprocessing as mp
import pickle
import json
"""
mixing hdf5 shards with each other
"""
def load_and_prepare(input_files, num_shards):
seq_len = None
pred_len = None
input_lengths = []
for input_file in input_files:
with h5py.File(input_file, 'r') as f:
input_lengths.append(len(f['input_ids']))
if seq_len is None:
seq_len = f['input_ids'].shape[1]
pred_len = f['masked_lm_ids'].shape[1]
assert (isinstance(seq_len, int) and isinstance(pred_len, int))
total_instances = sum(input_lengths)
n_inst_per_file = math.ceil(total_instances * 1.0 / num_shards)
permutation = np.random.permutation(total_instances)
instance_indices = []
for i in range(0, num_shards):
start_pos = i * n_inst_per_file
end_pos = min((i+1) * n_inst_per_file, total_instances)
instance_indices.append(permutation[start_pos:end_pos])
return seq_len, pred_len, input_lengths, instance_indices
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_files",
default=None,
type=str,
required=True,
help="comma seperated list of file paths, each path can be either file or directory of hdf5 files")
parser.add_argument("--num_output_shards",
default=None,
type=int,
required=True,
help="number of shards to be created. shards will be created as even as possible.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="directory for meta files")
parser.add_argument('--random_seed',
type=int,
default=12345,
help="random seed for initialization")
args = parser.parse_args()
rng = random.Random(args.random_seed)
np.random.seed(args.random_seed)
input_paths = args.input_files.strip().split(',')
input_paths = [f for f in input_paths if f]
input_files = []
for path in input_paths:
if os.path.isfile(path):
assert (path.endswith('.hdf5')), "file must be hdf5 file"
input_files.append(path)
else:
assert os.path.isdir(path)
hdf5_files = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and f.endswith('.hdf5')]
input_files.extend(hdf5_files)
input_files.sort()
assert(os.path.isdir(args.output_dir))
print("load and prepare")
seq_len, pred_len, input_lengths, output_inst_indices = load_and_prepare(input_files, args.num_output_shards)
print("preparing lookup table")
total_num_instances = sum(input_lengths)
out_2_in = dict()
length_so_far = 0
for i, l in enumerate(input_lengths):
for j in range(l):
out_2_in[length_so_far + j] = (i, j)
length_so_far += input_lengths[i]
output_files = [os.path.join(args.output_dir, "indices_" + str(i) + ".npy") for i in range(args.num_output_shards)]
print("save data")
with open(os.path.join(args.output_dir, 'lookup_table.pkl'), 'wb') as f:
pickle.dump(out_2_in, f)
for i, out_file in enumerate(output_files):
np.save(out_file, output_inst_indices[i])
meta = {'seq_len': seq_len, 'pred_len':pred_len}
with open(os.path.join(args.output_dir, 'meta_data.pkl'), 'wb') as f:
pickle.dump(meta, f)
if __name__ == "__main__":
main()