DeepLearningExamples/DGLPyTorch/DrugDiscovery/RoseTTAFold/network/utils/utils_logging.py
2021-10-15 15:46:41 +02:00

124 lines
3.8 KiB
Python

import os
import sys
import time
import datetime
import subprocess
import numpy as np
import torch
from utils.utils_data import to_np
_global_log = {}
def try_mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
# @profile
def make_logdir(checkpoint_dir, run_name=None):
if run_name is None:
now = datetime.datetime.now().strftime("%Y_%m_%d_%H.%M.%S")
else:
assert type(run_name) == str
now = run_name
log_dir = os.path.join(checkpoint_dir, now)
try_mkdir(log_dir)
return log_dir
def count_parameters(model):
"""
count number of trainable parameters in module
:param model: nn.Module instance
:return: integer
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
n_params = sum([np.prod(p.size()) for p in model_parameters])
return n_params
def write_info_file(model, FLAGS, UNPARSED_ARGV, wandb_log_dir=None):
time_str = time.strftime("%m%d_%H%M%S")
filename_log = "info_" + time_str + ".txt"
filename_git_diff = "git_diff_" + time_str + ".txt"
checkpoint_name = 'model'
if wandb_log_dir:
log_dir = wandb_log_dir
os.mkdir(os.path.join(log_dir, 'checkpoints'))
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
elif FLAGS.restore:
# set restore path
assert FLAGS.run_name is not None
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.run_name)
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
else:
# makes logdir with time stamp
log_dir = make_logdir(FLAGS.checkpoint_dir, FLAGS.run_name)
os.mkdir(os.path.join(log_dir, 'checkpoints'))
os.mkdir(os.path.join(log_dir, 'point_clouds'))
# os.mkdir(os.path.join(log_dir, 'train_log'))
# os.mkdir(os.path.join(log_dir, 'test_log'))
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
# writing arguments and git hash to info file
file = open(os.path.join(log_dir, filename_log), "w")
label = subprocess.check_output(["git", "describe", "--always"]).strip()
file.write('latest git commit on this branch: ' + str(label) + '\n')
file.write('\nFLAGS: \n')
for key in sorted(vars(FLAGS)):
file.write(key + ': ' + str(vars(FLAGS)[key]) + '\n')
# count number of parameters
if hasattr(model, 'parameters'):
file.write('\nNumber of Model Parameters: ' + str(count_parameters(model)) + '\n')
if hasattr(model, 'enc'):
file.write('\nNumber of Encoder Parameters: ' + str(
count_parameters(model.enc)) + '\n')
if hasattr(model, 'dec'):
file.write('\nNumber of Decoder Parameters: ' + str(
count_parameters(model.dec)) + '\n')
file.write('\nUNPARSED_ARGV:\n' + str(UNPARSED_ARGV))
file.write('\n\nBASH COMMAND: \n')
bash_command = 'python'
for argument in sys.argv:
bash_command += (' ' + argument)
file.write(bash_command)
file.close()
# write 'git diff' output into extra file
subprocess.call(["git diff > " + os.path.join(log_dir, filename_git_diff)], shell=True)
return log_dir, checkpoint_path
def log_gradient_norm(tensor, variable_name):
if variable_name not in _global_log:
_global_log[variable_name] = []
def log_gradient_norm_inner(gradient):
gradient_norm = torch.norm(gradient, dim=-1)
_global_log[variable_name].append(to_np(gradient_norm))
tensor.register_hook(log_gradient_norm_inner)
def get_average(variable_name):
if variable_name not in _global_log:
return float('nan')
elif _global_log[variable_name]:
overall_tensor = np.concatenate(_global_log[variable_name])
return np.mean(overall_tensor)
else:
return 0
def clear_data(variable_name):
_global_log[variable_name] = []