2019-03-18 19:45:05 +01:00
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import os
import time
import argparse
import numpy as np
from contextlib import contextmanager
import torch
from torch . utils . data import DataLoader
from torch . autograd import Variable
from torch . nn . parameter import Parameter
import torch . distributed as dist
from torch . utils . data . distributed import DistributedSampler
2021-11-08 23:23:50 +01:00
from torch . nn . parallel import DistributedDataParallel as DDP
2019-03-18 19:45:05 +01:00
import models
import loss_functions
import data_functions
2021-05-18 18:28:11 +02:00
from tacotron2_common . utils import ParseFromConfigFile
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
import dllogger as DLLogger
from dllogger import StdOutBackend , JSONStreamBackend , Verbosity
2019-03-18 19:45:05 +01:00
from scipy . io . wavfile import write as write_wav
def parse_args ( parser ) :
"""
Parse commandline arguments .
"""
2019-12-16 18:29:08 +01:00
parser . add_argument ( ' -o ' , ' --output ' , type = str , required = True ,
2019-03-18 19:45:05 +01:00
help = ' Directory to save checkpoints ' )
parser . add_argument ( ' -d ' , ' --dataset-path ' , type = str ,
default = ' ./ ' , help = ' Path to dataset ' )
parser . add_argument ( ' -m ' , ' --model-name ' , type = str , default = ' ' , required = True ,
help = ' Model to train ' )
parser . add_argument ( ' --log-file ' , type = str , default = ' nvlog.json ' ,
help = ' Filename for logging ' )
parser . add_argument ( ' --anneal-steps ' , nargs = ' * ' ,
help = ' Epochs after which decrease learning rate ' )
parser . add_argument ( ' --anneal-factor ' , type = float , choices = [ 0.1 , 0.3 ] , default = 0.1 ,
help = ' Factor for annealing learning rate ' )
2020-09-07 21:16:02 +02:00
parser . add_argument ( ' --config-file ' , action = ParseFromConfigFile ,
type = str , help = ' Path to configuration file ' )
2019-03-18 19:45:05 +01:00
# training
training = parser . add_argument_group ( ' training setup ' )
2019-07-08 22:51:28 +02:00
training . add_argument ( ' --epochs ' , type = int , required = True ,
2019-03-18 19:45:05 +01:00
help = ' Number of total epochs to run ' )
2019-07-08 22:51:28 +02:00
training . add_argument ( ' --epochs-per-checkpoint ' , type = int , default = 50 ,
2019-03-18 19:45:05 +01:00
help = ' Number of epochs per checkpoint ' )
2019-11-06 19:46:35 +01:00
training . add_argument ( ' --checkpoint-path ' , type = str , default = ' ' ,
help = ' Checkpoint path to resume training ' )
2020-07-04 01:15:57 +02:00
training . add_argument ( ' --resume-from-last ' , action = ' store_true ' ,
help = ' Resumes training from the last checkpoint; uses the directory provided with \' --output \' option to search for the checkpoint \" checkpoint_<model_name>_last.pt \" ' )
2019-03-18 19:45:05 +01:00
training . add_argument ( ' --dynamic-loss-scaling ' , type = bool , default = True ,
help = ' Enable dynamic loss scaling ' )
2020-07-04 01:15:57 +02:00
training . add_argument ( ' --amp ' , action = ' store_true ' ,
2019-07-08 22:51:28 +02:00
help = ' Enable AMP ' )
training . add_argument ( ' --cudnn-enabled ' , action = ' store_true ' ,
2019-03-18 19:45:05 +01:00
help = ' Enable cudnn ' )
2019-07-08 22:51:28 +02:00
training . add_argument ( ' --cudnn-benchmark ' , action = ' store_true ' ,
2019-03-18 19:45:05 +01:00
help = ' Run cudnn benchmark ' )
2019-08-13 15:01:10 +02:00
training . add_argument ( ' --disable-uniform-initialize-bn-weight ' , action = ' store_true ' ,
help = ' disable uniform initialization of batchnorm layer weight ' )
2019-03-18 19:45:05 +01:00
optimization = parser . add_argument_group ( ' optimization setup ' )
optimization . add_argument (
' --use-saved-learning-rate ' , default = False , type = bool )
optimization . add_argument ( ' -lr ' , ' --learning-rate ' , type = float , required = True ,
help = ' Learing rate ' )
optimization . add_argument ( ' --weight-decay ' , default = 1e-6 , type = float ,
help = ' Weight decay ' )
optimization . add_argument ( ' --grad-clip-thresh ' , default = 1.0 , type = float ,
help = ' Clip threshold for gradients ' )
optimization . add_argument ( ' -bs ' , ' --batch-size ' , type = int , required = True ,
help = ' Batch size per GPU ' )
optimization . add_argument ( ' --grad-clip ' , default = 5.0 , type = float ,
help = ' Enables gradient clipping and sets maximum gradient norm value ' )
# dataset parameters
dataset = parser . add_argument_group ( ' dataset parameters ' )
2019-07-08 22:51:28 +02:00
dataset . add_argument ( ' --load-mel-from-disk ' , action = ' store_true ' ,
2019-03-18 19:45:05 +01:00
help = ' Loads mel spectrograms from disk instead of computing them on the fly ' )
dataset . add_argument ( ' --training-files ' ,
default = ' filelists/ljs_audio_text_train_filelist.txt ' ,
type = str , help = ' Path to training filelist ' )
dataset . add_argument ( ' --validation-files ' ,
default = ' filelists/ljs_audio_text_val_filelist.txt ' ,
type = str , help = ' Path to validation filelist ' )
dataset . add_argument ( ' --text-cleaners ' , nargs = ' * ' ,
default = [ ' english_cleaners ' ] , type = str ,
help = ' Type of text cleaners for input text ' )
# audio parameters
audio = parser . add_argument_group ( ' audio parameters ' )
audio . add_argument ( ' --max-wav-value ' , default = 32768.0 , type = float ,
help = ' Maximum audiowave value ' )
audio . add_argument ( ' --sampling-rate ' , default = 22050 , type = int ,
help = ' Sampling rate ' )
audio . add_argument ( ' --filter-length ' , default = 1024 , type = int ,
help = ' Filter length ' )
audio . add_argument ( ' --hop-length ' , default = 256 , type = int ,
help = ' Hop (stride) length ' )
audio . add_argument ( ' --win-length ' , default = 1024 , type = int ,
help = ' Window length ' )
audio . add_argument ( ' --mel-fmin ' , default = 0.0 , type = float ,
help = ' Minimum mel frequency ' )
audio . add_argument ( ' --mel-fmax ' , default = 8000.0 , type = float ,
help = ' Maximum mel frequency ' )
distributed = parser . add_argument_group ( ' distributed setup ' )
# distributed.add_argument('--distributed-run', default=True, type=bool,
# help='enable distributed run')
distributed . add_argument ( ' --rank ' , default = 0 , type = int ,
help = ' Rank of the process, do not set! Done by multiproc module ' )
distributed . add_argument ( ' --world-size ' , default = 1 , type = int ,
help = ' Number of processes, do not set! Done by multiproc module ' )
distributed . add_argument ( ' --dist-url ' , type = str , default = ' tcp://localhost:23456 ' ,
help = ' Url used to set up distributed training ' )
distributed . add_argument ( ' --group-name ' , type = str , default = ' group_name ' ,
required = False , help = ' Distributed group name ' )
distributed . add_argument ( ' --dist-backend ' , default = ' nccl ' , type = str , choices = { ' nccl ' } ,
help = ' Distributed run backend ' )
2019-12-16 18:29:08 +01:00
benchmark = parser . add_argument_group ( ' benchmark ' )
benchmark . add_argument ( ' --bench-class ' , type = str , default = ' ' )
2019-03-18 19:45:05 +01:00
return parser
def reduce_tensor ( tensor , num_gpus ) :
rt = tensor . clone ( )
dist . all_reduce ( rt , op = dist . reduce_op . SUM )
2020-10-07 13:01:21 +02:00
if rt . is_floating_point ( ) :
rt = rt / num_gpus
else :
rt = rt / / num_gpus
2019-03-18 19:45:05 +01:00
return rt
def init_distributed ( args , world_size , rank , group_name ) :
assert torch . cuda . is_available ( ) , " Distributed mode requires CUDA. "
print ( " Initializing Distributed " )
# Set cuda device so everything is done on the right GPU.
torch . cuda . set_device ( rank % torch . cuda . device_count ( ) )
# Initialize distributed communication
dist . init_process_group (
backend = args . dist_backend , init_method = args . dist_url ,
world_size = world_size , rank = rank , group_name = group_name )
print ( " Done initializing distributed " )
2020-07-04 01:15:57 +02:00
def save_checkpoint ( model , optimizer , epoch , config , amp_run , output_dir , model_name ,
local_rank , world_size ) :
random_rng_state = torch . random . get_rng_state ( ) . cuda ( )
cuda_rng_state = torch . cuda . get_rng_state ( local_rank ) . cuda ( )
random_rng_states_all = [ torch . empty_like ( random_rng_state ) for _ in range ( world_size ) ]
cuda_rng_states_all = [ torch . empty_like ( cuda_rng_state ) for _ in range ( world_size ) ]
2019-11-06 19:46:35 +01:00
2020-07-04 01:15:57 +02:00
if world_size > 1 :
dist . all_gather ( random_rng_states_all , random_rng_state )
dist . all_gather ( cuda_rng_states_all , cuda_rng_state )
else :
random_rng_states_all = [ random_rng_state ]
cuda_rng_states_all = [ cuda_rng_state ]
2019-11-06 19:46:35 +01:00
2020-07-04 01:15:57 +02:00
random_rng_states_all = torch . stack ( random_rng_states_all ) . cpu ( )
cuda_rng_states_all = torch . stack ( cuda_rng_states_all ) . cpu ( )
2019-11-06 19:46:35 +01:00
2020-07-04 01:15:57 +02:00
if local_rank == 0 :
checkpoint = { ' epoch ' : epoch ,
' cuda_rng_state_all ' : cuda_rng_states_all ,
' random_rng_states_all ' : random_rng_states_all ,
' config ' : config ,
' state_dict ' : model . state_dict ( ) ,
' optimizer ' : optimizer . state_dict ( ) }
if amp_run :
checkpoint [ ' amp ' ] = amp . state_dict ( )
checkpoint_filename = " checkpoint_ {} _ {} .pt " . format ( model_name , epoch )
2020-09-07 21:16:02 +02:00
checkpoint_path = os . path . join ( output_dir , checkpoint_filename )
2020-07-04 01:15:57 +02:00
print ( " Saving model and optimizer state at epoch {} to {} " . format (
epoch , checkpoint_path ) )
torch . save ( checkpoint , checkpoint_path )
symlink_src = checkpoint_filename
symlink_dst = os . path . join (
output_dir , " checkpoint_ {} _last.pt " . format ( model_name ) )
if os . path . exists ( symlink_dst ) and os . path . islink ( symlink_dst ) :
2020-09-07 21:16:02 +02:00
print ( " Updating symlink " , symlink_dst , " to point to " , symlink_src )
2020-07-04 01:15:57 +02:00
os . remove ( symlink_dst )
os . symlink ( symlink_src , symlink_dst )
def get_last_checkpoint_filename ( output_dir , model_name ) :
symlink = os . path . join ( output_dir , " checkpoint_ {} _last.pt " . format ( model_name ) )
if os . path . exists ( symlink ) :
2020-09-07 21:16:02 +02:00
print ( " Loading checkpoint from symlink " , symlink )
2020-07-04 01:15:57 +02:00
return os . path . join ( output_dir , os . readlink ( symlink ) )
else :
2020-09-07 21:16:02 +02:00
print ( " No last checkpoint available - starting from epoch 0 " )
2020-07-04 01:15:57 +02:00
return " "
def load_checkpoint ( model , optimizer , epoch , config , amp_run , filepath , local_rank ) :
2019-11-06 19:46:35 +01:00
checkpoint = torch . load ( filepath , map_location = ' cpu ' )
epoch [ 0 ] = checkpoint [ ' epoch ' ] + 1
2020-07-04 01:15:57 +02:00
device_id = local_rank % torch . cuda . device_count ( )
2020-04-02 17:18:26 +02:00
torch . cuda . set_rng_state ( checkpoint [ ' cuda_rng_state_all ' ] [ device_id ] )
2020-11-19 14:21:05 +01:00
if ' random_rng_states_all ' in checkpoint :
torch . random . set_rng_state ( checkpoint [ ' random_rng_states_all ' ] [ device_id ] )
elif ' random_rng_state ' in checkpoint :
torch . random . set_rng_state ( checkpoint [ ' random_rng_state ' ] )
else :
raise Exception ( " Model checkpoint must have either ' random_rng_state ' or ' random_rng_states_all ' key. " )
2019-11-06 19:46:35 +01:00
config = checkpoint [ ' config ' ]
model . load_state_dict ( checkpoint [ ' state_dict ' ] )
optimizer . load_state_dict ( checkpoint [ ' optimizer ' ] )
if amp_run :
amp . load_state_dict ( checkpoint [ ' amp ' ] )
2019-03-18 19:45:05 +01:00
2020-07-04 01:15:57 +02:00
2019-03-18 19:45:05 +01:00
# adapted from: https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
# Following snippet is licensed under MIT license
@contextmanager
def evaluating ( model ) :
''' Temporarily switch to evaluation mode. '''
istrain = model . training
try :
model . eval ( )
yield model
finally :
if istrain :
model . train ( )
2019-12-16 18:29:08 +01:00
def validate ( model , criterion , valset , epoch , batch_iter , batch_size ,
world_size , collate_fn , distributed_run , rank , batch_to_gpu ) :
2019-03-18 19:45:05 +01:00
""" Handles all the validation scoring and printing """
with evaluating ( model ) , torch . no_grad ( ) :
val_sampler = DistributedSampler ( valset ) if distributed_run else None
val_loader = DataLoader ( valset , num_workers = 1 , shuffle = False ,
sampler = val_sampler ,
batch_size = batch_size , pin_memory = False ,
collate_fn = collate_fn )
val_loss = 0.0
2020-02-28 15:36:14 +01:00
num_iters = 0
val_items_per_sec = 0.0
2019-03-18 19:45:05 +01:00
for i , batch in enumerate ( val_loader ) :
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
iter_start_time = time . perf_counter ( )
x , y , num_items = batch_to_gpu ( batch )
2019-07-08 22:51:28 +02:00
y_pred = model ( x )
loss = criterion ( y_pred , y )
2019-03-18 19:45:05 +01:00
if distributed_run :
reduced_val_loss = reduce_tensor ( loss . data , world_size ) . item ( )
2020-02-28 15:36:14 +01:00
reduced_num_items = reduce_tensor ( num_items . data , 1 ) . item ( )
2020-07-04 01:15:57 +02:00
else : #
2019-03-18 19:45:05 +01:00
reduced_val_loss = loss . item ( )
2020-02-28 15:36:14 +01:00
reduced_num_items = num_items . item ( )
2019-03-18 19:45:05 +01:00
val_loss + = reduced_val_loss
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
iter_stop_time = time . perf_counter ( )
iter_time = iter_stop_time - iter_start_time
items_per_sec = reduced_num_items / iter_time
DLLogger . log ( step = ( epoch , batch_iter , i ) , data = { ' val_items_per_sec ' : items_per_sec } )
val_items_per_sec + = items_per_sec
num_iters + = 1
val_loss = val_loss / ( i + 1 )
DLLogger . log ( step = ( epoch , ) , data = { ' val_loss ' : val_loss } )
DLLogger . log ( step = ( epoch , ) , data = { ' val_items_per_sec ' :
( val_items_per_sec / num_iters if num_iters > 0 else 0.0 ) } )
2020-10-07 13:01:21 +02:00
return val_loss , val_items_per_sec
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
def adjust_learning_rate ( iteration , epoch , optimizer , learning_rate ,
anneal_steps , anneal_factor , rank ) :
2019-03-18 19:45:05 +01:00
p = 0
if anneal_steps is not None :
for i , a_step in enumerate ( anneal_steps ) :
if epoch > = int ( a_step ) :
p = p + 1
if anneal_factor == 0.3 :
lr = learning_rate * ( ( 0.1 * * ( p / / 2 ) ) * ( 1.0 if p % 2 == 0 else 0.3 ) )
else :
lr = learning_rate * ( anneal_factor * * p )
if optimizer . param_groups [ 0 ] [ ' lr ' ] != lr :
2019-12-16 18:29:08 +01:00
DLLogger . log ( step = ( epoch , iteration ) , data = { ' learning_rate changed ' : str ( optimizer . param_groups [ 0 ] [ ' lr ' ] ) + " -> " + str ( lr ) } )
2019-03-18 19:45:05 +01:00
for param_group in optimizer . param_groups :
param_group [ ' lr ' ] = lr
def main ( ) :
parser = argparse . ArgumentParser ( description = ' PyTorch Tacotron 2 Training ' )
parser = parse_args ( parser )
args , _ = parser . parse_known_args ( )
2019-12-16 18:29:08 +01:00
if ' LOCAL_RANK ' in os . environ and ' WORLD_SIZE ' in os . environ :
local_rank = int ( os . environ [ ' LOCAL_RANK ' ] )
world_size = int ( os . environ [ ' WORLD_SIZE ' ] )
else :
local_rank = args . rank
world_size = args . world_size
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
distributed_run = world_size > 1
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
if local_rank == 0 :
2020-09-07 21:16:02 +02:00
log_file = os . path . join ( args . output , args . log_file )
DLLogger . init ( backends = [ JSONStreamBackend ( Verbosity . DEFAULT , log_file ) ,
2019-12-16 18:29:08 +01:00
StdOutBackend ( Verbosity . VERBOSE ) ] )
else :
DLLogger . init ( backends = [ ] )
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
for k , v in vars ( args ) . items ( ) :
DLLogger . log ( step = " PARAMETER " , data = { k : v } )
DLLogger . log ( step = " PARAMETER " , data = { ' model_name ' : ' Tacotron2_PyT ' } )
model_name = args . model_name
2020-09-07 21:16:02 +02:00
parser = models . model_parser ( model_name , parser )
2019-12-16 18:29:08 +01:00
args , _ = parser . parse_known_args ( )
2019-03-18 19:45:05 +01:00
torch . backends . cudnn . enabled = args . cudnn_enabled
torch . backends . cudnn . benchmark = args . cudnn_benchmark
if distributed_run :
2019-12-16 18:29:08 +01:00
init_distributed ( args , world_size , local_rank , args . group_name )
2019-03-18 19:45:05 +01:00
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
run_start_time = time . perf_counter ( )
2019-03-18 19:45:05 +01:00
model_config = models . get_model_config ( model_name , args )
model = models . get_model ( model_name , model_config ,
2020-05-13 11:08:15 +02:00
cpu_run = False ,
2019-08-13 15:01:10 +02:00
uniform_initialize_bn_weight = not args . disable_uniform_initialize_bn_weight )
2019-07-08 22:51:28 +02:00
2021-11-08 23:23:50 +01:00
if distributed_run :
model = DDP ( model , device_ids = [ local_rank ] , output_device = local_rank )
2019-03-18 19:45:05 +01:00
optimizer = torch . optim . Adam ( model . parameters ( ) , lr = args . learning_rate ,
weight_decay = args . weight_decay )
2021-11-08 23:23:50 +01:00
scaler = torch . cuda . amp . GradScaler ( enabled = args . amp )
2019-03-18 19:45:05 +01:00
try :
sigma = args . sigma
except AttributeError :
sigma = None
2019-11-06 19:46:35 +01:00
start_epoch = [ 0 ]
2020-07-04 01:15:57 +02:00
if args . resume_from_last :
args . checkpoint_path = get_last_checkpoint_filename ( args . output , model_name )
2019-11-06 19:46:35 +01:00
if args . checkpoint_path is not " " :
load_checkpoint ( model , optimizer , start_epoch , model_config ,
2020-07-04 01:15:57 +02:00
args . amp , args . checkpoint_path , local_rank )
2019-11-06 19:46:35 +01:00
start_epoch = start_epoch [ 0 ]
2019-03-18 19:45:05 +01:00
criterion = loss_functions . get_loss_function ( model_name , sigma )
try :
n_frames_per_step = args . n_frames_per_step
except AttributeError :
n_frames_per_step = None
collate_fn = data_functions . get_collate_function (
model_name , n_frames_per_step )
trainset = data_functions . get_data_loader (
model_name , args . dataset_path , args . training_files , args )
2020-02-28 15:36:14 +01:00
if distributed_run :
train_sampler = DistributedSampler ( trainset )
shuffle = False
else :
train_sampler = None
shuffle = True
train_loader = DataLoader ( trainset , num_workers = 1 , shuffle = shuffle ,
2019-03-18 19:45:05 +01:00
sampler = train_sampler ,
batch_size = args . batch_size , pin_memory = False ,
drop_last = True , collate_fn = collate_fn )
valset = data_functions . get_data_loader (
model_name , args . dataset_path , args . validation_files , args )
batch_to_gpu = data_functions . get_batch_to_gpu ( model_name )
iteration = 0
2020-02-28 15:36:14 +01:00
train_epoch_items_per_sec = 0.0
2019-12-16 18:29:08 +01:00
val_loss = 0.0
num_iters = 0
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
model . train ( )
2019-03-18 19:45:05 +01:00
2019-11-06 19:46:35 +01:00
for epoch in range ( start_epoch , args . epochs ) :
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
epoch_start_time = time . perf_counter ( )
2019-03-18 19:45:05 +01:00
# used to calculate avg items/sec over epoch
reduced_num_items_epoch = 0
2020-02-28 15:36:14 +01:00
train_epoch_items_per_sec = 0.0
2019-12-16 18:29:08 +01:00
2019-03-18 19:45:05 +01:00
num_iters = 0
2020-07-04 01:15:57 +02:00
reduced_loss = 0
2019-03-18 19:45:05 +01:00
# if overflow at the last iteration then do not save checkpoint
overflow = False
2019-11-06 19:46:35 +01:00
if distributed_run :
train_loader . sampler . set_epoch ( epoch )
2019-03-18 19:45:05 +01:00
for i , batch in enumerate ( train_loader ) :
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
iter_start_time = time . perf_counter ( )
2019-12-16 18:29:08 +01:00
DLLogger . log ( step = ( epoch , i ) ,
data = { ' glob_iter/iters_per_epoch ' : str ( iteration ) + " / " + str ( len ( train_loader ) ) } )
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
adjust_learning_rate ( iteration , epoch , optimizer , args . learning_rate ,
args . anneal_steps , args . anneal_factor , local_rank )
2019-03-18 19:45:05 +01:00
model . zero_grad ( )
x , y , num_items = batch_to_gpu ( batch )
2021-11-08 23:23:50 +01:00
#AMP upstream autocast
with torch . cuda . amp . autocast ( enabled = args . amp ) :
y_pred = model ( x )
loss = criterion ( y_pred , y )
2019-03-18 19:45:05 +01:00
if distributed_run :
2019-12-16 18:29:08 +01:00
reduced_loss = reduce_tensor ( loss . data , world_size ) . item ( )
2019-03-18 19:45:05 +01:00
reduced_num_items = reduce_tensor ( num_items . data , 1 ) . item ( )
else :
reduced_loss = loss . item ( )
reduced_num_items = num_items . item ( )
if np . isnan ( reduced_loss ) :
raise Exception ( " loss is NaN " )
2020-02-28 15:36:14 +01:00
DLLogger . log ( step = ( epoch , i ) , data = { ' train_loss ' : reduced_loss } )
2019-03-18 19:45:05 +01:00
num_iters + = 1
# accumulate number of items processed in this epoch
reduced_num_items_epoch + = reduced_num_items
2020-07-04 01:15:57 +02:00
if args . amp :
2021-11-08 23:23:50 +01:00
scaler . scale ( loss ) . backward ( )
scaler . unscale_ ( optimizer )
2019-07-08 22:51:28 +02:00
grad_norm = torch . nn . utils . clip_grad_norm_ (
2021-11-08 23:23:50 +01:00
model . parameters ( ) , args . grad_clip_thresh )
scaler . step ( optimizer )
scaler . update ( )
2019-03-18 19:45:05 +01:00
else :
loss . backward ( )
grad_norm = torch . nn . utils . clip_grad_norm_ (
model . parameters ( ) , args . grad_clip_thresh )
2021-11-08 23:23:50 +01:00
optimizer . step ( )
2019-03-18 19:45:05 +01:00
2021-11-08 23:23:50 +01:00
model . zero_grad ( set_to_none = True )
2019-03-18 19:45:05 +01:00
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
iter_stop_time = time . perf_counter ( )
2019-03-18 19:45:05 +01:00
iter_time = iter_stop_time - iter_start_time
2019-07-08 22:51:28 +02:00
items_per_sec = reduced_num_items / iter_time
2020-02-28 15:36:14 +01:00
train_epoch_items_per_sec + = items_per_sec
2019-07-08 22:51:28 +02:00
2020-02-28 15:36:14 +01:00
DLLogger . log ( step = ( epoch , i ) , data = { ' train_items_per_sec ' : items_per_sec } )
2019-12-16 18:29:08 +01:00
DLLogger . log ( step = ( epoch , i ) , data = { ' train_iter_time ' : iter_time } )
iteration + = 1
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
epoch_stop_time = time . perf_counter ( )
2019-03-18 19:45:05 +01:00
epoch_time = epoch_stop_time - epoch_start_time
2020-02-28 15:36:14 +01:00
DLLogger . log ( step = ( epoch , ) , data = { ' train_items_per_sec ' :
( train_epoch_items_per_sec / num_iters if num_iters > 0 else 0.0 ) } )
2020-07-04 01:15:57 +02:00
DLLogger . log ( step = ( epoch , ) , data = { ' train_loss ' : reduced_loss } )
2020-02-28 15:36:14 +01:00
DLLogger . log ( step = ( epoch , ) , data = { ' train_epoch_time ' : epoch_time } )
2019-03-18 19:45:05 +01:00
2020-10-07 13:01:21 +02:00
val_loss , val_items_per_sec = validate ( model , criterion , valset , epoch ,
iteration , args . batch_size ,
world_size , collate_fn ,
distributed_run , local_rank ,
batch_to_gpu )
2019-03-18 19:45:05 +01:00
2020-07-04 01:15:57 +02:00
if ( epoch % args . epochs_per_checkpoint == 0 ) and args . bench_class == " " :
2019-11-06 19:46:35 +01:00
save_checkpoint ( model , optimizer , epoch , model_config ,
2020-07-04 01:15:57 +02:00
args . amp , args . output , args . model_name ,
local_rank , world_size )
2019-12-16 18:29:08 +01:00
if local_rank == 0 :
DLLogger . flush ( )
2019-03-18 19:45:05 +01:00
2020-02-28 15:36:14 +01:00
torch . cuda . synchronize ( )
run_stop_time = time . perf_counter ( )
2019-03-18 19:45:05 +01:00
run_time = run_stop_time - run_start_time
2019-12-16 18:29:08 +01:00
DLLogger . log ( step = tuple ( ) , data = { ' run_time ' : run_time } )
DLLogger . log ( step = tuple ( ) , data = { ' val_loss ' : val_loss } )
2020-02-28 15:36:14 +01:00
DLLogger . log ( step = tuple ( ) , data = { ' train_items_per_sec ' :
( train_epoch_items_per_sec / num_iters if num_iters > 0 else 0.0 ) } )
2020-10-07 13:01:21 +02:00
DLLogger . log ( step = tuple ( ) , data = { ' val_items_per_sec ' : val_items_per_sec } )
2019-03-18 19:45:05 +01:00
2019-12-16 18:29:08 +01:00
if local_rank == 0 :
DLLogger . flush ( )
2019-03-18 19:45:05 +01:00
if __name__ == ' __main__ ' :
main ( )