DeepLearningExamples/PyTorch/Translation/GNMT/seq2seq/models/gnmt.py
2019-02-14 12:40:30 +01:00

52 lines
2 KiB
Python

import torch.nn as nn
import seq2seq.data.config as config
from seq2seq.models.decoder import ResidualRecurrentDecoder
from seq2seq.models.encoder import ResidualRecurrentEncoder
from seq2seq.models.seq2seq_base import Seq2Seq
class GNMT(Seq2Seq):
"""
GNMT v2 model
"""
def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2,
batch_first=False, share_embedding=True):
"""
Constructor for the GNMT v2 model.
:param vocab_size: size of vocabulary (number of tokens)
:param hidden_size: internal hidden size of the model
:param num_layers: number of layers, applies to both encoder and
decoder
:param dropout: probability of dropout (in encoder and decoder)
:param batch_first: if True the model uses (batch,seq,feature) tensors,
if false the model uses (seq, batch, feature)
:param share_embedding: if True embeddings are shared between encoder
and decoder
"""
super(GNMT, self).__init__(batch_first=batch_first)
if share_embedding:
embedder = nn.Embedding(vocab_size, hidden_size,
padding_idx=config.PAD)
nn.init.uniform_(embedder.weight.data, -0.1, 0.1)
else:
embedder = None
self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size,
num_layers, dropout,
batch_first, embedder)
self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size,
num_layers, dropout,
batch_first, embedder)
def forward(self, input_encoder, input_enc_len, input_decoder):
context = self.encode(input_encoder, input_enc_len)
context = (context, input_enc_len, None)
output, _, _ = self.decode(input_decoder, context)
return output