174 lines
5.7 KiB
Python
174 lines
5.7 KiB
Python
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import atexit
|
|
import glob
|
|
import os
|
|
import re
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
import dllogger
|
|
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
|
|
|
|
|
tb_loggers = {}
|
|
|
|
|
|
class TBLogger:
|
|
"""
|
|
xyz_dummies: stretch the screen with empty plots so the legend would
|
|
always fit for other plots
|
|
"""
|
|
def __init__(self, enabled, log_dir, name, interval=1, dummies=True):
|
|
self.enabled = enabled
|
|
self.interval = interval
|
|
self.cache = {}
|
|
if self.enabled:
|
|
self.summary_writer = SummaryWriter(
|
|
log_dir=os.path.join(log_dir, name),
|
|
flush_secs=120, max_queue=200)
|
|
atexit.register(self.summary_writer.close)
|
|
if dummies:
|
|
for key in ('aaa', 'zzz'):
|
|
self.summary_writer.add_scalar(key, 0.0, 1)
|
|
|
|
def log(self, step, data):
|
|
for k, v in data.items():
|
|
self.log_value(step, k, v.item() if type(v) is torch.Tensor else v)
|
|
|
|
def log_value(self, step, key, val, stat='mean'):
|
|
if self.enabled:
|
|
if key not in self.cache:
|
|
self.cache[key] = []
|
|
self.cache[key].append(val)
|
|
if len(self.cache[key]) == self.interval:
|
|
agg_val = getattr(np, stat)(self.cache[key])
|
|
self.summary_writer.add_scalar(key, agg_val, step)
|
|
del self.cache[key]
|
|
|
|
def log_grads(self, step, model):
|
|
if self.enabled:
|
|
norms = [p.grad.norm().item() for p in model.parameters()
|
|
if p.grad is not None]
|
|
for stat in ('max', 'min', 'mean'):
|
|
self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
|
|
stat=stat)
|
|
|
|
|
|
def unique_log_fpath(log_fpath):
|
|
|
|
if not os.path.isfile(log_fpath):
|
|
return log_fpath
|
|
|
|
# Avoid overwriting old logs
|
|
saved = sorted([int(re.search('\.(\d+)', f).group(1))
|
|
for f in glob.glob(f'{log_fpath}.*')])
|
|
|
|
log_num = (saved[-1] if saved else 0) + 1
|
|
return f'{log_fpath}.{log_num}'
|
|
|
|
|
|
def stdout_step_format(step):
|
|
if isinstance(step, str):
|
|
return step
|
|
fields = []
|
|
if len(step) > 0:
|
|
fields.append("epoch {:>4}".format(step[0]))
|
|
if len(step) > 1:
|
|
fields.append("iter {:>4}".format(step[1]))
|
|
if len(step) > 2:
|
|
fields[-1] += "/{}".format(step[2])
|
|
return " | ".join(fields)
|
|
|
|
|
|
def stdout_metric_format(metric, metadata, value):
|
|
name = metadata.get("name", metric + " : ")
|
|
unit = metadata.get("unit", None)
|
|
format = f'{{{metadata.get("format", "")}}}'
|
|
fields = [name, format.format(value) if value is not None else value, unit]
|
|
fields = [f for f in fields if f is not None]
|
|
return "| " + " ".join(fields)
|
|
|
|
|
|
def init_log(args):
|
|
enabled = (args.local_rank == 0)
|
|
if enabled:
|
|
fpath = args.log_file or os.path.join(args.output_dir, 'nvlog.json')
|
|
backends = [JSONStreamBackend(Verbosity.DEFAULT,
|
|
unique_log_fpath(fpath)),
|
|
StdOutBackend(Verbosity.VERBOSE,
|
|
step_format=stdout_step_format,
|
|
metric_format=stdout_metric_format)]
|
|
else:
|
|
backends = []
|
|
|
|
dllogger.init(backends=backends)
|
|
dllogger.metadata("train_lrate", {"name": "lrate", "format": ":>3.2e"})
|
|
|
|
for id_, pref in [('train', ''), ('train_avg', 'avg train '),
|
|
('dev', ' avg dev '), ('dev_ema', ' EMA dev ')]:
|
|
|
|
dllogger.metadata(f"{id_}_loss",
|
|
{"name": f"{pref}loss", "format": ":>7.2f"})
|
|
|
|
dllogger.metadata(f"{id_}_wer",
|
|
{"name": f"{pref}wer", "format": ":>6.2f"})
|
|
|
|
dllogger.metadata(f"{id_}_throughput",
|
|
{"name": f"{pref}utts/s", "format": ":>5.0f"})
|
|
|
|
dllogger.metadata(f"{id_}_took",
|
|
{"name": "took", "unit": "s", "format": ":>5.2f"})
|
|
|
|
tb_subsets = ['train', 'dev', 'dev_ema'] if args.ema else ['train', 'dev']
|
|
global tb_loggers
|
|
tb_loggers = {s: TBLogger(enabled, args.output_dir, name=s)
|
|
for s in tb_subsets}
|
|
|
|
log_parameters(vars(args), tb_subset='train')
|
|
|
|
|
|
def log(step, tb_total_steps=None, subset='train', data={}):
|
|
|
|
if tb_total_steps is not None:
|
|
tb_loggers[subset].log(tb_total_steps, data)
|
|
|
|
if subset != '':
|
|
data = {f'{subset}_{key}': val for key, val in data.items()}
|
|
dllogger.log(step, data=data)
|
|
|
|
|
|
def log_grads_tb(tb_total_steps, grads, tb_subset='train'):
|
|
tb_loggers[tb_subset].log_grads(tb_total_steps, grads)
|
|
|
|
|
|
def log_parameters(data, verbosity=0, tb_subset=None):
|
|
for k, v in data.items():
|
|
dllogger.log(step="PARAMETER", data={k: v}, verbosity=verbosity)
|
|
|
|
if tb_subset is not None and tb_loggers[tb_subset].enabled:
|
|
tb_data = {k: v for k, v in data.items()
|
|
if type(v) in (str, bool, int, float)}
|
|
tb_loggers[tb_subset].summary_writer.add_hparams(tb_data, {})
|
|
|
|
|
|
def flush_log():
|
|
dllogger.flush()
|
|
for tbl in tb_loggers.values():
|
|
if tbl.enabled:
|
|
tbl.summary_writer.flush()
|