DeepLearningExamples/PyTorch/Translation/Transformer/train.py
2019-08-13 15:01:10 +02:00

549 lines
22 KiB
Python

#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import itertools
import os
import math
import torch
import time
import ctypes
import sys
from copy import deepcopy
from functools import reduce
from fairseq import data, distributed_utils, options, progress_bar, tasks, utils, bleu, tokenizer
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.data import dictionary
import sacrebleu
def main(args):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
torch.cuda.set_device(args.device_id)
if args.distributed_world_size > 1:
assert(torch.distributed.is_initialized())
torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
torch.cuda.synchronize()
if args.max_tokens is None:
args.max_tokens = 6000
print(args)
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
result = torch.cuda.cudart().cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
result = torch.cuda.cudart().cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(task, ['train', 'valid'])
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Build trainer
if args.fp16 and not args.amp:
trainer = FP16Trainer(args, task, model, criterion)
elif args.fp16 and args.amp:
raise ValueError('Cannot use AMP and fp16 simultaneously')
else:
if torch.cuda.get_device_capability(0)[0] >= 7 and not args.amp:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, task, model, criterion)
if (args.online_eval or args.target_bleu) and not args.remove_bpe:
args.remove_bpe='@@ '
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
args.max_sentences,
))
max_positions = trainer.get_model().max_positions()
epoch_itr = data.EpochBatchIterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
)
# Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr)
# Send a dummy batch to warm the caching allocator
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small or model reaches target score
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
tgt_bleu = args.target_bleu or math.inf
current_bleu = 0.0
best_bleu = 0.0
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',')
while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update and current_bleu < tgt_bleu:
# train for one epoch
train(args, trainer, task, epoch_itr)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# Eval BLEU score
if args.online_eval or (not tgt_bleu is math.inf):
current_bleu, current_sc_bleu = score(args, trainer, task, epoch_itr, args.gen_subset)
if current_bleu > best_bleu:
best_bleu = current_bleu
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
# Only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# Save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
if args.enable_parallel_backward_allred_opt and update_freq > 1:
raise RuntimeError('--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1')
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(epoch_itr)
begin = time.time()
#inside = 0
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False, last_step=(i == len(itr)-1))
continue
else:
log_output = trainer.train_step(sample, update_params=True, last_step=(i == len(itr)-1))
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# ignore the first mini-batch in words-per-second calculation
if i == 0:
trainer.get_meter('wps').reset()
if args.profile is not None and i == args.profile:
import sys
sys.exit()
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
print('Epoch time:', time.time() - begin)
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer):
stats = collections.OrderedDict()
stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
if trainer.get_meter('train_nll_loss').count > 0:
nll_loss = trainer.get_meter('train_nll_loss').avg
stats['nll_loss'] = '{:.3f}'.format(nll_loss)
else:
nll_loss = trainer.get_meter('train_loss').avg
stats['ppl'] = get_perplexity(nll_loss)
stats['wps'] = round(trainer.get_meter('wps').avg)
stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
stats['wpb'] = round(trainer.get_meter('wpb').avg)
stats['bsz'] = round(trainer.get_meter('bsz').avg)
stats['num_updates'] = trainer.get_num_updates()
stats['lr'] = trainer.get_lr()
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
return stats
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
for subset in subsets:
# Initialize data iterator
itr = data.EpochBatchIterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=trainer.get_model().max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
valid_losses.append(stats['valid_loss'])
return valid_losses
def score(args, trainer, task, epoch_itr, subset):
begin = time.time()
if not subset in task.datasets.keys():
task.load_dataset(subset)
src_dict = deepcopy(task.source_dictionary) # This is necessary, generation of translations
tgt_dict = deepcopy(task.target_dictionary) # alters target dictionary messing up with the rest of training
model = trainer.get_model()
# Initialize data iterator
itr = data.EpochBatchIterator(
dataset=task.dataset(subset),
max_tokens=None,
max_sentences=max(8,min(math.ceil(1024/args.distributed_world_size),128)),
max_positions=model.max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
# Initialize generator
gen_timer = StopwatchMeter()
translator = SequenceGenerator(
[model], tgt_dict, beam_size=args.beam,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
)
# Generate and compute BLEU
dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
num_sentences = 0
has_target = True
predictions = []
with progress_bar.build_progress_bar(args, itr) as progress:
translations = translator.generate_batched_itr(
progress, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=True, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and grount truth
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
src_str = src_dict.string(src_tokens, args.remove_bpe)
if has_target:
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict = None,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe
)
# Score only the top hypothesis
if has_target and i==0:
if args.sentencepiece:
hypo_str = hypo_str.replace(' ', '').replace('', ' ')
target_str = target_str.replace(' ', '').replace('', ' ')
sys_tok = tokenizer.Tokenizer.tokenize((hypo_str.lower() if args.ignore_case else hypo_str), dict)
ref_tok = tokenizer.Tokenizer.tokenize((target_str.lower() if args.ignore_case else target_str), dict)
scorer.add(ref_tok, sys_tok)
if not args.sentencepiece:
hypo_str = tokenizer.Tokenizer.detokenize(hypo_str, 'de')
predictions.append('{}\t{}'.format(sample_id, hypo_str))
wps_meter.update(src_tokens.size(0))
progress.log({'wps':round(wps_meter.avg)})
num_sentences += 1
if args.distributed_world_size > 1:
_all_gather_bleu_scorer(scorer)
predictions = _all_gather_predictions(predictions)
with open(os.path.join(args.data, 'sacrebleu_reference.de'), 'r') as reference:
refs = [reference.readlines()]
#reducing indexed predictions as strings is more memory efficient than reducing tuples
predictions = [tuple(item.split('\t')) for item in predictions]
predictions = [(int(item[0]), item[1]) for item in predictions]
predictions.sort(key=lambda tup: tup[0])
predictions = [hypo[1] + ('\n' if hypo[1][-1]!='\n' else '') for hypo in predictions]
sacrebleu_score = sacrebleu.corpus_bleu(predictions, refs, lowercase=args.ignore_case)
print(f'|Detokenized {sacrebleu_score}')
if gen_timer.sum != 0:
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1./gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(subset, args.beam, scorer.result_string()))
print('| Eval completed in: {:.2f}s'.format(time.time()-begin))
return scorer.score(order=4), sacrebleu_score.score
def _all_gather_predictions(predictions):
ready = False
all_ready = False
reduced_predictions = []
max_size = 65000
while not all_ready:
lst_len = len(predictions)
size = 2000 #some extra space for python stuff
n = 0
while n < lst_len:
str_len = len(predictions[n].encode('utf8')) + 8 # per string pickle overhead
if size + str_len >= max_size:
break
size += str_len
n += 1
chunk = predictions[:n]
predictions = predictions[n:]
if not predictions:
ready = True
chunk = (ready, chunk)
torch.cuda.synchronize()
gathered = distributed_utils.all_gather_list(chunk, max_size=65000)
torch.cuda.synchronize()
reduced_predictions += [t[1] for t in gathered]
all_ready = all([t[0] for t in gathered])
reduced_predictions = [item for sublist in reduced_predictions for item in sublist]
return reduced_predictions
def _all_gather_bleu_scorer(scorer):
stats = distributed_utils.all_gather_list(scorer.stat)
bleu_stat = bleu.BleuStat()
bleu_stat.reflen = reduce(lambda x,y: x+y, [s.reflen for s in stats])
bleu_stat.predlen = reduce(lambda x,y: x+y, [s.predlen for s in stats])
bleu_stat.match1 = reduce(lambda x,y: x+y, [s.match1 for s in stats])
bleu_stat.count1 = reduce(lambda x,y: x+y, [s.count1 for s in stats])
bleu_stat.match2 = reduce(lambda x,y: x+y, [s.match2 for s in stats])
bleu_stat.count2 = reduce(lambda x,y: x+y, [s.count2 for s in stats])
bleu_stat.match3 = reduce(lambda x,y: x+y, [s.match3 for s in stats])
bleu_stat.count3 = reduce(lambda x,y: x+y, [s.count3 for s in stats])
bleu_stat.match4 = reduce(lambda x,y: x+y, [s.match4 for s in stats])
bleu_stat.count4 = reduce(lambda x,y: x+y, [s.count4 for s in stats])
scorer.stat = bleu_stat
def get_valid_stats(trainer):
stats = collections.OrderedDict()
stats['valid_loss'] = trainer.get_meter('valid_loss').avg
if trainer.get_meter('valid_nll_loss').count > 0:
nll_loss = trainer.get_meter('valid_nll_loss').avg
stats['valid_nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('valid_loss').avg
stats['valid_ppl'] = get_perplexity(nll_loss)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
return stats
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'best': save_checkpoint.best,
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator'])
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
def load_dataset_splits(task, splits):
for split in splits:
if split == 'train':
task.load_dataset(split, combine=True)
else:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main
distributed_main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
multiprocessing_main(args)
else:
main(args)