DeepLearningExamples/PyTorch/Translation/GNMT/seq2seq/data/dataset.py
2019-02-14 12:40:30 +01:00

386 lines
13 KiB
Python

import logging
from operator import itemgetter
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import seq2seq.data.config as config
from seq2seq.data.sampler import BucketingSampler
from seq2seq.data.sampler import DistributedSampler
from seq2seq.data.sampler import ShardingSampler
from seq2seq.data.sampler import StaticDistributedSampler
def build_collate_fn(batch_first=False, parallel=True, sort=False):
"""
Factory for collate_fn functions.
:param batch_first: if True returns batches in (batch, seq) format, if
False returns in (seq, batch) format
:param parallel: if True builds batches from parallel corpus (src, tgt)
:param sort: if True sorts by src sequence length within each batch
"""
def collate_seq(seq):
"""
Builds batches for training or inference.
Batches are returned as pytorch tensors, with padding.
:param seq: list of sequences
"""
lengths = [len(s) for s in seq]
batch_length = max(lengths)
shape = (batch_length, len(seq))
seq_tensor = torch.full(shape, config.PAD, dtype=torch.int64)
for i, s in enumerate(seq):
end_seq = lengths[i]
seq_tensor[:end_seq, i].copy_(s[:end_seq])
if batch_first:
seq_tensor = seq_tensor.t()
return (seq_tensor, lengths)
def parallel_collate(seqs):
"""
Builds batches from parallel dataset (src, tgt), optionally sorts batch
by src sequence length.
:param seqs: tuple of (src, tgt) sequences
"""
src_seqs, tgt_seqs = zip(*seqs)
if sort:
indices, src_seqs = zip(*sorted(enumerate(src_seqs),
key=lambda item: len(item[1]),
reverse=True))
tgt_seqs = [tgt_seqs[idx] for idx in indices]
return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]])
def single_collate(src_seqs):
"""
Builds batches from text dataset, optionally sorts batch by src
sequence length.
:param src_seqs: source sequences
"""
if sort:
indices, src_seqs = zip(*sorted(enumerate(src_seqs),
key=lambda item: len(item[1]),
reverse=True))
else:
indices = range(len(src_seqs))
return collate_seq(src_seqs), tuple(indices)
if parallel:
return parallel_collate
else:
return single_collate
class TextDataset(Dataset):
def __init__(self, src_fname, tokenizer, min_len=None, max_len=None,
sort=False, max_size=None):
"""
Constructor for the TextDataset. Builds monolingual dataset.
:param src_fname: path to the file with data
:param tokenizer: tokenizer
:param min_len: minimum sequence length
:param max_len: maximum sequence length
:param sort: sorts dataset by sequence length
:param max_size: loads at most 'max_size' samples from the input file,
if None loads the entire dataset
"""
self.min_len = min_len
self.max_len = max_len
self.parallel = False
self.sorted = False
self.src = self.process_data(src_fname, tokenizer, max_size)
if min_len is not None and max_len is not None:
self.filter_data(min_len, max_len)
lengths = [len(s) for s in self.src]
self.lengths = torch.tensor(lengths)
if sort:
self.sort_by_length()
def sort_by_length(self):
"""
Sorts dataset by the sequence length.
"""
self.lengths, indices = self.lengths.sort(descending=True)
self.src = [self.src[idx] for idx in indices]
self.indices = indices.tolist()
self.sorted = True
def unsort(self, array):
"""
"Unsorts" given array (restores original order of elements before
dataset was sorted by sequence length).
:param array: array to be "unsorted"
"""
if self.sorted:
inverse = sorted(enumerate(self.indices), key=itemgetter(1))
array = [array[i[0]] for i in inverse]
return array
def filter_data(self, min_len, max_len):
"""
Preserves only samples which satisfy the following inequality:
min_len <= sample sequence length <= max_len
:param min_len: minimum sequence length
:param max_len: maximum sequence length
"""
logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}')
initial_len = len(self.src)
filtered_src = []
for src in self.src:
if min_len <= len(src) <= max_len:
filtered_src.append(src)
self.src = filtered_src
filtered_len = len(self.src)
logging.info(f'Pairs before: {initial_len}, after: {filtered_len}')
def process_data(self, fname, tokenizer, max_size):
"""
Loads data from the input file.
:param fname: input file name
:param tokenizer: tokenizer
:param max_size: loads at most 'max_size' samples from the input file,
if None loads the entire dataset
"""
logging.info(f'Processing data from {fname}')
data = []
with open(fname) as dfile:
for idx, line in enumerate(dfile):
if max_size and idx == max_size:
break
entry = tokenizer.segment(line)
entry = torch.tensor(entry)
data.append(entry)
return data
def __len__(self):
return len(self.src)
def __getitem__(self, idx):
return self.src[idx]
def get_loader(self, batch_size=1, seeds=None, shuffle=False,
num_workers=0, batch_first=False, pad=False,
batching=None, batching_opt={}):
collate_fn = build_collate_fn(batch_first, parallel=self.parallel,
sort=True)
if shuffle:
if batching == 'random':
sampler = DistributedSampler(self, batch_size, seeds)
elif batching == 'sharding':
sampler = ShardingSampler(self, batch_size, seeds,
batching_opt['shard_size'])
elif batching == 'bucketing':
sampler = BucketingSampler(self, batch_size, seeds,
batching_opt['num_buckets'])
else:
raise NotImplementedError
else:
sampler = StaticDistributedSampler(self, batch_size, pad)
return DataLoader(self,
batch_size=batch_size,
collate_fn=collate_fn,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
drop_last=False)
class ParallelDataset(TextDataset):
def __init__(self, src_fname, tgt_fname, tokenizer,
min_len, max_len, sort=False, max_size=None):
"""
Constructor for the ParallelDataset.
Tokenization is done when the data is loaded from the disk.
:param src_fname: path to the file with src language data
:param tgt_fname: path to the file with tgt language data
:param tokenizer: tokenizer
:param min_len: minimum sequence length
:param max_len: maximum sequence length
:param sort: sorts dataset by sequence length
:param max_size: loads at most 'max_size' samples from the input file,
if None loads the entire dataset
"""
self.min_len = min_len
self.max_len = max_len
self.parallel = True
self.sorted = False
self.src = self.process_data(src_fname, tokenizer, max_size)
self.tgt = self.process_data(tgt_fname, tokenizer, max_size)
assert len(self.src) == len(self.tgt)
self.filter_data(min_len, max_len)
assert len(self.src) == len(self.tgt)
src_lengths = [len(s) for s in self.src]
tgt_lengths = [len(t) for t in self.tgt]
self.src_lengths = torch.tensor(src_lengths)
self.tgt_lengths = torch.tensor(tgt_lengths)
self.lengths = self.src_lengths + self.tgt_lengths
if sort:
self.sort_by_length()
def sort_by_length(self):
"""
Sorts dataset by the sequence length.
"""
self.lengths, indices = self.lengths.sort(descending=True)
self.src = [self.src[idx] for idx in indices]
self.tgt = [self.tgt[idx] for idx in indices]
self.src_lengths = [self.src_lengths[idx] for idx in indices]
self.tgt_lengths = [self.tgt_lengths[idx] for idx in indices]
self.indices = indices.tolist()
self.sorted = True
def filter_data(self, min_len, max_len):
"""
Preserves only samples which satisfy the following inequality:
min_len <= src sample sequence length <= max_len AND
min_len <= tgt sample sequence length <= max_len
:param min_len: minimum sequence length
:param max_len: maximum sequence length
"""
logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}')
initial_len = len(self.src)
filtered_src = []
filtered_tgt = []
for src, tgt in zip(self.src, self.tgt):
if min_len <= len(src) <= max_len and \
min_len <= len(tgt) <= max_len:
filtered_src.append(src)
filtered_tgt.append(tgt)
self.src = filtered_src
self.tgt = filtered_tgt
filtered_len = len(self.src)
logging.info(f'Pairs before: {initial_len}, after: {filtered_len}')
def __getitem__(self, idx):
return self.src[idx], self.tgt[idx]
class LazyParallelDataset(TextDataset):
def __init__(self, src_fname, tgt_fname, tokenizer,
min_len, max_len, sort=False, max_size=None):
"""
Constructor for the LazyParallelDataset.
Tokenization is done on the fly.
:param src_fname: path to the file with src language data
:param tgt_fname: path to the file with tgt language data
:param tokenizer: tokenizer
:param min_len: minimum sequence length
:param max_len: maximum sequence length
:param sort: sorts dataset by sequence length
:param max_size: loads at most 'max_size' samples from the input file,
if None loads the entire dataset
"""
self.min_len = min_len
self.max_len = max_len
self.parallel = True
self.sorted = False
self.tokenizer = tokenizer
self.raw_src = self.process_raw_data(src_fname, max_size)
self.raw_tgt = self.process_raw_data(tgt_fname, max_size)
assert len(self.raw_src) == len(self.raw_tgt)
logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}')
# Subtracting 2 because EOS and BOS are added later during tokenization
self.filter_raw_data(min_len - 2, max_len - 2)
assert len(self.raw_src) == len(self.raw_tgt)
# Adding 2 because EOS and BOS are added later during tokenization
src_lengths = [i + 2 for i in self.src_len]
tgt_lengths = [i + 2 for i in self.tgt_len]
self.src_lengths = torch.tensor(src_lengths)
self.tgt_lengths = torch.tensor(tgt_lengths)
self.lengths = self.src_lengths + self.tgt_lengths
def process_raw_data(self, fname, max_size):
"""
Loads data from the input file.
:param fname: input file name
:param max_size: loads at most 'max_size' samples from the input file,
if None loads the entire dataset
"""
logging.info(f'Processing data from {fname}')
data = []
with open(fname) as dfile:
for idx, line in enumerate(dfile):
if max_size and idx == max_size:
break
data.append(line)
return data
def filter_raw_data(self, min_len, max_len):
"""
Preserves only samples which satisfy the following inequality:
min_len <= src sample sequence length <= max_len AND
min_len <= tgt sample sequence length <= max_len
:param min_len: minimum sequence length
:param max_len: maximum sequence length
"""
initial_len = len(self.raw_src)
filtered_src = []
filtered_tgt = []
filtered_src_len = []
filtered_tgt_len = []
for src, tgt in zip(self.raw_src, self.raw_tgt):
src_len = src.count(' ') + 1
tgt_len = tgt.count(' ') + 1
if min_len <= src_len <= max_len and \
min_len <= tgt_len <= max_len:
filtered_src.append(src)
filtered_tgt.append(tgt)
filtered_src_len.append(src_len)
filtered_tgt_len.append(tgt_len)
self.raw_src = filtered_src
self.raw_tgt = filtered_tgt
self.src_len = filtered_src_len
self.tgt_len = filtered_tgt_len
filtered_len = len(self.raw_src)
logging.info(f'Pairs before: {initial_len}, after: {filtered_len}')
def __getitem__(self, idx):
src = torch.tensor(self.tokenizer.segment(self.raw_src[idx]))
tgt = torch.tensor(self.tokenizer.segment(self.raw_tgt[idx]))
return src, tgt
def __len__(self):
return len(self.raw_src)