This commit is contained in:
Chang Lan 2021-11-04 15:01:04 +01:00 committed by GitHub
commit 24aa6c86f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 52 additions and 65 deletions

View file

@ -27,7 +27,7 @@ import modeling
import tokenization
import tensorflow as tf
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
@ -416,4 +416,4 @@ if __name__ == "__main__":
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("init_checkpoint")
flags.mark_flag_as_required("output_file")
tf.app.run()
tf.compat.v1.app.run()

View file

@ -37,7 +37,7 @@ from utils.create_glue_data import *
import numpy as np
import tf_metrics
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS

View file

@ -25,7 +25,7 @@ import tokenization
import tensorflow as tf
import tensorflow_hub as hub
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
@ -311,4 +311,4 @@ if __name__ == "__main__":
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("bert_hub_module_handle")
flags.mark_flag_as_required("output_dir")
tf.app.run()
tf.compat.v1.app.run()

View file

@ -30,7 +30,7 @@ from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
@ -122,7 +122,7 @@ flags.DEFINE_integer(
"iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")

View file

@ -33,7 +33,7 @@ from dllogger import Verbosity
from tensorflow.core.protobuf import rewriter_config_pb2
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS

View file

@ -40,7 +40,7 @@ from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
@ -114,7 +114,7 @@ flags.DEFINE_integer("save_checkpoints_steps", 1000,
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")

View file

@ -41,7 +41,7 @@ from utils.gpu_affinity import set_affinity
import utils.dllogger_class
from dllogger import Verbosity
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = None
def extract_run_squad_flags():
@ -426,9 +426,10 @@ def input_fn_builder(input_file, batch_size, seq_length, is_training, drop_remai
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = tf.data.TFRecordDataset(input_file, num_parallel_reads=4)
if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
d = d.apply(tf.data.experimental.ignore_errors())
d = d.shuffle(buffer_size=100)
d = d.prefetch(tf.data.experimental.AUTOTUNE)
if hvd is not None:
d = d.shard(hvd.size(), hvd.rank())
d = d.shuffle(buffer_size=128)
d = d.repeat()
else:
d = tf.data.TFRecordDataset(input_file)
@ -438,7 +439,8 @@ def input_fn_builder(input_file, batch_size, seq_length, is_training, drop_remai
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
drop_remainder=drop_remainder,
num_parallel_calls=tf.data.experimental.AUTOTUNE))
return d
@ -992,11 +994,14 @@ def main(_):
if FLAGS.amp:
tf.enable_resource_variables()
save_checkpoints_steps = FLAGS.save_checkpoints_steps if master_process and FLAGS.save_checkpoints_steps else None
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir if master_process else None,
session_config=config,
save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
save_summary_steps=FLAGS.save_checkpoints_steps if master_process else None,
save_checkpoints_steps=save_checkpoints_steps,
save_checkpoints_secs =None,
save_summary_steps=save_checkpoints_steps,
log_step_count_steps=FLAGS.display_loss_steps,
keep_checkpoint_max=1)
@ -1020,27 +1025,6 @@ def main(_):
len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
# Pre-shuffle the input to avoid having to make a very large shuffle
# buffer in in the `input_fn`.
rng = random.Random(12345)
rng.shuffle(train_examples)
start_index = 0
end_index = len(train_examples)
tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
if FLAGS.horovod:
tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
num_examples_per_rank = len(train_examples) // hvd.size()
remainder = len(train_examples) % hvd.size()
if hvd.rank() < remainder:
start_index = hvd.rank() * (num_examples_per_rank+1)
end_index = start_index + num_examples_per_rank + 1
else:
start_index = hvd.rank() * num_examples_per_rank + remainder
end_index = start_index + (num_examples_per_rank)
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
@ -1055,33 +1039,37 @@ def main(_):
config=run_config)
if FLAGS.do_train:
# We write to a temporary file to avoid storing very large constant tensors
# in memory.
train_writer = FeatureWriter(
filename=tmp_filenames[hvd_rank],
is_training=True)
convert_examples_to_features(
examples=train_examples[start_index:end_index],
tokenizer=tokenizer,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=True,
output_fn=train_writer.process_feature,
verbose_logging=FLAGS.verbose_logging)
train_writer.close()
tmp_filename = os.path.join(FLAGS.output_dir, "train.tf_record")
if hvd.local_rank() == 0:
train_writer = FeatureWriter(
filename=tmp_filename,
is_training=True)
convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=True,
output_fn=train_writer.process_feature,
verbose_logging=FLAGS.verbose_logging)
train_writer.close()
tf.compat.v1.logging.info("***** Running training *****")
tf.compat.v1.logging.info(" Num orig examples = %d", len(train_examples))
tf.compat.v1.logging.info(" Num split examples = %d", train_writer.num_features)
tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)
tf.compat.v1.logging.info(" LR = %f", learning_rate)
tf.compat.v1.logging.info("***** Running training *****")
tf.compat.v1.logging.info(" Num orig examples = %d", end_index - start_index)
tf.compat.v1.logging.info(" Num split examples = %d", train_writer.num_features)
tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)
tf.compat.v1.logging.info(" LR = %f", learning_rate)
del train_examples
train_input_fn = input_fn_builder(
input_file=tmp_filenames,
input_file=tmp_filename,
batch_size=FLAGS.train_batch_size,
seq_length=FLAGS.max_seq_length,
is_training=True,
@ -1227,4 +1215,4 @@ def main(_):
if __name__ == "__main__":
FLAGS = extract_run_squad_flags()
tf.app.run()
tf.compat.v1.app.run()

View file

@ -28,7 +28,7 @@ else:
import Queue as queue
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = flags.FLAGS
## Required parameters

View file

@ -29,7 +29,7 @@ import horovod.tensorflow as hvd
import time
import csv
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = None
def extract_flags():

View file

@ -28,7 +28,7 @@ import tensorflow as tf
import horovod.tensorflow as hvd
import time
flags = tf.flags
flags = tf.compat.v1.flags
FLAGS = None
def extract_flags():
@ -509,7 +509,6 @@ class FeatureWriter(object):
self._writer.close()
def main():
FLAGS = extract_flags()
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

View file

@ -71,7 +71,7 @@ class LogTrainRunHook(tf.estimator.SessionRunHook):
self.count += 1
# Removing first 100 step + first five steps after every checkpoint save
if (self.global_step - self.init_global_step) <= self.num_steps_ignore_xla or (self.global_step - self.init_global_step) % self.save_checkpoints_steps < 5:
if (self.global_step - self.init_global_step) <= self.num_steps_ignore_xla or (self.save_checkpoints_steps > 0 and (self.global_step - self.init_global_step) % self.save_checkpoints_steps < 5):
print("Skipping time record for ", self.global_step, " due to checkpoint-saving/warmup overhead")
self.skipped += 1
else: