372 lines
16 KiB
Python
372 lines
16 KiB
Python
|
#!/usr/bin/env python
|
||
|
import argparse
|
||
|
import os
|
||
|
import logging
|
||
|
from ast import literal_eval
|
||
|
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.parallel
|
||
|
import torch.utils.data.distributed
|
||
|
import torch.distributed as dist
|
||
|
import torch.optim
|
||
|
|
||
|
from seq2seq.models.gnmt import GNMT
|
||
|
from seq2seq.train.smoothing import LabelSmoothing
|
||
|
from seq2seq.data.dataset import TextDataset
|
||
|
from seq2seq.data.dataset import ParallelDataset
|
||
|
from seq2seq.data.tokenizer import Tokenizer
|
||
|
from seq2seq.utils import setup_logging
|
||
|
import seq2seq.data.config as config
|
||
|
import seq2seq.train.trainer as trainers
|
||
|
from seq2seq.inference.inference import Translator
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
"""
|
||
|
Parse commandline arguments.
|
||
|
"""
|
||
|
parser = argparse.ArgumentParser(description='GNMT training',
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
|
||
|
# dataset
|
||
|
dataset = parser.add_argument_group('dataset setup')
|
||
|
dataset.add_argument('--dataset-dir', default=None, required=True,
|
||
|
help='path to directory with training/validation data')
|
||
|
dataset.add_argument('--max-size', default=None, type=int,
|
||
|
help='use at most MAX_SIZE elements from training \
|
||
|
dataset (useful for benchmarking), by default \
|
||
|
uses entire dataset')
|
||
|
|
||
|
# results
|
||
|
results = parser.add_argument_group('results setup')
|
||
|
results.add_argument('--results-dir', default='results',
|
||
|
help='path to directory with results, it it will be \
|
||
|
automatically created if does not exist')
|
||
|
results.add_argument('--save', default='gnmt_wmt16',
|
||
|
help='defines subdirectory within RESULTS_DIR for \
|
||
|
results from this training run')
|
||
|
results.add_argument('--print-freq', default=10, type=int,
|
||
|
help='print log every PRINT_FREQ batches')
|
||
|
|
||
|
# model
|
||
|
model = parser.add_argument_group('model setup')
|
||
|
model.add_argument('--model-config',
|
||
|
default="{'hidden_size': 1024,'num_layers': 4, \
|
||
|
'dropout': 0.2, 'share_embedding': True}",
|
||
|
help='GNMT architecture configuration')
|
||
|
model.add_argument('--smoothing', default=0.1, type=float,
|
||
|
help='label smoothing, if equal to zero model will use \
|
||
|
CrossEntropyLoss, if not zero model will be trained \
|
||
|
with label smoothing loss')
|
||
|
|
||
|
# setup
|
||
|
general = parser.add_argument_group('general setup')
|
||
|
general.add_argument('--math', default='fp16', choices=['fp32', 'fp16'],
|
||
|
help='arithmetic type')
|
||
|
general.add_argument('--seed', default=None, type=int,
|
||
|
help='set random number generator seed')
|
||
|
general.add_argument('--disable-eval', action='store_true', default=False,
|
||
|
help='disables validation after every epoch')
|
||
|
general.add_argument('--workers', default=0, type=int,
|
||
|
help='number of workers for data loading')
|
||
|
|
||
|
cuda_parser = general.add_mutually_exclusive_group(required=False)
|
||
|
cuda_parser.add_argument('--cuda', dest='cuda', action='store_true',
|
||
|
help='enables cuda (use \'--no-cuda\' to disable)')
|
||
|
cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false',
|
||
|
help=argparse.SUPPRESS)
|
||
|
cuda_parser.set_defaults(cuda=True)
|
||
|
|
||
|
cudnn_parser = general.add_mutually_exclusive_group(required=False)
|
||
|
cudnn_parser.add_argument('--cudnn', dest='cudnn', action='store_true',
|
||
|
help='enables cudnn (use \'--no-cudnn\' to disable)')
|
||
|
cudnn_parser.add_argument('--no-cudnn', dest='cudnn', action='store_false',
|
||
|
help=argparse.SUPPRESS)
|
||
|
cudnn_parser.set_defaults(cudnn=True)
|
||
|
|
||
|
# training
|
||
|
training = parser.add_argument_group('training setup')
|
||
|
training.add_argument('--batch-size', default=128, type=int,
|
||
|
help='batch size for training')
|
||
|
training.add_argument('--epochs', default=8, type=int,
|
||
|
help='number of total epochs to run')
|
||
|
training.add_argument('--optimization-config',
|
||
|
default="{'optimizer': 'Adam', 'lr': 5e-4}", type=str,
|
||
|
help='optimizer config')
|
||
|
training.add_argument('--grad-clip', default=5.0, type=float,
|
||
|
help='enabled gradient clipping and sets maximum \
|
||
|
gradient norm value')
|
||
|
training.add_argument('--max-length-train', default=50, type=int,
|
||
|
help='maximum sequence length for training')
|
||
|
training.add_argument('--min-length-train', default=0, type=int,
|
||
|
help='minimum sequence length for training')
|
||
|
|
||
|
bucketing_parser = training.add_mutually_exclusive_group(required=False)
|
||
|
bucketing_parser.add_argument('--bucketing', dest='bucketing', action='store_true',
|
||
|
help='enables bucketing (use \'--no-bucketing\' to disable)')
|
||
|
bucketing_parser.add_argument('--no-bucketing', dest='bucketing', action='store_false',
|
||
|
help=argparse.SUPPRESS)
|
||
|
bucketing_parser.set_defaults(bucketing=True)
|
||
|
|
||
|
# validation
|
||
|
validation = parser.add_argument_group('validation setup')
|
||
|
validation.add_argument('--val-batch-size', default=128, type=int,
|
||
|
help='batch size for validation')
|
||
|
validation.add_argument('--max-length-val', default=80, type=int,
|
||
|
help='maximum sequence length for validation')
|
||
|
validation.add_argument('--min-length-val', default=0, type=int,
|
||
|
help='minimum sequence length for validation')
|
||
|
|
||
|
# test
|
||
|
test = parser.add_argument_group('test setup')
|
||
|
test.add_argument('--test-batch-size', default=128, type=int,
|
||
|
help='batch size for test')
|
||
|
test.add_argument('--max-length-test', default=150, type=int,
|
||
|
help='maximum sequence length for test')
|
||
|
test.add_argument('--min-length-test', default=0, type=int,
|
||
|
help='minimum sequence length for test')
|
||
|
test.add_argument('--beam-size', default=5, type=int,
|
||
|
help='beam size')
|
||
|
test.add_argument('--len-norm-factor', default=0.6, type=float,
|
||
|
help='length normalization factor')
|
||
|
test.add_argument('--cov-penalty-factor', default=0.1, type=float,
|
||
|
help='coverage penalty factor')
|
||
|
test.add_argument('--len-norm-const', default=5.0, type=float,
|
||
|
help='length normalization constant')
|
||
|
test.add_argument('--target-bleu', default=None, type=float,
|
||
|
help='target accuracy')
|
||
|
test.add_argument('--intra-epoch-eval', default=0, type=int,
|
||
|
help='evaluate within epoch')
|
||
|
|
||
|
# checkpointing
|
||
|
checkpoint = parser.add_argument_group('checkpointing setup')
|
||
|
checkpoint.add_argument('--start-epoch', default=0, type=int,
|
||
|
help='manually set initial epoch counter')
|
||
|
checkpoint.add_argument('--resume', default=None, type=str, metavar='PATH',
|
||
|
help='resumes training from checkpoint from PATH')
|
||
|
checkpoint.add_argument('--save-all', action='store_true', default=False,
|
||
|
help='saves checkpoint after every epoch')
|
||
|
checkpoint.add_argument('--save-freq', default=5000, type=int,
|
||
|
help='save checkpoint every SAVE_FREQ batches')
|
||
|
checkpoint.add_argument('--keep-checkpoints', default=0, type=int,
|
||
|
help='keep only last KEEP_CHECKPOINTS checkpoints, \
|
||
|
affects only checkpoints controlled by --save-freq \
|
||
|
option')
|
||
|
|
||
|
# distributed support
|
||
|
distributed = parser.add_argument_group('distributed setup')
|
||
|
distributed.add_argument('--rank', default=0, type=int,
|
||
|
help='rank of the process, do not set! Done by multiproc module')
|
||
|
distributed.add_argument('--world-size', default=1, type=int,
|
||
|
help='number of processes, do not set! Done by multiproc module')
|
||
|
distributed.add_argument('--dist-url', default='tcp://localhost:23456', type=str,
|
||
|
help='url used to set up distributed training')
|
||
|
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
def build_criterion(vocab_size, padding_idx, smoothing):
|
||
|
if smoothing == 0.:
|
||
|
logging.info(f'Building CrossEntropyLoss')
|
||
|
loss_weight = torch.ones(vocab_size)
|
||
|
loss_weight[padding_idx] = 0
|
||
|
criterion = nn.CrossEntropyLoss(weight=loss_weight, size_average=False)
|
||
|
else:
|
||
|
logging.info(f'Building LabelSmoothingLoss (smoothing: {smoothing})')
|
||
|
criterion = LabelSmoothing(padding_idx, smoothing)
|
||
|
|
||
|
return criterion
|
||
|
|
||
|
|
||
|
def main():
|
||
|
"""
|
||
|
Launches data-parallel multi-gpu training.
|
||
|
"""
|
||
|
args = parse_args()
|
||
|
|
||
|
if not args.cudnn:
|
||
|
torch.backends.cudnn.enabled = False
|
||
|
if args.seed is not None:
|
||
|
torch.manual_seed(args.seed + args.rank)
|
||
|
|
||
|
# initialize distributed backend
|
||
|
distributed = args.world_size > 1
|
||
|
if distributed:
|
||
|
backend = 'nccl' if args.cuda else 'gloo'
|
||
|
dist.init_process_group(backend=backend, rank=args.rank,
|
||
|
init_method=args.dist_url,
|
||
|
world_size=args.world_size)
|
||
|
|
||
|
# create directory for results
|
||
|
save_path = os.path.join(args.results_dir, args.save)
|
||
|
args.save_path = save_path
|
||
|
os.makedirs(save_path, exist_ok=True)
|
||
|
|
||
|
# setup logging
|
||
|
log_filename = f'log_gpu_{args.rank}.log'
|
||
|
setup_logging(os.path.join(save_path, log_filename))
|
||
|
|
||
|
logging.info(f'Saving results to: {save_path}')
|
||
|
logging.info(f'Run arguments: {args}')
|
||
|
|
||
|
if args.cuda:
|
||
|
torch.cuda.set_device(args.rank)
|
||
|
|
||
|
# build tokenizer
|
||
|
tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME))
|
||
|
|
||
|
# build datasets
|
||
|
train_data = ParallelDataset(
|
||
|
src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME),
|
||
|
tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME),
|
||
|
tokenizer=tokenizer,
|
||
|
min_len=args.min_length_train,
|
||
|
max_len=args.max_length_train,
|
||
|
sort=False,
|
||
|
max_size=args.max_size)
|
||
|
|
||
|
val_data = ParallelDataset(
|
||
|
src_fname=os.path.join(args.dataset_dir, config.SRC_VAL_FNAME),
|
||
|
tgt_fname=os.path.join(args.dataset_dir, config.TGT_VAL_FNAME),
|
||
|
tokenizer=tokenizer,
|
||
|
min_len=args.min_length_val,
|
||
|
max_len=args.max_length_val,
|
||
|
sort=True)
|
||
|
|
||
|
test_data = TextDataset(
|
||
|
src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME),
|
||
|
tokenizer=tokenizer,
|
||
|
min_len=args.min_length_test,
|
||
|
max_len=args.max_length_test,
|
||
|
sort=False)
|
||
|
|
||
|
vocab_size = tokenizer.vocab_size
|
||
|
|
||
|
# build GNMT model
|
||
|
model_config = dict(vocab_size=vocab_size, math=args.math,
|
||
|
**literal_eval(args.model_config))
|
||
|
model = GNMT(**model_config)
|
||
|
logging.info(model)
|
||
|
|
||
|
batch_first = model.batch_first
|
||
|
|
||
|
# define loss function (criterion) and optimizer
|
||
|
criterion = build_criterion(vocab_size, config.PAD, args.smoothing)
|
||
|
opt_config = literal_eval(args.optimization_config)
|
||
|
logging.info(f'Training optimizer: {opt_config}')
|
||
|
|
||
|
num_parameters = sum([l.nelement() for l in model.parameters()])
|
||
|
logging.info(f'Number of parameters: {num_parameters}')
|
||
|
|
||
|
# get data loaders
|
||
|
train_loader = train_data.get_loader(batch_size=args.batch_size,
|
||
|
batch_first=batch_first,
|
||
|
shuffle=True,
|
||
|
bucketing=args.bucketing,
|
||
|
num_workers=args.workers,
|
||
|
drop_last=True)
|
||
|
|
||
|
val_loader = val_data.get_loader(batch_size=args.val_batch_size,
|
||
|
batch_first=batch_first,
|
||
|
shuffle=False,
|
||
|
num_workers=args.workers,
|
||
|
drop_last=False)
|
||
|
|
||
|
test_loader = test_data.get_loader(batch_size=args.test_batch_size,
|
||
|
batch_first=batch_first,
|
||
|
shuffle=False,
|
||
|
num_workers=args.workers,
|
||
|
drop_last=False)
|
||
|
|
||
|
translator = Translator(model=model,
|
||
|
tokenizer=tokenizer,
|
||
|
loader=test_loader,
|
||
|
beam_size=args.beam_size,
|
||
|
max_seq_len=args.max_length_test,
|
||
|
len_norm_factor=args.len_norm_factor,
|
||
|
len_norm_const=args.len_norm_const,
|
||
|
cov_penalty_factor=args.cov_penalty_factor,
|
||
|
cuda=args.cuda,
|
||
|
print_freq=args.print_freq,
|
||
|
dataset_dir=args.dataset_dir,
|
||
|
target_bleu=args.target_bleu,
|
||
|
save_path=args.save_path)
|
||
|
|
||
|
# create trainer
|
||
|
trainer_options = dict(
|
||
|
criterion=criterion,
|
||
|
grad_clip=args.grad_clip,
|
||
|
save_path=save_path,
|
||
|
save_freq=args.save_freq,
|
||
|
save_info={'config': args, 'tokenizer': tokenizer},
|
||
|
opt_config=opt_config,
|
||
|
batch_first=batch_first,
|
||
|
keep_checkpoints=args.keep_checkpoints,
|
||
|
math=args.math,
|
||
|
print_freq=args.print_freq,
|
||
|
cuda=args.cuda,
|
||
|
distributed=distributed,
|
||
|
intra_epoch_eval=args.intra_epoch_eval,
|
||
|
translator=translator)
|
||
|
|
||
|
trainer_options['model'] = model
|
||
|
trainer = trainers.Seq2SeqTrainer(**trainer_options)
|
||
|
|
||
|
# optionally resume from a checkpoint
|
||
|
if args.resume:
|
||
|
checkpoint_file = args.resume
|
||
|
if os.path.isdir(checkpoint_file):
|
||
|
checkpoint_file = os.path.join(
|
||
|
checkpoint_file, 'model_best.pth')
|
||
|
if os.path.isfile(checkpoint_file):
|
||
|
trainer.load(checkpoint_file)
|
||
|
else:
|
||
|
logging.error(f'No checkpoint found at {args.resume}')
|
||
|
|
||
|
# training loop
|
||
|
best_loss = float('inf')
|
||
|
for epoch in range(args.start_epoch, args.epochs):
|
||
|
logging.info(f'Starting epoch {epoch}')
|
||
|
|
||
|
if distributed:
|
||
|
train_loader.sampler.set_epoch(epoch)
|
||
|
|
||
|
trainer.epoch = epoch
|
||
|
train_loss, train_perf = trainer.optimize(train_loader)
|
||
|
|
||
|
# evaluate on validation set
|
||
|
if args.rank == 0 and not args.disable_eval:
|
||
|
logging.info(f'Running validation on dev set')
|
||
|
val_loss, val_perf = trainer.evaluate(val_loader)
|
||
|
|
||
|
# remember best prec@1 and save checkpoint
|
||
|
is_best = val_loss < best_loss
|
||
|
best_loss = min(val_loss, best_loss)
|
||
|
trainer.save(save_all=args.save_all, is_best=is_best)
|
||
|
|
||
|
break_training = False
|
||
|
if not args.disable_eval:
|
||
|
test_bleu, break_training = translator.run(calc_bleu=True,
|
||
|
epoch=epoch)
|
||
|
|
||
|
if args.rank == 0 and not args.disable_eval:
|
||
|
logging.info(f'Summary: Epoch: {epoch}\t'
|
||
|
f'Training Loss: {train_loss:.4f}\t'
|
||
|
f'Validation Loss: {val_loss:.4f}\t'
|
||
|
f'Test BLEU: {test_bleu:.2f}')
|
||
|
logging.info(f'Performance: Epoch: {epoch}\t'
|
||
|
f'Training: {train_perf:.0f} Tok/s\t'
|
||
|
f'Validation: {val_perf:.0f} Tok/s')
|
||
|
else:
|
||
|
logging.info(f'Summary: Epoch: {epoch}\t'
|
||
|
f'Training Loss {train_loss:.4f}')
|
||
|
logging.info(f'Performance: Epoch: {epoch}\t'
|
||
|
f'Training: {train_perf:.0f} Tok/s')
|
||
|
|
||
|
logging.info(f'Finished epoch {epoch}')
|
||
|
if break_training:
|
||
|
break
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|