DeepLearningExamples/PyTorch/Translation/GNMT/seq2seq/inference/beam_search.py
2019-02-14 12:40:30 +01:00

291 lines
11 KiB
Python

import torch
from seq2seq.data.config import BOS
from seq2seq.data.config import EOS
class SequenceGenerator:
"""
Generator for the autoregressive inference with beam search decoding.
"""
def __init__(self, model, beam_size=5, max_seq_len=100, cuda=False,
len_norm_factor=0.6, len_norm_const=5,
cov_penalty_factor=0.1):
"""
Constructor for the SequenceGenerator.
Beam search decoding supports coverage penalty and length
normalization. For details, refer to Section 7 of the GNMT paper
(https://arxiv.org/pdf/1609.08144.pdf).
:param model: model which implements generate method
:param beam_size: decoder beam size
:param max_seq_len: maximum decoder sequence length
:param cuda: whether to use cuda
:param len_norm_factor: length normalization factor
:param len_norm_const: length normalization constant
:param cov_penalty_factor: coverage penalty factor
"""
self.model = model
self.cuda = cuda
self.beam_size = beam_size
self.max_seq_len = max_seq_len
self.len_norm_factor = len_norm_factor
self.len_norm_const = len_norm_const
self.cov_penalty_factor = cov_penalty_factor
self.batch_first = self.model.batch_first
def greedy_search(self, batch_size, initial_input, initial_context=None):
"""
Greedy decoder.
:param batch_size: decoder batch size
:param initial_input: initial input, usually tensor of BOS tokens
:param initial_context: initial context, usually [encoder_context,
src_seq_lengths, None]
returns: (translation, lengths, counter)
translation: (batch_size, max_seq_len) - indices of target tokens
lengths: (batch_size) - lengths of generated translations
counter: number of iterations of the decoding loop
"""
max_seq_len = self.max_seq_len
translation = torch.zeros(batch_size, max_seq_len, dtype=torch.int64)
lengths = torch.ones(batch_size, dtype=torch.int64)
active = torch.arange(0, batch_size, dtype=torch.int64)
base_mask = torch.arange(0, batch_size, dtype=torch.int64)
if self.cuda:
translation = translation.cuda()
lengths = lengths.cuda()
active = active.cuda()
base_mask = base_mask.cuda()
translation[:, 0] = BOS
words, context = initial_input, initial_context
if self.batch_first:
word_view = (-1, 1)
ctx_batch_dim = 0
else:
word_view = (1, -1)
ctx_batch_dim = 1
counter = 0
for idx in range(1, max_seq_len):
if not len(active):
break
counter += 1
words = words.view(word_view)
output = self.model.generate(words, context, 1)
words, logprobs, attn, context = output
words = words.view(-1)
translation[active, idx] = words
lengths[active] += 1
terminating = (words == EOS)
if terminating.any():
not_terminating = ~terminating
mask = base_mask[:len(active)]
mask = mask.masked_select(not_terminating)
active = active.masked_select(not_terminating)
words = words[mask]
context[0] = context[0].index_select(ctx_batch_dim, mask)
context[1] = context[1].index_select(0, mask)
context[2] = context[2].index_select(1, mask)
return translation, lengths, counter
def beam_search(self, batch_size, initial_input, initial_context=None):
"""
Beam search decoder.
:param batch_size: decoder batch size
:param initial_input: initial input, usually tensor of BOS tokens
:param initial_context: initial context, usually [encoder_context,
src_seq_lengths, None]
returns: (translation, lengths, counter)
translation: (batch_size, max_seq_len) - indices of target tokens
lengths: (batch_size) - lengths of generated translations
counter: number of iterations of the decoding loop
"""
beam_size = self.beam_size
norm_const = self.len_norm_const
norm_factor = self.len_norm_factor
max_seq_len = self.max_seq_len
cov_penalty_factor = self.cov_penalty_factor
translation = torch.zeros(batch_size * beam_size, max_seq_len,
dtype=torch.int64)
lengths = torch.ones(batch_size * beam_size, dtype=torch.int64)
scores = torch.zeros(batch_size * beam_size, dtype=torch.float32)
active = torch.arange(0, batch_size * beam_size, dtype=torch.int64)
base_mask = torch.arange(0, batch_size * beam_size, dtype=torch.int64)
global_offset = torch.arange(0, batch_size * beam_size, beam_size,
dtype=torch.int64)
eos_beam_fill = torch.tensor([0] + (beam_size - 1) * [float('-inf')])
if self.cuda:
translation = translation.cuda()
lengths = lengths.cuda()
active = active.cuda()
base_mask = base_mask.cuda()
scores = scores.cuda()
global_offset = global_offset.cuda()
eos_beam_fill = eos_beam_fill.cuda()
translation[:, 0] = BOS
words, context = initial_input, initial_context
if self.batch_first:
word_view = (-1, 1)
ctx_batch_dim = 0
attn_query_dim = 1
else:
word_view = (1, -1)
ctx_batch_dim = 1
attn_query_dim = 0
# replicate context
if self.batch_first:
# context[0] (encoder state): (batch, seq, feature)
_, seq, feature = context[0].shape
context[0].unsqueeze_(1)
context[0] = context[0].expand(-1, beam_size, -1, -1)
context[0] = context[0].contiguous().view(batch_size * beam_size,
seq, feature)
# context[0]: (batch * beam, seq, feature)
else:
# context[0] (encoder state): (seq, batch, feature)
seq, _, feature = context[0].shape
context[0].unsqueeze_(2)
context[0] = context[0].expand(-1, -1, beam_size, -1)
context[0] = context[0].contiguous().view(seq, batch_size *
beam_size, feature)
# context[0]: (seq, batch * beam, feature)
# context[1] (encoder seq length): (batch)
context[1].unsqueeze_(1)
context[1] = context[1].expand(-1, beam_size)
context[1] = context[1].contiguous().view(batch_size * beam_size)
# context[1]: (batch * beam)
accu_attn_scores = torch.zeros(batch_size * beam_size, seq)
if self.cuda:
accu_attn_scores = accu_attn_scores.cuda()
counter = 0
for idx in range(1, self.max_seq_len):
if not len(active):
break
counter += 1
eos_mask = (words == EOS)
eos_mask = eos_mask.view(-1, beam_size)
terminating, _ = eos_mask.min(dim=1)
lengths[active[~eos_mask.view(-1)]] += 1
output = self.model.generate(words, context, beam_size)
words, logprobs, attn, context = output
attn = attn.float().squeeze(attn_query_dim)
attn = attn.masked_fill(eos_mask.view(-1).unsqueeze(1), 0)
accu_attn_scores[active] += attn
# words: (batch, beam, k)
words = words.view(-1, beam_size, beam_size)
words = words.masked_fill(eos_mask.unsqueeze(2), EOS)
# logprobs: (batch, beam, k)
logprobs = logprobs.float().view(-1, beam_size, beam_size)
if eos_mask.any():
logprobs[eos_mask] = eos_beam_fill
active_scores = scores[active].view(-1, beam_size)
# new_scores: (batch, beam, k)
new_scores = active_scores.unsqueeze(2) + logprobs
if idx == 1:
new_scores[:, 1:, :].fill_(float('-inf'))
new_scores = new_scores.view(-1, beam_size * beam_size)
# index: (batch, beam)
_, index = new_scores.topk(beam_size, dim=1)
source_beam = index / beam_size
new_scores = new_scores.view(-1, beam_size * beam_size)
best_scores = torch.gather(new_scores, 1, index)
scores[active] = best_scores.view(-1)
words = words.view(-1, beam_size * beam_size)
words = torch.gather(words, 1, index)
# words: (1, batch * beam)
words = words.view(word_view)
offset = global_offset[:source_beam.shape[0]]
source_beam += offset.unsqueeze(1)
translation[active, :] = translation[active[source_beam.view(-1)], :]
translation[active, idx] = words.view(-1)
lengths[active] = lengths[active[source_beam.view(-1)]]
context[2] = context[2].index_select(1, source_beam.view(-1))
if terminating.any():
not_terminating = ~terminating
not_terminating = not_terminating.unsqueeze(1)
not_terminating = not_terminating.expand(-1, beam_size).contiguous()
normalization_mask = active.view(-1, beam_size)[terminating]
# length normalization
norm = lengths[normalization_mask].float()
norm = (norm_const + norm) / (norm_const + 1.0)
norm = norm ** norm_factor
scores[normalization_mask] /= norm
# coverage penalty
penalty = accu_attn_scores[normalization_mask]
penalty = penalty.clamp(0, 1)
penalty = penalty.log()
penalty[penalty == float('-inf')] = 0
penalty = penalty.sum(dim=-1)
scores[normalization_mask] += cov_penalty_factor * penalty
mask = base_mask[:len(active)]
mask = mask.masked_select(not_terminating.view(-1))
words = words.index_select(ctx_batch_dim, mask)
context[0] = context[0].index_select(ctx_batch_dim, mask)
context[1] = context[1].index_select(0, mask)
context[2] = context[2].index_select(1, mask)
active = active.masked_select(not_terminating.view(-1))
scores = scores.view(batch_size, beam_size)
_, idx = scores.max(dim=1)
translation = translation[idx + global_offset, :]
lengths = lengths[idx + global_offset]
return translation, lengths, counter