DeepLearningExamples/PyTorch/Translation/Transformer/scripts/draw_summary.py
2020-07-04 02:28:25 +02:00

134 lines
5.3 KiB
Python

import json
import argparse
from collections import defaultdict, OrderedDict
import matplotlib.pyplot as plt
import numpy as np
def smooth_moving_average(x, n):
fil = np.ones(n)/n
smoothed = np.convolve(x, fil, mode='valid')
smoothed = np.concatenate((x[:n-1], smoothed), axis=0)
return smoothed
def moving_stdev(x, n):
fil = np.ones(n)/n
avg_sqare = np.convolve(np.power(x, 2), fil, mode='valid')
squared_avg = np.power(np.convolve(x, fil, mode='valid'), 2)
var = avg_sqare - squared_avg
stdev = np.sqrt(var)
#pad first few values
stdev = np.concatenate(([0]*(n-1), stdev), axis=0)
return stdev
def get_plot(log):
steps = [x[0] for x in log if isinstance(x[0], int)]
values = [x[2] for x in log if isinstance(x[0], int)]
return steps, values
def highlight_max_point(plot, color):
point = max(zip(*plot), key=lambda x: x[1])
plt.plot(point[0], point[1], 'bo-', color=color)
plt.annotate("{:.2f}".format(point[1]), point)
return point
def main(args):
jlog = defaultdict(list)
jlog['parameters'] = {}
with open(args.log_file, 'r') as f:
for line in f.readlines():
line_dict = json.loads(line[5:])
if line_dict['type'] == 'LOG':
if line_dict['step'] == 'PARAMETER':
jlog['parameters'].update(line_dict['data'])
elif line_dict['step'] == [] and 'training_summary' not in jlog:
jlog['training_summary']=line_dict['data']
else:
for k, v in line_dict['data'].items():
jlog[k].append((line_dict['step'], line_dict['elapsedtime'], v))
fig, ax1 = plt.subplots(figsize=(20,5))
fig.suptitle(args.title, fontsize=16)
ax1.set_xlabel('steps')
ax1.set_ylabel('loss')
# Define colors for specific curves
VAL_LOSS_COLOR = 'blue'
VAL_BLEU_COLOR = 'red'
TEST_BLEU_COLOR = 'pink'
# Plot smoothed loss curve
steps, loss = get_plot(jlog['loss'])
smoothed_loss = smooth_moving_average(loss, 150)
stdev = moving_stdev(loss, 150)
ax1.plot(steps, smoothed_loss, label='Training loss')
ax1.plot(steps, smoothed_loss + stdev, '--', color='orange', linewidth=0.3, label='Stdev')
ax1.plot(steps, smoothed_loss - stdev, '--', color='orange', linewidth=0.3)
# Plot validation loss curve
val_steps, val_loss = get_plot(jlog['val_loss'])
ax1.plot(val_steps, val_loss, color='blue', label='Validation loss')
min_val_loss_step = val_steps[np.argmin(val_loss)]
ax1.axvline(min_val_loss_step, linestyle='dashed', color=VAL_LOSS_COLOR, linewidth=0.5, label='Validation loss minimum')
# Plot BLEU curves
ax2 = ax1.twinx()
ax2.set_ylabel('BLEU')
val_steps, val_bleu = get_plot(jlog['val_bleu'])
ax2.plot(val_steps, val_bleu, color=VAL_BLEU_COLOR, label='Validation BLEU')
mvb_step, _ =highlight_max_point((val_steps,val_bleu), color=VAL_BLEU_COLOR)
# values to be labeled on plot
max_val_bleu_step = val_steps[np.argmax(val_bleu)]
max_val_bleu = val_bleu[val_steps.index(max_val_bleu_step)]
min_loss_bleu = val_bleu[val_steps.index(min_val_loss_step)]
if 'test_bleu' in jlog:
test_steps, test_bleu = get_plot(jlog['test_bleu'])
ax2.plot(val_steps, test_bleu, color=TEST_BLEU_COLOR, label='Test BLEU')
highlight_max_point((test_steps, test_bleu), color=TEST_BLEU_COLOR)
ax2.tick_params(axis='y')
# Annotate points with highest BLEU score as well as those for minimal validation loss
ax2.plot(min_val_loss_step, min_loss_bleu, 'bo-', color=VAL_BLEU_COLOR)
ax2.annotate("{:.2f}".format(min_loss_bleu), (min_val_loss_step, min_loss_bleu))
if 'test_bleu' in jlog:
min_loss_test_bleu = test_bleu[val_steps.index(min_val_loss_step)] #BLEU score on test set when validation loss is minimal
ax2.plot(min_val_loss_step, min_loss_test_bleu, 'bo-', color=TEST_BLEU_COLOR)
ax2.annotate("{:.2f}".format(min_loss_test_bleu), (min_val_loss_step, min_loss_test_bleu))
max_val_bleu_test = test_bleu[val_steps.index(max_val_bleu_step)] #BLEU score on test set when BLEU score on dev set is maximal
ax2.plot(mvb_step, max_val_bleu_test, 'bo-', color=TEST_BLEU_COLOR)
ax2.annotate("{:.2f}".format(max_val_bleu_test), (max_val_bleu_step, max_val_bleu_test))
ax1.legend(loc='lower left', bbox_to_anchor=(1,0))
ax2.legend(loc='upper left', bbox_to_anchor=(1,1))
plt.grid()
plt.savefig(args.output)
# Produce json with training summary
if args.dump_json:
summary = OrderedDict()
summary['args'] = OrderedDict(jlog['parameters'])
summary['min_val_loss'] = min(val_loss)
summary['max_val_bleu'] = max(val_bleu)
summary['max_test_bleu'] = max(test_bleu)
summary['final_values'] = jlog['training_summary']
summary['avg_epoch_loss'] = [x.mean() for x in np.array_split(np.array(loss), jlog['parameters']['max_epoch'])]
summary['min_val_loss_step'] = min_val_loss_step
json.dump(summary, open(args.dump_json, 'w'))
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--title', type=str)
parser.add_argument('--log-file', type=str)
parser.add_argument('--output' ,'-o', type=str)
parser.add_argument('--dump-json', '-j', type=str)
args = parser.parse_args()
main(args)