214 lines
8.3 KiB
Python
214 lines
8.3 KiB
Python
#!/usr/bin/env python
|
|
import logging
|
|
import argparse
|
|
import warnings
|
|
from ast import literal_eval
|
|
|
|
import torch
|
|
|
|
from seq2seq.models.gnmt import GNMT
|
|
from seq2seq.inference.inference import Translator
|
|
from seq2seq.data.dataset import TextDataset
|
|
|
|
|
|
def parse_args():
|
|
"""
|
|
Parse commandline arguments.
|
|
"""
|
|
parser = argparse.ArgumentParser(description='GNMT Translate',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
# data
|
|
dataset = parser.add_argument_group('data setup')
|
|
dataset.add_argument('--dataset-dir', default='data/wmt16_de_en/',
|
|
help='path to directory with training/validation data')
|
|
dataset.add_argument('-i', '--input', required=True,
|
|
help='full path to the input file (tokenized)')
|
|
dataset.add_argument('-o', '--output', required=True,
|
|
help='full path to the output file (tokenized)')
|
|
dataset.add_argument('-r', '--reference', default=None,
|
|
help='full path to the reference file (for sacrebleu)')
|
|
dataset.add_argument('-m', '--model', required=True,
|
|
help='full path to the model checkpoint file')
|
|
# parameters
|
|
params = parser.add_argument_group('inference setup')
|
|
params.add_argument('--batch-size', default=128, type=int,
|
|
help='batch size')
|
|
params.add_argument('--beam-size', default=5, type=int,
|
|
help='beam size')
|
|
params.add_argument('--max-seq-len', default=80, type=int,
|
|
help='maximum generated sequence length')
|
|
params.add_argument('--len-norm-factor', default=0.6, type=float,
|
|
help='length normalization factor')
|
|
params.add_argument('--cov-penalty-factor', default=0.1, type=float,
|
|
help='coverage penalty factor')
|
|
params.add_argument('--len-norm-const', default=5.0, type=float,
|
|
help='length normalization constant')
|
|
# general setup
|
|
general = parser.add_argument_group('general setup')
|
|
general.add_argument('--math', default='fp16', choices=['fp32', 'fp16'],
|
|
help='arithmetic type')
|
|
|
|
bleu_parser = general.add_mutually_exclusive_group(required=False)
|
|
bleu_parser.add_argument('--bleu', dest='bleu', action='store_true',
|
|
help='compares with reference and computes BLEU \
|
|
(use \'--no-bleu\' to disable)')
|
|
bleu_parser.add_argument('--no-bleu', dest='bleu', action='store_false',
|
|
help=argparse.SUPPRESS)
|
|
bleu_parser.set_defaults(bleu=True)
|
|
|
|
batch_first_parser = general.add_mutually_exclusive_group(required=False)
|
|
batch_first_parser.add_argument('--batch-first', dest='batch_first',
|
|
action='store_true',
|
|
help='uses (batch, seq, feature) data \
|
|
format for RNNs')
|
|
batch_first_parser.add_argument('--seq-first', dest='batch_first',
|
|
action='store_false',
|
|
help='uses (seq, batch, feature) data \
|
|
format for RNNs')
|
|
batch_first_parser.set_defaults(batch_first=True)
|
|
|
|
cuda_parser = general.add_mutually_exclusive_group(required=False)
|
|
cuda_parser.add_argument('--cuda', dest='cuda', action='store_true',
|
|
help='enables cuda (use \'--no-cuda\' to disable)')
|
|
cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false',
|
|
help=argparse.SUPPRESS)
|
|
cuda_parser.set_defaults(cuda=True)
|
|
|
|
cudnn_parser = general.add_mutually_exclusive_group(required=False)
|
|
cudnn_parser.add_argument('--cudnn', dest='cudnn', action='store_true',
|
|
help='enables cudnn (use \'--no-cudnn\' to disable)')
|
|
cudnn_parser.add_argument('--no-cudnn', dest='cudnn', action='store_false',
|
|
help=argparse.SUPPRESS)
|
|
cudnn_parser.set_defaults(cudnn=True)
|
|
|
|
general.add_argument('--print-freq', '-p', default=1, type=int,
|
|
help='print log every PRINT_FREQ batches')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.bleu and args.reference is None:
|
|
parser.error('--bleu requires --reference')
|
|
|
|
return args
|
|
|
|
|
|
def checkpoint_from_distributed(state_dict):
|
|
"""
|
|
Checks whether checkpoint was generated by DistributedDataParallel. DDP
|
|
wraps model in additional "module.", it needs to be unwrapped for single
|
|
GPU inference.
|
|
|
|
:param state_dict: model's state dict
|
|
"""
|
|
ret = False
|
|
for key, _ in state_dict.items():
|
|
if key.find('module.') != -1:
|
|
ret = True
|
|
break
|
|
return ret
|
|
|
|
|
|
def unwrap_distributed(state_dict):
|
|
"""
|
|
Unwraps model from DistributedDataParallel.
|
|
DDP wraps model in additional "module.", it needs to be removed for single
|
|
GPU inference.
|
|
|
|
:param state_dict: model's state dict
|
|
"""
|
|
new_state_dict = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace('module.', '')
|
|
new_state_dict[new_key] = value
|
|
return new_state_dict
|
|
|
|
|
|
def main():
|
|
"""
|
|
Launches translation (inference).
|
|
Inference is executed on a single GPU, implementation supports beam search
|
|
with length normalization and coverage penalty.
|
|
"""
|
|
args = parse_args()
|
|
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
filename='log.log',
|
|
filemode='w')
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(message)s')
|
|
console.setFormatter(formatter)
|
|
logging.getLogger('').addHandler(console)
|
|
|
|
logging.info(args)
|
|
|
|
if args.cuda:
|
|
torch.cuda.set_device(0)
|
|
if not args.cuda and torch.cuda.is_available():
|
|
warnings.warn('cuda is available but not enabled')
|
|
if args.math == 'fp16' and not args.cuda:
|
|
raise RuntimeError('fp16 requires cuda')
|
|
if not args.cudnn:
|
|
torch.backends.cudnn.enabled = False
|
|
|
|
# load checkpoint and deserialize to CPU (to save GPU memory)
|
|
checkpoint = torch.load(args.model, map_location={'cuda:0': 'cpu'})
|
|
|
|
# build GNMT model
|
|
tokenizer = checkpoint['tokenizer']
|
|
vocab_size = tokenizer.vocab_size
|
|
model_config = dict(vocab_size=vocab_size, math=checkpoint['config'].math,
|
|
**literal_eval(checkpoint['config'].model_config))
|
|
model_config['batch_first'] = args.batch_first
|
|
model = GNMT(**model_config)
|
|
|
|
state_dict = checkpoint['state_dict']
|
|
if checkpoint_from_distributed(state_dict):
|
|
state_dict = unwrap_distributed(state_dict)
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
if args.math == 'fp32':
|
|
dtype = torch.FloatTensor
|
|
if args.math == 'fp16':
|
|
dtype = torch.HalfTensor
|
|
|
|
model.type(dtype)
|
|
if args.cuda:
|
|
model = model.cuda()
|
|
model.eval()
|
|
|
|
# construct the dataset
|
|
test_data = TextDataset(src_fname=args.input,
|
|
tokenizer=tokenizer,
|
|
sort=False)
|
|
|
|
# build the data loader
|
|
test_loader = test_data.get_loader(batch_size=args.batch_size,
|
|
batch_first=args.batch_first,
|
|
shuffle=False,
|
|
num_workers=0,
|
|
drop_last=False)
|
|
|
|
# build the translator object
|
|
translator = Translator(model=model,
|
|
tokenizer=tokenizer,
|
|
loader=test_loader,
|
|
beam_size=args.beam_size,
|
|
max_seq_len=args.max_seq_len,
|
|
len_norm_factor=args.len_norm_factor,
|
|
len_norm_const=args.len_norm_const,
|
|
cov_penalty_factor=args.cov_penalty_factor,
|
|
cuda=args.cuda,
|
|
print_freq=args.print_freq,
|
|
dataset_dir=args.dataset_dir)
|
|
|
|
# execute the inference
|
|
translator.run(calc_bleu=args.bleu, eval_path=args.output,
|
|
reference_path=args.reference, summary=True)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|