DeepLearningExamples/PyTorch/Translation/GNMT/seq2seq/models/attention.py
2019-02-14 12:40:30 +01:00

165 lines
5.5 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class BahdanauAttention(nn.Module):
"""
Bahdanau Attention (https://arxiv.org/abs/1409.0473)
Implementation is very similar to tf.contrib.seq2seq.BahdanauAttention
"""
def __init__(self, query_size, key_size, num_units, normalize=False,
batch_first=False, init_weight=0.1):
"""
Constructor for the BahdanauAttention.
:param query_size: feature dimension for query
:param key_size: feature dimension for keys
:param num_units: internal feature dimension
:param normalize: whether to normalize energy term
:param batch_first: if True batch size is the 1st dimension, if False
the sequence is first and batch size is second
:param init_weight: range for uniform initializer used to initialize
Linear key and query transform layers and linear_att vector
"""
super(BahdanauAttention, self).__init__()
self.normalize = normalize
self.batch_first = batch_first
self.num_units = num_units
self.linear_q = nn.Linear(query_size, num_units, bias=False)
self.linear_k = nn.Linear(key_size, num_units, bias=False)
nn.init.uniform_(self.linear_q.weight.data, -init_weight, init_weight)
nn.init.uniform_(self.linear_k.weight.data, -init_weight, init_weight)
self.linear_att = Parameter(torch.Tensor(num_units))
self.mask = None
if self.normalize:
self.normalize_scalar = Parameter(torch.Tensor(1))
self.normalize_bias = Parameter(torch.Tensor(num_units))
else:
self.register_parameter('normalize_scalar', None)
self.register_parameter('normalize_bias', None)
self.reset_parameters(init_weight)
def reset_parameters(self, init_weight):
"""
Sets initial random values for trainable parameters.
"""
stdv = 1. / math.sqrt(self.num_units)
self.linear_att.data.uniform_(-init_weight, init_weight)
if self.normalize:
self.normalize_scalar.data.fill_(stdv)
self.normalize_bias.data.zero_()
def set_mask(self, context_len, context):
"""
sets self.mask which is applied before softmax
ones for inactive context fields, zeros for active context fields
:param context_len: b
:param context: if batch_first: (b x t_k x n) else: (t_k x b x n)
self.mask: (b x t_k)
"""
if self.batch_first:
max_len = context.size(1)
else:
max_len = context.size(0)
indices = torch.arange(0, max_len, dtype=torch.int64,
device=context.device)
self.mask = indices >= (context_len.unsqueeze(1))
def calc_score(self, att_query, att_keys):
"""
Calculate Bahdanau score
:param att_query: b x t_q x n
:param att_keys: b x t_k x n
returns: b x t_q x t_k scores
"""
b, t_k, n = att_keys.size()
t_q = att_query.size(1)
att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
sum_qk = att_query + att_keys
if self.normalize:
sum_qk = sum_qk + self.normalize_bias
linear_att = self.linear_att / self.linear_att.norm()
linear_att = linear_att * self.normalize_scalar
else:
linear_att = self.linear_att
out = torch.tanh(sum_qk).matmul(linear_att)
return out
def forward(self, query, keys):
"""
:param query: if batch_first: (b x t_q x n) else: (t_q x b x n)
:param keys: if batch_first: (b x t_k x n) else (t_k x b x n)
:returns: (context, scores_normalized)
context: if batch_first: (b x t_q x n) else (t_q x b x n)
scores_normalized: if batch_first (b x t_q x t_k) else (t_q x b x t_k)
"""
# first dim of keys and query has to be 'batch', it's needed for bmm
if not self.batch_first:
keys = keys.transpose(0, 1)
if query.dim() == 3:
query = query.transpose(0, 1)
if query.dim() == 2:
single_query = True
query = query.unsqueeze(1)
else:
single_query = False
b = query.size(0)
t_k = keys.size(1)
t_q = query.size(1)
# FC layers to transform query and key
processed_query = self.linear_q(query)
processed_key = self.linear_k(keys)
# scores: (b x t_q x t_k)
scores = self.calc_score(processed_query, processed_key)
if self.mask is not None:
mask = self.mask.unsqueeze(1).expand(b, t_q, t_k)
# I can't use -INF because of overflow check in pytorch
scores.data.masked_fill_(mask, -65504.0)
# Normalize the scores, softmax over t_k
scores_normalized = F.softmax(scores, dim=-1)
# Calculate the weighted average of the attention inputs according to
# the scores
# context: (b x t_q x n)
context = torch.bmm(scores_normalized, keys)
if single_query:
context = context.squeeze(1)
scores_normalized = scores_normalized.squeeze(1)
elif not self.batch_first:
context = context.transpose(0, 1)
scores_normalized = scores_normalized.transpose(0, 1)
return context, scores_normalized