DeepLearningExamples/PyTorch/Forecasting/TFT/tft_pyt/train.py
2021-11-08 14:08:58 -08:00

295 lines
13 KiB
Python

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
import os
import pickle
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from apex import amp
from apex.optimizers import FusedAdam
#from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel import DistributedDataParallel as DDP
import numpy as np
import dllogger
from modeling import TemporalFusionTransformer
from configuration import CONFIGS
from data_utils import TFTBinaryDataset, sample_data
from log_helper import setup_logger
from criterions import QuantileLoss
from inference import predict
from utils import PerformanceMeter
import gpu_affinity
from ema import ModelEma
def load_dataset(args, config):
train_split = TFTBinaryDataset(os.path.join(args.data_path, 'train.bin'), config)
train_split = sample_data(train_split, args.sample_data[0])
if args.distributed_world_size > 1:
data_sampler = DistributedSampler(train_split, args.distributed_world_size, args.distributed_rank, seed=args.seed + args.distributed_rank, drop_last=True)
else:
data_sampler = RandomSampler(train_split)
train_loader = DataLoader(train_split, batch_size=args.batch_size, num_workers=4, sampler=data_sampler, pin_memory=True)
valid_split = TFTBinaryDataset(os.path.join(args.data_path, 'valid.bin'), config)
valid_split = sample_data(valid_split, args.sample_data[1])
if args.distributed_world_size > 1:
data_sampler = DistributedSampler(valid_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
else:
data_sampler = None
valid_loader = DataLoader(valid_split, batch_size=args.batch_size, sampler=data_sampler, num_workers=4, pin_memory=True)
test_split = TFTBinaryDataset(os.path.join(args.data_path, 'test.bin'), config)
if args.distributed_world_size > 1:
data_sampler = DistributedSampler(test_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
else:
data_sampler = None
test_loader = DataLoader(test_split, batch_size=args.batch_size, sampler=data_sampler, num_workers=4, pin_memory=True)
print_once(f'Train split length: {len(train_split)}')
print_once(f'Valid split length: {len(valid_split)}')
print_once(f'Test split length: {len(test_split)}')
return train_loader, valid_loader, test_loader
def print_once(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
def main(args):
# Enable CuDNN autotuner
nproc_per_node = torch.cuda.device_count()
if args.affinity != 'disabled':
affinity = gpu_affinity.set_affinity(
args.local_rank,
nproc_per_node,
args.affinity
)
print(f'{args.local_rank}: thread affinity: {affinity}')
torch.backends.cudnn.benchmark = True
### INIT DISTRIBUTED
if args.distributed_world_size > 1:
args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank))
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = dist.get_rank()
print_once(f'Distributed training with {args.distributed_world_size} GPUs')
torch.cuda.synchronize()
if args.seed:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
setup_logger(args)
config = CONFIGS[args.dataset]()
if args.overwrite_config:
config.__dict__.update(json.loads(args.overwrite_config))
dllogger.log(step='HPARAMS', data={**vars(args), **vars(config)}, verbosity=1)
model = TemporalFusionTransformer(config).cuda()
if args.ema_decay:
model_ema = ModelEma(model, decay=args.ema_decay)
print_once('Model params: {}'.format(sum(p.numel() for p in model.parameters())))
criterion = QuantileLoss(config).cuda()
optimizer = FusedAdam(model.parameters(), lr=args.lr)
if args.use_amp:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
if args.distributed_world_size > 1:
#model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
model = DDP(model)
train_loader, valid_loader, test_loader = load_dataset(args, config)
global_step = 0
perf_meter = PerformanceMeter()
for epoch in range(args.epochs):
start = time.time()
dllogger.log(step=global_step, data={'epoch': epoch}, verbosity=1)
model.train()
for local_step, batch in enumerate(train_loader):
perf_meter.reset_current_lap()
batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
predictions = model(batch)
targets = batch['target'][:,config.encoder_length:,:]
p_losses = criterion(predictions, targets)
loss = p_losses.sum()
if args.use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if not args.grad_accumulation or (global_step+1) % args.grad_accumulation == 0:
if args.clip_grad:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step()
optimizer.zero_grad()
if args.ema_decay:
model_ema.update(model)
if args.distributed_world_size > 1:
dist.all_reduce(p_losses)
p_losses /= args.distributed_world_size
loss = p_losses.sum()
torch.cuda.synchronize()
ips = perf_meter.update(args.batch_size * args.distributed_world_size,
exclude_from_total=local_step in [0, len(train_loader)-1])
log_dict = {'P10':p_losses[0].item(), 'P50':p_losses[1].item(), 'P90':p_losses[2].item(), 'loss': loss.item(), 'items/s':ips}
dllogger.log(step=global_step, data=log_dict, verbosity=1)
global_step += 1
validate(args, config, model_ema if args.ema_decay else model, criterion, valid_loader, global_step)
if validate.early_stop_c >= args.early_stopping:
print_once('Early stopping')
break
### TEST PHASE ###
state_dict = torch.load(os.path.join(args.results, 'checkpoint.pt'), map_location='cpu')
if isinstance(model, DDP):
model.module.load_state_dict(state_dict['model'])
else:
model.load_state_dict(state_dict['model'])
model.cuda().eval()
tgt_scalers = pickle.load(open(os.path.join(args.data_path, 'tgt_scalers.bin'), 'rb'))
cat_encodings = pickle.load(open(os.path.join(args.data_path,'cat_encodings.bin'), 'rb'))
unscaled_predictions, unscaled_targets, _, _ = predict(args, config, model, test_loader, tgt_scalers, cat_encodings)
losses = QuantileLoss(config)(unscaled_predictions, unscaled_targets)
normalizer = unscaled_targets.abs().mean()
quantiles = 2 * losses / normalizer
if args.distributed_world_size > 1:
quantiles = quantiles.cuda()
dist.all_reduce(quantiles)
quantiles /= args.distributed_world_size
quantiles = {'test_p10': quantiles[0].item(), 'test_p50': quantiles[1].item(), 'test_p90': quantiles[2].item(), 'sum':sum(quantiles).item()}
finish_log = {**quantiles, 'average_ips':perf_meter.avg, 'convergence_step':validate.conv_step}
dllogger.log(step=(), data=finish_log, verbosity=1)
def validate(args, config, model, criterion, dataloader, global_step):
if not hasattr(validate, 'best_valid_loss'):
validate.best_valid_loss = float('inf')
if not hasattr(validate, 'early_stop_c'):
validate.early_stop_c = 0
model.eval()
losses = []
validation_start = time.time()
for batch in dataloader:
with torch.no_grad():
batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
predictions = model(batch)
targets = batch['target'][:,config.encoder_length:,:]
p_losses = criterion(predictions, targets)
bs = next(t for t in batch.values() if t is not None).shape[0]
losses.append((p_losses, bs))
validation_end = time.time()
p_losses = sum([l[0]*l[1] for l in losses])/sum([l[1] for l in losses]) #takes into accunt that the last batch is not full
if args.distributed_world_size > 1:
dist.all_reduce(p_losses)
p_losses = p_losses/args.distributed_world_size
ips = len(dataloader.dataset) / (validation_end - validation_start)
log_dict = {'P10':p_losses[0].item(), 'P50':p_losses[1].item(), 'P90':p_losses[2].item(), 'loss': p_losses.sum().item(), 'items/s':ips}
if log_dict['loss'] < validate.best_valid_loss:
validate.best_valid_loss = log_dict['loss']
validate.early_stop_c = 0
validate.conv_step = global_step
if not dist.is_initialized() or dist.get_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, (DDP, ModelEma)) else model.state_dict()
ckpt = {'args':args, 'config':config, 'model':state_dict}
torch.save(ckpt, os.path.join(args.results, 'checkpoint.pt'))
if args.distributed_world_size > 1:
dist.barrier()
else:
validate.early_stop_c += 1
log_dict = {'val_'+k:v for k,v in log_dict.items()}
dllogger.log(step=global_step, data=log_dict, verbosity=1)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True,
help='Path to the dataset')
parser.add_argument('--dataset', type=str, required=True, choices=CONFIGS.keys(),
help='Dataset name')
parser.add_argument('--epochs', type=int, default=25,
help='Default number of training epochs')
parser.add_argument('--sample_data', type=lambda x: int(float(x)), nargs=2, default=[-1, -1],
help="""Subsample the dataset. Specify number of training and valid examples.
Values can be provided in scientific notation. Floats will be truncated.""")
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--use_amp', action='store_true', help='Enable automatic mixed precision')
parser.add_argument('--clip_grad', type=float, default=0.0)
parser.add_argument('--grad_accumulation', type=int, default=0)
parser.add_argument('--early_stopping', type=int, default=1000,
help='Stop training if validation loss does not improve for more than this number of epochs.')
parser.add_argument('--results', type=str, default='/results',
help='Directory in which results are stored')
parser.add_argument('--log_file', type=str, default='dllogger.json',
help='Name of dllogger output file')
parser.add_argument('--distributed_world_size', type=int, metavar='N',
default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
parser.add_argument('--distributed_rank', default=os.getenv('LOCAL_RANK', 0), type=int,
help='rank of the current worker')
parser.add_argument('--local_rank', default=0, type=int,
help='rank of the current worker')
parser.add_argument('--overwrite_config', type=str, default='',
help='JSON string used to overload config')
parser.add_argument('--affinity', type=str,
default='socket_unique_interleaved',
choices=['socket', 'single', 'single_unique',
'socket_unique_interleaved',
'socket_unique_continuous',
'disabled'],
help='type of CPU affinity')
parser.add_argument("--ema_decay", type=float, default=0.0, help='Use exponential moving average')
ARGS = parser.parse_args()
main(ARGS)