2020-07-04 02:28:25 +02:00
|
|
|
#!/usr/bin/env python3 -u
|
2019-01-23 17:03:45 +01:00
|
|
|
# Copyright (c) 2017-present, Facebook, Inc.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the license found in the LICENSE file in
|
|
|
|
# the root directory of this source tree. An additional grant of patent rights
|
|
|
|
# can be found in the PATENTS file in the same directory.
|
|
|
|
#
|
|
|
|
#-------------------------------------------------------------------------
|
|
|
|
#
|
|
|
|
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import collections
|
|
|
|
import os
|
|
|
|
import math
|
|
|
|
import time
|
|
|
|
import ctypes
|
|
|
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
import torch
|
|
|
|
import sacrebleu
|
|
|
|
import dllogger as DLLogger
|
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
from fairseq import data, distributed_utils, options, utils, tokenizer
|
|
|
|
from fairseq.ddp_trainer import DDPTrainer
|
2021-06-30 13:50:08 +02:00
|
|
|
from fairseq.meters import StopwatchMeter
|
2019-01-23 17:03:45 +01:00
|
|
|
from fairseq.sequence_generator import SequenceGenerator
|
2021-06-30 13:50:08 +02:00
|
|
|
from fairseq.data import data_utils, load_dataset_splits
|
2020-07-04 02:28:25 +02:00
|
|
|
from fairseq.models import build_model
|
2021-06-30 13:50:08 +02:00
|
|
|
from fairseq.log_helper import setup_logger, reset_perf_meters
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
def main(args):
|
2020-07-04 02:28:25 +02:00
|
|
|
|
|
|
|
print(args)
|
|
|
|
setup_logger(args)
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
if not torch.cuda.is_available():
|
|
|
|
raise NotImplementedError('Training on CPU is not supported')
|
|
|
|
torch.cuda.set_device(args.device_id)
|
|
|
|
if args.distributed_world_size > 1:
|
2021-06-30 13:50:08 +02:00
|
|
|
assert torch.distributed.is_initialized()
|
2019-01-23 17:03:45 +01:00
|
|
|
torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
2020-07-04 02:28:25 +02:00
|
|
|
ctypes.CDLL('libcudart.so').cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
|
|
|
ctypes.CDLL('libcudart.so').cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
2019-01-23 17:03:45 +01:00
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
src_dict, tgt_dict = data_utils.load_dictionaries(args)
|
|
|
|
add_extra_items_to_checkpoint({'src_dict': src_dict, 'tgt_dict': tgt_dict})
|
|
|
|
datasets = load_dataset_splits(args, ['train', 'valid', 'test'], src_dict, tgt_dict)
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
model = build_model(args)
|
2019-01-23 17:03:45 +01:00
|
|
|
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
|
|
|
|
|
|
|
|
# Build trainer
|
2020-07-04 02:28:25 +02:00
|
|
|
if torch.cuda.get_device_capability(0)[0] >= 7 and not args.amp:
|
|
|
|
print('| NOTICE: your device may support faster training with --amp')
|
|
|
|
trainer = DDPTrainer(args, model)
|
|
|
|
print('| model {}, criterion {}'.format(args.arch, trainer.criterion.__class__.__name__))
|
2019-01-23 17:03:45 +01:00
|
|
|
print('| training on {} GPUs'.format(args.distributed_world_size))
|
|
|
|
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
|
|
|
|
args.max_tokens,
|
|
|
|
args.max_sentences,
|
|
|
|
))
|
2020-07-04 02:28:25 +02:00
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
epoch_itr = data.EpochBatchIterator(
|
2020-07-04 02:28:25 +02:00
|
|
|
dataset=datasets[args.train_subset],
|
2019-01-23 17:03:45 +01:00
|
|
|
max_tokens=args.max_tokens,
|
|
|
|
max_sentences=args.max_sentences_valid,
|
2020-07-04 02:28:25 +02:00
|
|
|
max_positions=args.max_positions,
|
2019-01-23 17:03:45 +01:00
|
|
|
required_batch_size_multiple=8,
|
|
|
|
seed=args.seed,
|
|
|
|
num_shards=args.distributed_world_size,
|
|
|
|
shard_id=args.distributed_rank,
|
|
|
|
)
|
|
|
|
# Load the latest checkpoint if one is available
|
|
|
|
load_checkpoint(args, trainer, epoch_itr)
|
|
|
|
|
|
|
|
# Send a dummy batch to warm the caching allocator
|
2020-07-04 02:28:25 +02:00
|
|
|
dummy_batch = data_utils.get_dummy_batch(args.max_tokens, src_dict, tgt_dict)
|
2019-01-23 17:03:45 +01:00
|
|
|
trainer.dummy_train_step(dummy_batch)
|
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
# Sanity check
|
|
|
|
if args.do_sanity_check:
|
|
|
|
print('Performing sanity check...')
|
|
|
|
sanity_score = score(args, trainer, datasets['test'], src_dict, tgt_dict, 'test.raw.de')
|
|
|
|
DLLogger.log(step='SANITY_CHECK', data={'sanity_check_score': sanity_score}, verbosity=1)
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
# Train until the learning rate gets too small or model reaches target score
|
|
|
|
max_epoch = args.max_epoch or math.inf
|
|
|
|
max_update = args.max_update or math.inf
|
|
|
|
tgt_bleu = args.target_bleu or math.inf
|
|
|
|
current_bleu = 0.0
|
2020-07-04 02:28:25 +02:00
|
|
|
best_bleu = -1.0
|
2019-01-23 17:03:45 +01:00
|
|
|
lr = trainer.get_lr()
|
|
|
|
train_meter = StopwatchMeter()
|
|
|
|
train_meter.start()
|
|
|
|
valid_losses = [None]
|
|
|
|
valid_subsets = args.valid_subset.split(',')
|
2020-07-04 02:28:25 +02:00
|
|
|
run_summary = {'loss': float('inf'),
|
|
|
|
'val_loss': float('inf'),
|
|
|
|
'speed': 0,
|
|
|
|
'accuracy': 0}
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2019-07-16 21:13:08 +02:00
|
|
|
while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update and current_bleu < tgt_bleu:
|
2021-06-30 13:50:08 +02:00
|
|
|
DLLogger.log(step=trainer.get_num_updates()+1, data={'epoch': epoch_itr.epoch}, verbosity=0)
|
2019-01-23 17:03:45 +01:00
|
|
|
# train for one epoch
|
2021-08-13 18:36:58 +02:00
|
|
|
train(args, trainer, epoch_itr)
|
2021-06-30 13:50:08 +02:00
|
|
|
DLLogger.log(step=trainer.get_num_updates(), data={'walltime': train_meter.sum}, verbosity=1)
|
|
|
|
DLLogger.log(step=trainer.get_num_updates(),
|
|
|
|
data={'avg_epoch_loss': trainer.avg_loss_meter.avg}, verbosity=1)
|
2020-07-04 02:28:25 +02:00
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
if epoch_itr.epoch % args.validate_interval == 0:
|
2020-07-04 02:28:25 +02:00
|
|
|
valid_losses = validate(args, trainer, datasets, valid_subsets)
|
|
|
|
valid_bleu = score(args, trainer, datasets[valid_subsets[0]], src_dict, tgt_dict, 'valid.raw.de')
|
2021-06-30 13:50:08 +02:00
|
|
|
DLLogger.log(step=trainer.get_num_updates(),
|
|
|
|
data={'val_loss': valid_losses[0], 'val_bleu': valid_bleu}, verbosity=1)
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
# Eval BLEU score
|
2021-06-30 13:50:08 +02:00
|
|
|
if args.online_eval or (tgt_bleu is not math.inf):
|
2020-07-04 02:28:25 +02:00
|
|
|
current_bleu = score(args, trainer, datasets[args.gen_subset], src_dict, tgt_dict, 'test.raw.de')
|
|
|
|
DLLogger.log(step=trainer.get_num_updates(), data={'test_bleu': current_bleu}, verbosity=1)
|
2021-06-30 13:50:08 +02:00
|
|
|
best_bleu = max(best_bleu, current_bleu)
|
|
|
|
|
|
|
|
run_summary['val_loss'] = min(run_summary['val_loss'], valid_losses[0])
|
|
|
|
run_summary['accuracy'] = best_bleu if best_bleu >= 0 else valid_bleu
|
2020-07-04 02:28:25 +02:00
|
|
|
run_summary['loss'] = valid_losses[0]
|
|
|
|
run_summary['speed'] = trainer.throughput_meter.u_avg
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
# Only use first validation loss to update the learning rate
|
|
|
|
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
train_meter.stop()
|
2021-06-30 13:50:08 +02:00
|
|
|
run_summary['walltime'] = train_meter.sum
|
|
|
|
DLLogger.log(step=(), data=run_summary, verbosity=0)
|
2019-01-23 17:03:45 +01:00
|
|
|
print('| done training in {:.1f} seconds'.format(train_meter.sum))
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
def train(args, trainer, epoch_itr):
|
2019-01-23 17:03:45 +01:00
|
|
|
"""Train the model for one epoch."""
|
|
|
|
|
|
|
|
# Initialize data iterator
|
|
|
|
itr = epoch_itr.next_epoch_itr()
|
|
|
|
|
|
|
|
# update parameters every N batches
|
|
|
|
if epoch_itr.epoch <= len(args.update_freq):
|
|
|
|
update_freq = args.update_freq[epoch_itr.epoch - 1]
|
|
|
|
else:
|
|
|
|
update_freq = args.update_freq[-1]
|
|
|
|
|
|
|
|
max_update = args.max_update or math.inf
|
|
|
|
num_batches = len(epoch_itr)
|
|
|
|
begin = time.time()
|
2020-07-04 02:28:25 +02:00
|
|
|
|
|
|
|
# reset meters
|
|
|
|
DLLogger.flush()
|
|
|
|
trainer.get_throughput_meter().reset()
|
|
|
|
|
|
|
|
for i, sample in enumerate(itr):
|
2019-01-23 17:03:45 +01:00
|
|
|
if i < num_batches - 1 and (i + 1) % update_freq > 0:
|
|
|
|
# buffer updates according to --update-freq
|
|
|
|
trainer.train_step(sample, update_params=False, last_step=(i == len(itr)-1))
|
|
|
|
continue
|
|
|
|
else:
|
2020-07-04 02:28:25 +02:00
|
|
|
trainer.train_step(sample, update_params=True, last_step=(i == len(itr)-1))
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
# ignore the first mini-batch in words-per-second calculation
|
|
|
|
if i == 0:
|
2020-07-04 02:28:25 +02:00
|
|
|
trainer.get_throughput_meter().reset()
|
2021-06-30 13:50:08 +02:00
|
|
|
reset_perf_meters()
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
if (i+1) % args.log_interval == 0:
|
|
|
|
DLLogger.flush()
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
if trainer.get_num_updates() >= max_update:
|
2019-01-23 17:03:45 +01:00
|
|
|
break
|
|
|
|
|
|
|
|
print('Epoch time:', time.time() - begin)
|
2020-07-04 02:28:25 +02:00
|
|
|
|
|
|
|
# Print epoch stats and reset training meters
|
2021-06-30 13:50:08 +02:00
|
|
|
DLLogger.log(step=trainer.get_num_updates(),
|
|
|
|
data={'speed': trainer.get_throughput_meter().avg}, verbosity=0)
|
2020-07-04 02:28:25 +02:00
|
|
|
DLLogger.flush()
|
|
|
|
|
|
|
|
def validate(args, trainer, datasets, subsets):
|
2019-01-23 17:03:45 +01:00
|
|
|
"""Evaluate the model on the validation set(s) and return the losses."""
|
2020-07-04 02:28:25 +02:00
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
valid_losses = []
|
|
|
|
for subset in subsets:
|
2020-07-04 02:28:25 +02:00
|
|
|
|
|
|
|
if len(subsets) > 1:
|
|
|
|
print('Validating on \'{}\' subset'.format(subset))
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
# Initialize data iterator
|
|
|
|
itr = data.EpochBatchIterator(
|
2020-07-04 02:28:25 +02:00
|
|
|
dataset=datasets[subset],
|
2019-01-23 17:03:45 +01:00
|
|
|
max_tokens=args.max_tokens,
|
|
|
|
max_sentences=args.max_sentences_valid,
|
2020-07-04 02:28:25 +02:00
|
|
|
max_positions=args.max_positions,
|
2019-01-23 17:03:45 +01:00
|
|
|
required_batch_size_multiple=8,
|
|
|
|
seed=args.seed,
|
|
|
|
num_shards=args.distributed_world_size,
|
|
|
|
shard_id=args.distributed_rank,
|
|
|
|
).next_epoch_itr(shuffle=False)
|
|
|
|
|
|
|
|
# reset validation loss meters
|
2020-07-04 02:28:25 +02:00
|
|
|
DLLogger.flush()
|
|
|
|
|
|
|
|
subset_losses = []
|
|
|
|
for sample in itr:
|
|
|
|
loss = trainer.valid_step(sample)
|
|
|
|
subset_losses.append(loss)
|
|
|
|
subset_loss = sum(subset_losses)/len(subset_losses)
|
|
|
|
|
|
|
|
DLLogger.flush()
|
|
|
|
|
|
|
|
valid_losses.append(subset_loss)
|
|
|
|
print(f'Validation loss on subset {subset}: {subset_loss}')
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
return valid_losses
|
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
def score(args, trainer, dataset, src_dict, tgt_dict, ref_file):
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
begin = time.time()
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
src_dict = deepcopy(src_dict) # This is necessary, generation of translations
|
|
|
|
tgt_dict = deepcopy(tgt_dict) # alters target dictionary messing up with the rest of training
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
model = trainer.get_model()
|
|
|
|
|
|
|
|
# Initialize data iterator
|
|
|
|
itr = data.EpochBatchIterator(
|
2020-07-04 02:28:25 +02:00
|
|
|
dataset=dataset,
|
2019-01-23 17:03:45 +01:00
|
|
|
max_tokens=None,
|
2021-06-30 13:50:08 +02:00
|
|
|
max_sentences=max(8, min(math.ceil(1024/args.distributed_world_size), 128)),
|
2020-07-04 02:28:25 +02:00
|
|
|
max_positions=args.max_positions,
|
2019-01-23 17:03:45 +01:00
|
|
|
required_batch_size_multiple=8,
|
|
|
|
num_shards=args.distributed_world_size,
|
|
|
|
shard_id=args.distributed_rank,
|
|
|
|
).next_epoch_itr(shuffle=False)
|
|
|
|
|
|
|
|
# Initialize generator
|
|
|
|
gen_timer = StopwatchMeter()
|
|
|
|
translator = SequenceGenerator(
|
2020-07-04 02:28:25 +02:00
|
|
|
[model],
|
|
|
|
tgt_dict.get_metadata(),
|
2021-06-30 13:50:08 +02:00
|
|
|
maxlen=args.max_target_positions - 1, # do not include EOS token
|
2020-07-04 02:28:25 +02:00
|
|
|
beam_size=args.beam,
|
2019-01-23 17:03:45 +01:00
|
|
|
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
|
|
|
|
len_penalty=args.lenpen, unk_penalty=args.unkpen,
|
|
|
|
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
|
2021-06-30 13:50:08 +02:00
|
|
|
use_amp=args.amp,
|
2019-01-23 17:03:45 +01:00
|
|
|
)
|
|
|
|
# Generate and compute BLEU
|
2019-05-27 15:24:14 +02:00
|
|
|
predictions = []
|
2020-07-04 02:28:25 +02:00
|
|
|
translations = translator.generate_batched_itr(
|
|
|
|
itr, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
|
|
|
|
cuda=True, timer=gen_timer, prefix_size=args.prefix_size,
|
|
|
|
)
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
for sample_id, src_tokens, _, hypos in translations:
|
2020-07-04 02:28:25 +02:00
|
|
|
# Process input and grount truth
|
|
|
|
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
|
|
|
|
|
|
|
# Process top predictions
|
|
|
|
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
|
2021-06-30 13:50:08 +02:00
|
|
|
_, hypo_str, _ = utils.post_process_prediction(
|
|
|
|
hypo_tokens=hypo['tokens'].int().cpu(),
|
|
|
|
src_str=src_str,
|
|
|
|
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
|
|
|
|
align_dict=None,
|
|
|
|
tgt_dict=tgt_dict,
|
|
|
|
remove_bpe=args.remove_bpe
|
|
|
|
)
|
2020-07-04 02:28:25 +02:00
|
|
|
|
|
|
|
# Score only the top hypothesis
|
2021-06-30 13:50:08 +02:00
|
|
|
if i == 0:
|
|
|
|
hypo_str = tokenizer.Tokenizer.detokenize(hypo_str, 'de')
|
2020-07-04 02:28:25 +02:00
|
|
|
predictions.append('{}\t{}'.format(sample_id, hypo_str))
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
if args.distributed_world_size > 1:
|
2019-07-16 21:13:08 +02:00
|
|
|
predictions = _all_gather_predictions(predictions)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
with open(os.path.join(args.data, ref_file), 'r') as reference:
|
2019-05-27 15:24:14 +02:00
|
|
|
refs = [reference.readlines()]
|
2021-06-30 13:50:08 +02:00
|
|
|
# reducing indexed predictions as strings is more memory efficient than reducing tuples
|
2019-05-27 15:24:14 +02:00
|
|
|
predictions = [tuple(item.split('\t')) for item in predictions]
|
|
|
|
predictions = [(int(item[0]), item[1]) for item in predictions]
|
|
|
|
predictions.sort(key=lambda tup: tup[0])
|
2021-06-30 13:50:08 +02:00
|
|
|
predictions = [hypo[1] + ('\n' if hypo[1][-1] != '\n' else '') for hypo in predictions]
|
2020-07-04 02:28:25 +02:00
|
|
|
sacrebleu_score = sacrebleu.corpus_bleu(predictions, refs, lowercase=not args.test_cased_bleu).score
|
2021-06-30 13:50:08 +02:00
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
if args.save_predictions:
|
|
|
|
os.makedirs(os.path.join(args.save_dir, 'predictions'), exist_ok=True)
|
2021-06-30 13:50:08 +02:00
|
|
|
fname = ref_file + '.pred.update_{}'.format(trainer.get_num_updates())
|
|
|
|
save_path = os.path.join(args.save_dir, 'predictions', fname)
|
|
|
|
with open(save_path, 'w') as f:
|
2020-07-04 02:28:25 +02:00
|
|
|
f.write(''.join(predictions))
|
|
|
|
|
|
|
|
DLLogger.log(step=trainer.get_num_updates(),
|
2021-06-30 13:50:08 +02:00
|
|
|
data={'inference tokens/s': float(args.distributed_world_size) / gen_timer.avg},
|
|
|
|
verbosity=0)
|
2020-07-04 02:28:25 +02:00
|
|
|
DLLogger.flush()
|
2019-01-23 17:03:45 +01:00
|
|
|
if gen_timer.sum != 0:
|
|
|
|
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
|
2021-06-30 13:50:08 +02:00
|
|
|
len(predictions),
|
|
|
|
gen_timer.n,
|
|
|
|
gen_timer.sum,
|
|
|
|
len(predictions) / gen_timer.sum,
|
|
|
|
float(args.distributed_world_size)/gen_timer.avg
|
|
|
|
))
|
|
|
|
|
|
|
|
print('| Eval completed in: {:.2f}s | {}CASED BLEU {:.2f}'.format(
|
|
|
|
time.time()-begin,
|
|
|
|
'' if args.test_cased_bleu else 'UN',
|
|
|
|
sacrebleu_score
|
|
|
|
))
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2020-07-04 02:28:25 +02:00
|
|
|
return sacrebleu_score
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2019-07-16 21:13:08 +02:00
|
|
|
def _all_gather_predictions(predictions):
|
|
|
|
ready = False
|
|
|
|
all_ready = False
|
|
|
|
reduced_predictions = []
|
|
|
|
max_size = 65000
|
|
|
|
while not all_ready:
|
|
|
|
lst_len = len(predictions)
|
2021-06-30 13:50:08 +02:00
|
|
|
size = 2000 # some extra space for python stuff
|
2019-07-16 21:13:08 +02:00
|
|
|
n = 0
|
|
|
|
while n < lst_len:
|
2021-06-30 13:50:08 +02:00
|
|
|
str_len = len(predictions[n].encode('utf8')) + 8 # per string pickle overhead
|
2019-07-16 21:13:08 +02:00
|
|
|
if size + str_len >= max_size:
|
|
|
|
break
|
|
|
|
size += str_len
|
|
|
|
n += 1
|
|
|
|
chunk = predictions[:n]
|
|
|
|
predictions = predictions[n:]
|
|
|
|
if not predictions:
|
|
|
|
ready = True
|
|
|
|
chunk = (ready, chunk)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
gathered = distributed_utils.all_gather_list(chunk, max_size=65000)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
reduced_predictions += [t[1] for t in gathered]
|
|
|
|
all_ready = all([t[0] for t in gathered])
|
|
|
|
|
|
|
|
reduced_predictions = [item for sublist in reduced_predictions for item in sublist]
|
|
|
|
|
|
|
|
return reduced_predictions
|
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
|
|
|
|
def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
2021-10-26 19:22:33 +02:00
|
|
|
if epoch_itr.epoch % args.save_interval != 0:
|
2021-06-30 13:50:08 +02:00
|
|
|
return
|
2019-01-23 17:03:45 +01:00
|
|
|
if args.no_save or not distributed_utils.is_master(args):
|
|
|
|
return
|
|
|
|
epoch = epoch_itr.epoch
|
|
|
|
end_of_epoch = epoch_itr.end_of_epoch()
|
|
|
|
|
|
|
|
checkpoint_conds = collections.OrderedDict()
|
2021-06-30 13:50:08 +02:00
|
|
|
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = end_of_epoch and not args.no_epoch_checkpoints
|
2019-01-23 17:03:45 +01:00
|
|
|
checkpoint_conds['checkpoint_best.pt'] = (
|
|
|
|
val_loss is not None and
|
|
|
|
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
|
|
|
|
)
|
|
|
|
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
|
|
|
|
|
|
|
|
prev_best = getattr(save_checkpoint, 'best', val_loss)
|
|
|
|
if val_loss is not None:
|
|
|
|
save_checkpoint.best = min(val_loss, prev_best)
|
|
|
|
extra_state = {
|
|
|
|
'best': save_checkpoint.best,
|
|
|
|
'train_iterator': epoch_itr.state_dict(),
|
|
|
|
'val_loss': val_loss,
|
|
|
|
}
|
2020-07-04 02:28:25 +02:00
|
|
|
extra_state.update(save_checkpoint.extra_items)
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
checkpoints = [os.path.join(args.save_dir, 'checkpoints', fn)
|
|
|
|
for fn, cond in checkpoint_conds.items() if cond]
|
|
|
|
if checkpoints:
|
2019-01-23 17:03:45 +01:00
|
|
|
for cp in checkpoints:
|
|
|
|
trainer.save_checkpoint(cp, extra_state)
|
|
|
|
|
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
def add_extra_items_to_checkpoint(items):
|
2020-07-04 02:28:25 +02:00
|
|
|
if not hasattr(save_checkpoint, 'extra_items'):
|
|
|
|
save_checkpoint.extra_items = {}
|
2021-06-30 13:50:08 +02:00
|
|
|
save_checkpoint.extra_items.update(items)
|
2020-07-04 02:28:25 +02:00
|
|
|
|
2019-01-23 17:03:45 +01:00
|
|
|
def load_checkpoint(args, trainer, epoch_itr):
|
|
|
|
"""Load a checkpoint and replay dataloader to match."""
|
2020-07-04 02:28:25 +02:00
|
|
|
os.makedirs(os.path.join(args.save_dir, 'checkpoints'), exist_ok=True)
|
|
|
|
checkpoint_path = os.path.join(args.save_dir, 'checkpoints', args.restore_file)
|
2019-01-23 17:03:45 +01:00
|
|
|
if os.path.isfile(checkpoint_path):
|
|
|
|
extra_state = trainer.load_checkpoint(checkpoint_path)
|
|
|
|
if extra_state is not None:
|
|
|
|
# replay train iterator to match checkpoint
|
|
|
|
epoch_itr.load_state_dict(extra_state['train_iterator'])
|
|
|
|
|
|
|
|
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
|
|
|
|
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
|
|
|
|
|
|
|
|
trainer.lr_step(epoch_itr.epoch)
|
|
|
|
trainer.lr_step_update(trainer.get_num_updates())
|
|
|
|
if 'best' in extra_state:
|
|
|
|
save_checkpoint.best = extra_state['best']
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = options.get_training_parser()
|
2020-07-04 02:28:25 +02:00
|
|
|
ARGS = options.parse_args_and_arch(parser)
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
if ARGS.distributed_world_size > 1:
|
|
|
|
distributed_utils.distributed_init(ARGS)
|
2019-01-23 17:03:45 +01:00
|
|
|
|
2021-06-30 13:50:08 +02:00
|
|
|
main(ARGS)
|