DeepLearningExamples/PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py

915 lines
37 KiB
Python

# coding: utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
import itertools
import logging
import math
import os
import sys
import time
import dllogger
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from apex import amp
from torch.nn.parallel import DistributedDataParallel
import lamb
import utils
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.data_parallel import BalancedDataParallel
from utils.exp_utils import AverageMeter
from utils.exp_utils import benchmark
from utils.exp_utils import create_exp_dir
from utils.exp_utils import log_env_info
def parse_args():
parent_parser = argparse.ArgumentParser(
description='PyTorch Transformer-XL Language Model',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
add_help=False,
)
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
cfg_parser.add_argument('--config', default='default')
cfg_parser.add_argument('--config_file', default='config.yaml')
config_args, _ = cfg_parser.parse_known_args()
if config_args.config is not None and config_args.config_file is not None:
with open(config_args.config_file) as f:
config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
else:
config = {}
general = parser.add_argument_group('general setup')
general.add_argument('--work_dir', default='LM-TFM', type=str,
help='Directory for the results')
general.add_argument('--append_dataset', action='store_true',
help='Automatically append dataset name to work_dir')
general.add_argument('--append_time', action='store_true',
help='Automatically append current time to work_dir')
general.add_argument('--cuda', action='store_true',
help='Run training on a GPU using CUDA')
general.add_argument('--fp16', action='store_true',
help='Run training in fp16/mixed precision')
general.add_argument('--restart', type=str, default='',
help='Restart training from the saved checkpoint')
general.add_argument('--debug', action='store_true',
help='Run in debug mode (do not create exp dir)')
general.add_argument('--log_all_ranks', action='store_true',
help='Enable logging from all distributed ranks')
general.add_argument('--save-all', action='store_true',
help='Save all checkpoints')
general.add_argument('--no_env', action='store_true',
help='Do not print info on execution env')
general.add_argument('--log_interval', type=int, default=10,
help='Report interval')
general.add_argument('--target_throughput', type=float, default=None,
help='Target training throughput (for benchmarking)')
general.add_argument('--target_perplexity', type=float, default=None,
help='Target validation perplexity (for benchmarking)')
dataset = parser.add_argument_group('dataset setup')
dataset.add_argument('--data', type=str, default='../data/wikitext-103',
help='Location of the data corpus')
dataset.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='Dataset name')
dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
help='Type of vocabulary')
model = parser.add_argument_group('model setup')
model.add_argument('--n_layer', type=int, default=16,
help='Number of total layers')
model.add_argument('--n_head', type=int, default=8,
help='Number of heads')
model.add_argument('--d_head', type=int, default=64,
help='Head dimension')
model.add_argument('--d_embed', type=int, default=-1,
help='Embedding dimension')
model.add_argument('--d_model', type=int, default=512,
help='Model dimension')
model.add_argument('--d_inner', type=int, default=2048,
help='Inner dimension in feedforward layer')
model.add_argument('--dropout', type=float, default=0.1,
help='Global dropout rate')
model.add_argument('--dropatt', type=float, default=0.0,
help='Attention probability dropout rate')
model.add_argument('--pre_lnorm', action='store_true',
help='Apply LayerNorm to the input instead of the output')
model.add_argument('--attn_type', type=int, default=0,
help='Attention type. 0 for ours, 1 for Shaw et al,'
'2 for Vaswani et al, 3 for Al Rfou et al.')
model.add_argument('--not_tied', action='store_true',
help='Do not tie the word embedding and softmax weights')
model.add_argument('--clamp_len', type=int, default=-1,
help='Use the same pos embeddings after clamp_len')
model.add_argument('--adaptive', action='store_true',
help='Use adaptive softmax')
model.add_argument('--div_val', type=int, default=1,
help='Dividend value for adaptive input and softmax')
model.add_argument('--sample_softmax', type=int, default=-1,
help='Number of samples in sampled softmax')
model.add_argument('--init', default='normal', type=str,
help='Parameter initializer to use')
model.add_argument('--emb_init', default='normal', type=str,
help='Parameter initializer to use')
model.add_argument('--init_range', type=float, default=0.1,
help='Parameters initialized by U(-init_range, init_range)')
model.add_argument('--emb_init_range', type=float, default=0.01,
help='Parameters initialized by U(-init_range, init_range)')
model.add_argument('--init_std', type=float, default=0.02,
help='Parameters initialized by N(0, init_std)')
model.add_argument('--proj_init_std', type=float, default=0.01,
help='Parameters initialized by N(0, init_std)')
opt = parser.add_argument_group('optimizer setup')
opt.add_argument('--optim', default='jitlamb', type=str,
choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
help='Optimizer to use')
opt.add_argument('--lr', type=float, default=0.01,
help='Initial learning rate')
opt.add_argument('--mom', type=float, default=0.0,
help='Momentum for sgd')
opt.add_argument('--scheduler', default='cosine', type=str,
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
help='LR scheduler to use')
opt.add_argument('--max_step_scheduler', type=int, default=None,
help='Max number of training steps for LR scheduler')
opt.add_argument('--warmup_step', type=int, default=1000,
help='Number of iterations for LR warmup')
opt.add_argument('--decay_rate', type=float, default=0.5,
help='Decay factor when ReduceLROnPlateau is used')
opt.add_argument('--lr_min', type=float, default=0.0,
help='Minimum learning rate during annealing')
opt.add_argument('--clip', type=float, default=0.25,
help='Gradient clipping')
opt.add_argument('--weight_decay', type=float, default=0.0,
help='Weight decay for adam|lamb')
opt.add_argument('--clip_nonemb', action='store_true',
help='Only clip the gradient of non-embedding params')
opt.add_argument('--patience', type=int, default=0,
help='Patience')
opt.add_argument('--eta_min', type=float, default=0.001,
help='Min learning rate for cosine scheduler')
training = parser.add_argument_group('training setup')
training.add_argument('--max_step', type=int, default=40000,
help='Max number of training steps')
training.add_argument('--batch_size', type=int, default=256,
help='Global batch size')
training.add_argument('--local_batch_size', type=int, default=None,
help='Local (per-device) batch size, this setting \
overrides global --batch_size and sets batch_size \
to local_batch_size * world_size')
training.add_argument('--batch_chunk', type=int, default=1,
help='Split batch into chunks and train with '
'gradient accumulation')
training.add_argument('--roll', action='store_true',
help='Enable random shifts within each data stream')
training.add_argument('--tgt_len', type=int, default=192,
help='Number of tokens to predict')
training.add_argument('--ext_len', type=int, default=0,
help='Length of the extended context')
training.add_argument('--mem_len', type=int, default=192,
help='Length of the retained previous heads')
training.add_argument('--seed', type=int, default=1111,
help='Random seed')
training.add_argument('--multi_gpu', default=None, type=str,
choices=['ddp', 'dp'],
help='Use multiple GPU')
training.add_argument('--gpu0_bsz', type=int, default=-1,
help='Batch size on gpu 0 (for "dp" backend)')
training.add_argument('--same_length', action='store_true',
help='Use the same attn length for all tokens')
training.add_argument('--varlen', action='store_true',
help='Use variable length')
val = parser.add_argument_group('validation setup')
val.add_argument('--eval_tgt_len', type=int, default=192,
help='Number of tokens to predict for evaluation')
val.add_argument('--eval_batch_size', type=int, default=16,
help='Eval batch size')
val.add_argument('--eval_max_steps', type=int, default=-1,
help='Max eval steps')
val.add_argument('--eval_interval', type=int, default=5000,
help='Evaluation interval')
dist = parser.add_argument_group('distributed setup')
dist.add_argument('--local_rank', type=int,
default=os.getenv('LOCAL_RANK', 0),
help='Used for multi-process training.')
parser.set_defaults(**config)
args, _ = parser.parse_known_args()
args.tied = not args.not_tied
if args.d_embed < 0:
args.d_embed = args.d_model
assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0
return args
def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
train_step, best_val_loss, work_dir, name='checkpoint.pt'):
if args.fp16:
amp_state = amp.state_dict()
else:
amp_state = None
state = {
'args': args,
'model_config': model_config,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'vocab': vocab,
'amp_state': amp_state,
'train_step': train_step,
'best_val_loss': best_val_loss,
}
with utils.distributed.sync_workers() as rank:
path = os.path.join(work_dir, name)
logging.info(f'Saving checkpoint to {path}')
if rank == 0:
torch.save(state, path)
def load_checkpoint(path):
if os.path.isdir(path):
path = os.path.join(path, 'checkpoint_last.pt')
dst = f'cuda:{torch.cuda.current_device()}'
logging.info(f'Loading checkpoint from {path}')
checkpoint = torch.load(path, map_location=dst)
return checkpoint
def init_weight(weight, args):
if args.init == 'uniform':
nn.init.uniform_(weight, -args.init_range, args.init_range)
elif args.init == 'normal':
nn.init.normal_(weight, 0.0, args.init_std)
def init_bias(bias):
nn.init.constant_(bias, 0.0)
def weights_init(m, args):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight, args)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight, args)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight, args)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
if hasattr(m, 'out_layers_weights'):
for i in range(len(m.out_layers_weights)):
if m.out_layers_weights[i] is not None:
init_weight(m.out_layers_weights[i], args)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, args.init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb, args)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias, args)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias, args)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)
def update_dropout(m, args):
classname = m.__class__.__name__
if classname.find('Dropout') != -1:
if hasattr(m, 'p'):
m.p = args.dropout
def update_dropatt(m, args):
if hasattr(m, 'dropatt'):
m.dropatt.p = args.dropatt
def evaluate(eval_iter, model, args):
# Turn on evaluation mode which disables dropout.
model.eval()
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if args.mem_len == 0:
model.reset_length(tgt_len=args.eval_tgt_len,
ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
mem_len=args.mem_len
)
else:
model.reset_length(tgt_len=args.eval_tgt_len,
ext_len=args.ext_len,
mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
)
# Evaluation
total_len, total_loss = 0, 0.
with torch.no_grad():
mems = None
for i, (data, target, seq_len, warm) in enumerate(eval_iter):
if args.eval_max_steps > 0 and i >= args.eval_max_steps:
break
loss, mems = model(data, target, mems)
loss = loss.float().mean()
if warm:
assert (not mems) or all([m.size(0) == model.mem_len for m in mems])
total_loss += seq_len * loss.item()
total_len += seq_len
# Switch back to the training mode
model.reset_length(tgt_len=args.tgt_len,
ext_len=args.ext_len,
mem_len=args.mem_len
)
model.train()
return total_loss / total_len
def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step,
best_val_loss, meters, args):
# Turn on training mode which enables dropout.
model.train()
train_loss = 0
target_tokens = 0
log_step = 0
log_start_time = time.time()
mems = [None for _ in range(args.batch_chunk)]
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len, _) in enumerate(train_iter):
log_step += 1
target_tokens += target.numel()
model.zero_grad()
data_chunks = torch.chunk(data, args.batch_chunk, 1)
target_chunks = torch.chunk(target, args.batch_chunk, 1)
for i in range(args.batch_chunk):
data_i = data_chunks[i].contiguous()
target_i = target_chunks[i].contiguous()
loss, mems[i] = para_model(data_i, target_i, mems[i])
loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
train_loss += loss.float().item()
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
if optimizer_sparse:
optimizer_sparse.step()
# step-wise learning rate annealing
train_step += 1
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
# linear warmup stage
if train_step < args.warmup_step:
curr_lr = args.lr * train_step / args.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
if optimizer_sparse:
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
else:
if args.scheduler == 'cosine':
scheduler.step(train_step - args.warmup_step)
if scheduler_sparse:
scheduler_sparse.step(train_step - args.warmup_step)
elif args.scheduler == 'inv_sqrt':
scheduler.step(train_step)
if scheduler_sparse:
scheduler_sparse.step(train_step)
if train_step % args.log_interval == 0:
cur_loss = train_loss / log_step
cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
train_loss = 0
elapsed = time.time() - log_start_time
avg_elapsed = elapsed / log_step
avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
log_start_time = time.time()
log_step = 0
lr = optimizer.param_groups[0]['lr']
throughput = target_tokens / elapsed
throughput = utils.distributed.all_reduce_item(throughput, op='sum')
meters['train_throughput'].update(throughput)
target_tokens = 0
log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
'| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
epoch,
train_step,
batch+1,
tr_iter.n_batch,
lr,
avg_elapsed * 1000,
throughput,
cur_loss,
)
dllogger_data = {
'epoch': epoch,
'train_batch': batch+1,
'lr': lr,
'train_time/batch': avg_elapsed * 1000,
'train_throughput': throughput,
'train_loss': cur_loss,
}
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
else:
log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
dllogger_data['train_perplexity'] = math.exp(cur_loss)
logging.info(log_str)
dllogger.log(step=train_step, data=dllogger_data)
if train_step % args.eval_interval == 0:
eval_start_time = time.time()
val_loss = evaluate(va_iter, model, args)
val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
logging.info('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(
train_step // args.eval_interval,
train_step,
(time.time() - eval_start_time),
val_loss,
)
dllogger_data = {
'valid_elapsed': (time.time() - eval_start_time),
'valid_loss': val_loss,
}
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
dllogger_data['valid_perplexity'] = math.exp(val_loss)
logging.info(log_str)
logging.info('-' * 100)
dllogger.log(step=train_step, data=dllogger_data)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
best_val_loss = val_loss
if not args.debug:
name = 'checkpoint_best.pt'
save_checkpoint(args, model, model_config, optimizer,
scheduler, vocab, train_step,
best_val_loss, args.work_dir, name)
# Always save after eval if save_all is true and not debug
if not args.debug and args.save_all:
name = f'checkpoint_{train_step}.pt'
save_checkpoint(args, model, model_config, optimizer,
scheduler, vocab, train_step, best_val_loss,
args.work_dir, name)
# Save last checkpoint if not debug and not save_all
if not args.debug and not args.save_all:
name = 'checkpoint_last.pt'
save_checkpoint(args, model, model_config, optimizer,
scheduler, vocab, train_step, best_val_loss,
args.work_dir, name)
# dev-performance based learning rate annealing
if args.scheduler == 'dev_perf':
scheduler.step(val_loss)
if scheduler_sparse:
scheduler_sparse.step(val_loss)
# subtract eval time from timers for training
log_start_time += time.time() - eval_start_time
if train_step == args.max_step:
break
return train_step, best_val_loss
def main():
args = parse_args()
# Initialize device and distributed backend
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda' if args.cuda else 'cpu')
utils.distributed.init_distributed(args.cuda)
args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
args.dataset,
args.append_dataset,
args.append_time,
)
with utils.distributed.sync_workers() as rank:
if rank == 0:
create_exp_dir(args.work_dir,
scripts_to_save=['train.py', 'mem_transformer.py'],
debug=args.debug)
# Setup logging
if args.log_all_ranks:
log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
else:
log_file = f'train_log.log'
dllog_file = f'train_log.json'
log_file = os.path.join(args.work_dir, log_file)
dllog_file = os.path.join(args.work_dir, dllog_file)
if args.debug:
log_file = os.devnull
dllog_file = os.devnull
utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
filename=log_file,
)
utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
if args.local_batch_size is not None:
world_size = utils.distributed.get_world_size()
args.batch_size = world_size * args.local_batch_size
logging.info(f'--local_batch_size was set, adjusting global batch size'
f' to {args.batch_size} (local_batch_size * world_size)')
logging.info(args)
dllogger.log(step='PARAMETER', data=vars(args))
if not args.no_env:
log_env_info()
# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
###########################################################################
# Load data
###########################################################################
corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
ntokens = len(corpus.vocab)
vocab = corpus.vocab
args.n_token = ntokens
if args.mem_len == 0:
eval_mem_len = 0
else:
eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid', args.eval_batch_size,
args.eval_tgt_len, device=device,
mem_len=eval_mem_len, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.eval_batch_size,
args.eval_tgt_len, device=device,
mem_len=eval_mem_len, ext_len=args.ext_len)
# adaptive softmax / embedding
cutoffs, tie_projs = [], [False]
if args.adaptive:
assert args.dataset in ['wt103', 'lm1b']
if args.dataset == 'wt103':
cutoffs = [19997, 39997, 199997]
tie_projs += [True] * len(cutoffs)
elif args.dataset == 'lm1b':
cutoffs = [59997, 99997, 639997]
tie_projs += [False] * len(cutoffs)
###########################################################################
# Build the model
###########################################################################
model_config = {
'n_token': ntokens,
'n_layer': args.n_layer,
'n_head': args.n_head,
'd_model': args.d_model,
'd_head': args.d_head,
'd_inner': args.d_inner,
'dropout': args.dropout,
'dropatt': args.dropatt,
'dtype': None,
'tie_weight': args.tied,
'd_embed': args.d_embed,
'div_val': args.div_val,
'tie_projs': tie_projs,
'pre_lnorm': args.pre_lnorm,
'tgt_len': args.tgt_len,
'ext_len': args.ext_len,
'mem_len': args.mem_len,
'cutoffs': cutoffs,
'same_length': args.same_length,
'attn_type': args.attn_type,
'clamp_len': args.clamp_len,
'sample_softmax': args.sample_softmax,
}
model = MemTransformerLM(**model_config)
model.apply(functools.partial(weights_init, args=args))
# ensure embedding init is not overridden by out_layer in case of weight sharing
model.word_emb.apply(functools.partial(weights_init, args=args))
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
# optimizer
if args.optim.lower() == 'sgd':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
else:
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.mom)
optimizer_sparse = None
elif args.optim.lower() == 'adam':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
optimizer = optim.Adam(dense_params, lr=args.lr,
weight_decay=args.weight_decay)
else:
optimizer = optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer_sparse = None
elif args.optim.lower() == 'adagrad':
optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
optimizer_sparse = None
elif args.optim.lower() == 'lamb':
optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer_sparse = None
elif args.optim.lower() == 'jitlamb':
optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer_sparse = None
model = model.to(device)
if args.fp16:
model, optimizer = amp.initialize(
model,
optimizer,
opt_level='O2',
)
if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
para_model = DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
find_unused_parameters=True,
)
elif args.multi_gpu == 'dp':
if args.gpu0_bsz >= 0:
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
model, dim=1).to(device)
else:
para_model = nn.DataParallel(model, dim=1).to(device)
else:
para_model = model
# scheduler
if args.scheduler == 'cosine':
if args.max_step_scheduler:
max_step = args.max_step_scheduler
else:
max_step = args.max_step
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
if args.sample_softmax > 0 and optimizer_sparse is not None:
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
optimizer_sparse, max_step - args.warmup_step,
eta_min=args.eta_min)
else:
scheduler_sparse = None
elif args.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
if args.sample_softmax > 0 and optimizer_sparse is not None:
scheduler_sparse = optim.lr_scheduler.LambdaLR(
optimizer_sparse,
lr_lambda=lr_lambda
)
else:
scheduler_sparse = None
elif args.scheduler == 'dev_perf':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=args.decay_rate, patience=args.patience,
min_lr=args.lr_min,
)
if args.sample_softmax > 0 and optimizer_sparse is not None:
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
optimizer_sparse, factor=args.decay_rate, patience=args.patience,
min_lr=args.lr_min,
)
else:
scheduler_sparse = None
elif args.scheduler == 'constant':
pass
logging.info('=' * 100)
for k, v in args.__dict__.items():
logging.info(' - {} : {}'.format(k, v))
logging.info('=' * 100)
logging.info('#params = {}'.format(args.n_all_param))
logging.info('#non emb params = {}'.format(args.n_nonemb_param))
train_step = 0
best_val_loss = None
if args.restart:
checkpoint = load_checkpoint(args.restart)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
scheduler.load_state_dict(checkpoint['scheduler_state'])
if args.fp16:
amp.load_state_dict(checkpoint['amp_state'])
train_step = checkpoint['train_step']
best_val_loss = checkpoint['best_val_loss']
model.apply(functools.partial(update_dropout, args=args))
model.apply(functools.partial(update_dropatt, args=args))
meters = {}
warmup = args.mem_len // args.tgt_len + 2
meters['train_throughput'] = AverageMeter(warmup=warmup)
###########################################################################
# Train
###########################################################################
# Loop over epochs.
# At any point you can hit Ctrl + C to break out of training early.
start_time = time.time()
try:
for epoch in itertools.count(start=1):
if args.roll:
tr_iter.roll()
train_step, best_val_loss = train(
tr_iter, va_iter, model, para_model, model_config, optimizer,
optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
train_step, best_val_loss, meters, args
)
if train_step == args.max_step:
logging.info('-' * 100)
logging.info('End of training')
break
except KeyboardInterrupt:
logging.info('-' * 100)
logging.info('Exiting from training early')
elapsed = time.time() - start_time
###########################################################################
# Test
###########################################################################
summary = {}
test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
if not args.debug and os.path.exists(test_path):
# Load the best saved model.
checkpoint = load_checkpoint(test_path)
model.load_state_dict(checkpoint['model_state'])
# Run on test data.
test_start_time = time.time()
test_loss = evaluate(te_iter, model, args)
test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
test_elapsed = time.time() - test_start_time
logging.info('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
test_elapsed, test_loss, test_loss / math.log(2)))
else:
logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_elapsed, test_loss, math.exp(test_loss)))
logging.info('=' * 100)
summary.update({
'test_elapsed': test_elapsed,
'test_loss': test_loss,
})
if args.dataset in ['enwik8', 'text8']:
summary['test_bits_per_character'] = test_loss / math.log(2)
else:
summary['test_perplexity'] = math.exp(test_loss)
logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
if best_val_loss:
val_perplexity = math.exp(best_val_loss)
else:
val_perplexity = None
summary.update({
'train_throughput': meters['train_throughput'].avg,
'train_elapsed': elapsed / 60,
'valid_loss': best_val_loss,
'valid_perplexity': val_perplexity,
})
dllogger.log(step=tuple(), data=summary)
passed = benchmark(
target_perplexity=args.target_perplexity,
test_perplexity=val_perplexity,
target_throughput=args.target_throughput,
test_throughput=meters['train_throughput'].avg
)
if not passed:
sys.exit(1)
if __name__ == "__main__":
main()