[DLRM/PyT] Fixed logging
This commit is contained in:
parent
72a90514f2
commit
99ab780448
|
@ -57,13 +57,12 @@ flags.DEFINE_boolean("Adam_MLP_optimizer", False, "Swaps MLP optimizer to Adam")
|
|||
def main(argv):
|
||||
torch.manual_seed(FLAGS.seed)
|
||||
|
||||
utils.init_logging(log_path=FLAGS.log_path)
|
||||
|
||||
use_gpu = "cpu" not in FLAGS.base_device.lower()
|
||||
rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu)
|
||||
device = FLAGS.base_device
|
||||
|
||||
if is_main_process():
|
||||
utils.init_logging(log_path=FLAGS.log_path)
|
||||
dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')
|
||||
|
||||
print("Command line flags:")
|
||||
|
@ -180,7 +179,8 @@ def main(argv):
|
|||
auc = dist_evaluate(model, data_loader_test)
|
||||
|
||||
results = {'auc': auc}
|
||||
dllogger.log(data=results, step=tuple())
|
||||
if is_main_process():
|
||||
dllogger.log(data=results, step=tuple())
|
||||
|
||||
if auc is not None:
|
||||
print(f"Finished testing. Test auc {auc:.4f}")
|
||||
|
@ -224,7 +224,6 @@ def main(argv):
|
|||
best_auc = 0
|
||||
best_epoch = 0
|
||||
start_time = time()
|
||||
stop_time = time()
|
||||
|
||||
for epoch in range(FLAGS.epochs):
|
||||
epoch_start_time = time()
|
||||
|
@ -299,7 +298,6 @@ def main(argv):
|
|||
step_time=timer.measured,
|
||||
loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1),
|
||||
lr=mlp_optimizer.param_groups[0]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1))
|
||||
stop_time = time()
|
||||
|
||||
eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
|
||||
metric_logger.print(header=f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}] eta: {eta_str}")
|
||||
|
@ -329,19 +327,19 @@ def main(argv):
|
|||
|
||||
epoch_stop_time = time()
|
||||
epoch_time_s = epoch_stop_time - epoch_start_time
|
||||
print(f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
|
||||
f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s.")
|
||||
print(f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. ")
|
||||
|
||||
avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg
|
||||
|
||||
if FLAGS.save_checkpoint_path:
|
||||
checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step)
|
||||
|
||||
results = {'best_auc': best_auc,
|
||||
'best_epoch': best_epoch,
|
||||
'average_train_throughput': avg_throughput}
|
||||
if is_main_process():
|
||||
results = {'best_auc': best_auc,
|
||||
'best_epoch': best_epoch,
|
||||
'average_train_throughput': avg_throughput}
|
||||
|
||||
dllogger.log(data=results, step=tuple())
|
||||
dllogger.log(data=results, step=tuple())
|
||||
|
||||
def scale_MLP_gradients(mlp_optimizer: torch.optim.Optimizer, world_size: int):
|
||||
for param_group in mlp_optimizer.param_groups[1:]: # Omitting top MLP
|
||||
|
|
Loading…
Reference in a new issue