DeepLearningExamples/TensorFlow2/LanguageModeling/ELECTRA/build_pretraining_dataset.py
2020-07-20 20:39:35 +02:00

234 lines
No EOL
8.8 KiB
Python

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Writes out text data as tfrecords that ELECTRA can be pre-trained on."""
import argparse
import multiprocessing
import os
import random
import time
import tensorflow.compat.v1 as tf
import utils
from tokenization import ElectraTokenizer
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
class ExampleBuilder(object):
"""Given a stream of input text, creates pretraining examples."""
def __init__(self, tokenizer, max_length):
self._tokenizer = tokenizer
self._current_sentences = []
self._current_length = 0
self._max_length = max_length
self._target_length = max_length
def add_line(self, line):
"""Adds a line of text to the current example being built."""
line = line.strip().replace("\n", " ")
if (not line) and self._current_length != 0: # empty lines separate docs
return self._create_example()
bert_tokens = self._tokenizer.tokenize(line)
bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens)
self._current_sentences.append(bert_tokids)
self._current_length += len(bert_tokids)
if self._current_length >= self._target_length:
return self._create_example()
return None
def _create_example(self):
"""Creates a pre-training example from the current list of sentences."""
# small chance to only have one segment as in classification tasks
if random.random() < 0.1:
first_segment_target_length = 100000
else:
# -3 due to not yet having [CLS]/[SEP] tokens in the input text
first_segment_target_length = (self._target_length - 3) // 2
first_segment = []
second_segment = []
for sentence in self._current_sentences:
# the sentence goes to the first segment if (1) the first segment is
# empty, (2) the sentence doesn't put the first segment over length or
# (3) 50% of the time when it does put the first segment over length
if (len(first_segment) == 0 or
len(first_segment) + len(sentence) < first_segment_target_length or
(len(second_segment) == 0 and
len(first_segment) < first_segment_target_length and
random.random() < 0.5)):
first_segment += sentence
else:
second_segment += sentence
# trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
first_segment = first_segment[:self._max_length - 2]
second_segment = second_segment[:max(0, self._max_length -
len(first_segment) - 3)]
# prepare to start building the next example
self._current_sentences = []
self._current_length = 0
# small chance for random-length instead of max_length-length example
if random.random() < 0.05:
self._target_length = random.randint(5, self._max_length)
else:
self._target_length = self._max_length
return self._make_tf_example(first_segment, second_segment)
def _make_tf_example(self, first_segment, second_segment):
"""Converts two "segments" of text into a tf.train.Example."""
vocab = self._tokenizer.vocab
input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]]
segment_ids = [0] * len(input_ids)
if second_segment:
input_ids += second_segment + [vocab["[SEP]"]]
segment_ids += [1] * (len(second_segment) + 1)
input_mask = [1] * len(input_ids)
input_ids += [0] * (self._max_length - len(input_ids))
input_mask += [0] * (self._max_length - len(input_mask))
segment_ids += [0] * (self._max_length - len(segment_ids))
tf_example = tf.train.Example(features=tf.train.Features(feature={
"input_ids": create_int_feature(input_ids),
"input_mask": create_int_feature(input_mask),
"segment_ids": create_int_feature(segment_ids)
}))
return tf_example
class ExampleWriter(object):
"""Writes pre-training examples to disk."""
def __init__(self, job_id, vocab_file, output_dir, max_seq_length,
num_jobs, blanks_separate_docs, do_lower_case,
num_out_files=1000):
self._blanks_separate_docs = blanks_separate_docs
tokenizer = ElectraTokenizer(
vocab_file=vocab_file,
do_lower_case=do_lower_case)
self._example_builder = ExampleBuilder(tokenizer, max_seq_length)
self._writers = []
for i in range(num_out_files):
if i % num_jobs == job_id:
output_fname = os.path.join(
output_dir, "pretrain_data.tfrecord-{:}-of-{:}".format(
i, num_out_files))
self._writers.append(tf.io.TFRecordWriter(output_fname))
self.n_written = 0
def write_examples(self, input_file):
"""Writes out examples from the provided input file."""
with tf.io.gfile.GFile(input_file) as f:
for line in f:
line = line.strip()
if line or self._blanks_separate_docs:
example = self._example_builder.add_line(line)
if example:
self._writers[self.n_written % len(self._writers)].write(
example.SerializeToString())
self.n_written += 1
example = self._example_builder.add_line("")
if example:
self._writers[self.n_written % len(self._writers)].write(
example.SerializeToString())
self.n_written += 1
def finish(self):
for writer in self._writers:
writer.close()
def write_examples(job_id, args):
"""A single process creating and writing out pre-processed examples."""
def log(*args):
msg = " ".join(map(str, args))
print("Job {}:".format(job_id), msg)
log("Creating example writer")
example_writer = ExampleWriter(
job_id=job_id,
vocab_file=args.vocab_file,
output_dir=args.output_dir,
max_seq_length=args.max_seq_length,
num_jobs=args.num_processes,
blanks_separate_docs=args.blanks_separate_docs,
do_lower_case=args.do_lower_case,
num_out_files=args.num_out_files,
)
log("Writing tf examples")
fnames = sorted(tf.io.gfile.listdir(args.corpus_dir))
fnames = [f for (i, f) in enumerate(fnames)
if i % args.num_processes == job_id]
random.shuffle(fnames)
start_time = time.time()
for file_no, fname in enumerate(fnames):
if file_no > 0:
elapsed = time.time() - start_time
log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, "
"{:} examples written".format(
file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed),
int((len(fnames) - file_no) / (file_no / elapsed)),
example_writer.n_written))
example_writer.write_examples(os.path.join(args.corpus_dir, fname))
example_writer.finish()
log("Done!")
# python build_pretraining_dataset --corpus-dir
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--corpus-dir", required=True,
help="Location of pre-training text files.")
parser.add_argument("--vocab-file", required=True,
help="Location of vocabulary file.")
parser.add_argument("--output-dir", required=True,
help="Where to write out the tfrecords.")
parser.add_argument("--max-seq-length", default=128, type=int,
help="Number of tokens per example.")
parser.add_argument("--num-processes", default=1, type=int,
help="Parallelize across multiple processes.")
parser.add_argument("--blanks-separate-docs", default=True, type=bool,
help="Whether blank lines indicate document boundaries.")
parser.add_argument("--do-lower-case", dest='do_lower_case',
action='store_true', help="Lower case input text.")
parser.add_argument("--no-lower-case", dest='do_lower_case',
action='store_false', help="Don't lower case input text.")
parser.add_argument("--num-out-files", default=1000, type=int,
help="Number of output files.")
parser.set_defaults(do_lower_case=True)
args = parser.parse_args()
utils.rmkdir(args.output_dir)
if args.num_processes == 1:
write_examples(0, args)
else:
jobs = []
for i in range(args.num_processes):
job = multiprocessing.Process(target=write_examples, args=(i, args))
jobs.append(job)
job.start()
for job in jobs:
job.join()
if __name__ == "__main__":
main()