2018-08-07 16:27:43 +02:00
|
|
|
import logging
|
2019-02-14 12:40:30 +01:00
|
|
|
from operator import itemgetter
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.utils.data import DataLoader
|
2019-02-14 12:40:30 +01:00
|
|
|
from torch.utils.data import Dataset
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
import seq2seq.data.config as config
|
|
|
|
from seq2seq.data.sampler import BucketingSampler
|
2019-02-14 12:40:30 +01:00
|
|
|
from seq2seq.data.sampler import DistributedSampler
|
|
|
|
from seq2seq.data.sampler import ShardingSampler
|
|
|
|
from seq2seq.data.sampler import StaticDistributedSampler
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
|
|
|
|
def build_collate_fn(batch_first=False, parallel=True, sort=False):
|
|
|
|
"""
|
2019-02-14 12:40:30 +01:00
|
|
|
Factory for collate_fn functions.
|
2018-08-07 16:27:43 +02:00
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
:param batch_first: if True returns batches in (batch, seq) format, if
|
|
|
|
False returns in (seq, batch) format
|
2018-08-07 16:27:43 +02:00
|
|
|
:param parallel: if True builds batches from parallel corpus (src, tgt)
|
|
|
|
:param sort: if True sorts by src sequence length within each batch
|
|
|
|
"""
|
|
|
|
def collate_seq(seq):
|
|
|
|
"""
|
|
|
|
Builds batches for training or inference.
|
|
|
|
Batches are returned as pytorch tensors, with padding.
|
|
|
|
|
|
|
|
:param seq: list of sequences
|
|
|
|
"""
|
|
|
|
lengths = [len(s) for s in seq]
|
|
|
|
batch_length = max(lengths)
|
|
|
|
|
|
|
|
shape = (batch_length, len(seq))
|
|
|
|
seq_tensor = torch.full(shape, config.PAD, dtype=torch.int64)
|
|
|
|
|
|
|
|
for i, s in enumerate(seq):
|
|
|
|
end_seq = lengths[i]
|
|
|
|
seq_tensor[:end_seq, i].copy_(s[:end_seq])
|
|
|
|
|
|
|
|
if batch_first:
|
|
|
|
seq_tensor = seq_tensor.t()
|
|
|
|
|
|
|
|
return (seq_tensor, lengths)
|
|
|
|
|
|
|
|
def parallel_collate(seqs):
|
|
|
|
"""
|
|
|
|
Builds batches from parallel dataset (src, tgt), optionally sorts batch
|
|
|
|
by src sequence length.
|
|
|
|
|
|
|
|
:param seqs: tuple of (src, tgt) sequences
|
|
|
|
"""
|
|
|
|
src_seqs, tgt_seqs = zip(*seqs)
|
|
|
|
if sort:
|
2019-02-14 12:40:30 +01:00
|
|
|
indices, src_seqs = zip(*sorted(enumerate(src_seqs),
|
|
|
|
key=lambda item: len(item[1]),
|
2018-08-07 16:27:43 +02:00
|
|
|
reverse=True))
|
|
|
|
tgt_seqs = [tgt_seqs[idx] for idx in indices]
|
|
|
|
|
|
|
|
return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]])
|
|
|
|
|
|
|
|
def single_collate(src_seqs):
|
|
|
|
"""
|
|
|
|
Builds batches from text dataset, optionally sorts batch by src
|
|
|
|
sequence length.
|
|
|
|
|
|
|
|
:param src_seqs: source sequences
|
|
|
|
"""
|
|
|
|
if sort:
|
2019-02-14 12:40:30 +01:00
|
|
|
indices, src_seqs = zip(*sorted(enumerate(src_seqs),
|
|
|
|
key=lambda item: len(item[1]),
|
2018-08-07 16:27:43 +02:00
|
|
|
reverse=True))
|
|
|
|
else:
|
|
|
|
indices = range(len(src_seqs))
|
|
|
|
|
|
|
|
return collate_seq(src_seqs), tuple(indices)
|
|
|
|
|
|
|
|
if parallel:
|
|
|
|
return parallel_collate
|
|
|
|
else:
|
|
|
|
return single_collate
|
|
|
|
|
|
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
|
|
def __init__(self, src_fname, tokenizer, min_len=None, max_len=None,
|
|
|
|
sort=False, max_size=None):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Constructor for the TextDataset. Builds monolingual dataset.
|
|
|
|
|
|
|
|
:param src_fname: path to the file with data
|
|
|
|
:param tokenizer: tokenizer
|
|
|
|
:param min_len: minimum sequence length
|
|
|
|
:param max_len: maximum sequence length
|
|
|
|
:param sort: sorts dataset by sequence length
|
|
|
|
:param max_size: loads at most 'max_size' samples from the input file,
|
|
|
|
if None loads the entire dataset
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
self.min_len = min_len
|
|
|
|
self.max_len = max_len
|
|
|
|
self.parallel = False
|
2019-02-14 12:40:30 +01:00
|
|
|
self.sorted = False
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
self.src = self.process_data(src_fname, tokenizer, max_size)
|
|
|
|
|
|
|
|
if min_len is not None and max_len is not None:
|
|
|
|
self.filter_data(min_len, max_len)
|
|
|
|
|
|
|
|
lengths = [len(s) for s in self.src]
|
|
|
|
self.lengths = torch.tensor(lengths)
|
|
|
|
|
|
|
|
if sort:
|
|
|
|
self.sort_by_length()
|
|
|
|
|
|
|
|
def sort_by_length(self):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Sorts dataset by the sequence length.
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
self.lengths, indices = self.lengths.sort(descending=True)
|
|
|
|
|
|
|
|
self.src = [self.src[idx] for idx in indices]
|
2019-02-14 12:40:30 +01:00
|
|
|
self.indices = indices.tolist()
|
|
|
|
self.sorted = True
|
|
|
|
|
|
|
|
def unsort(self, array):
|
|
|
|
"""
|
|
|
|
"Unsorts" given array (restores original order of elements before
|
|
|
|
dataset was sorted by sequence length).
|
|
|
|
|
|
|
|
:param array: array to be "unsorted"
|
|
|
|
"""
|
|
|
|
if self.sorted:
|
|
|
|
inverse = sorted(enumerate(self.indices), key=itemgetter(1))
|
|
|
|
array = [array[i[0]] for i in inverse]
|
|
|
|
return array
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
def filter_data(self, min_len, max_len):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Preserves only samples which satisfy the following inequality:
|
|
|
|
min_len <= sample sequence length <= max_len
|
|
|
|
|
|
|
|
:param min_len: minimum sequence length
|
|
|
|
:param max_len: maximum sequence length
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}')
|
|
|
|
|
|
|
|
initial_len = len(self.src)
|
|
|
|
filtered_src = []
|
|
|
|
for src in self.src:
|
|
|
|
if min_len <= len(src) <= max_len:
|
|
|
|
filtered_src.append(src)
|
|
|
|
|
|
|
|
self.src = filtered_src
|
|
|
|
filtered_len = len(self.src)
|
|
|
|
logging.info(f'Pairs before: {initial_len}, after: {filtered_len}')
|
|
|
|
|
|
|
|
def process_data(self, fname, tokenizer, max_size):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Loads data from the input file.
|
|
|
|
|
|
|
|
:param fname: input file name
|
|
|
|
:param tokenizer: tokenizer
|
|
|
|
:param max_size: loads at most 'max_size' samples from the input file,
|
|
|
|
if None loads the entire dataset
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
logging.info(f'Processing data from {fname}')
|
|
|
|
data = []
|
|
|
|
with open(fname) as dfile:
|
|
|
|
for idx, line in enumerate(dfile):
|
|
|
|
if max_size and idx == max_size:
|
|
|
|
break
|
|
|
|
entry = tokenizer.segment(line)
|
|
|
|
entry = torch.tensor(entry)
|
|
|
|
data.append(entry)
|
|
|
|
return data
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.src)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
return self.src[idx]
|
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
def get_loader(self, batch_size=1, seeds=None, shuffle=False,
|
|
|
|
num_workers=0, batch_first=False, pad=False,
|
|
|
|
batching=None, batching_opt={}):
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
collate_fn = build_collate_fn(batch_first, parallel=self.parallel,
|
|
|
|
sort=True)
|
|
|
|
|
|
|
|
if shuffle:
|
2019-02-14 12:40:30 +01:00
|
|
|
if batching == 'random':
|
|
|
|
sampler = DistributedSampler(self, batch_size, seeds)
|
|
|
|
elif batching == 'sharding':
|
|
|
|
sampler = ShardingSampler(self, batch_size, seeds,
|
|
|
|
batching_opt['shard_size'])
|
|
|
|
elif batching == 'bucketing':
|
|
|
|
sampler = BucketingSampler(self, batch_size, seeds,
|
|
|
|
batching_opt['num_buckets'])
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2018-08-07 16:27:43 +02:00
|
|
|
else:
|
2019-02-14 12:40:30 +01:00
|
|
|
sampler = StaticDistributedSampler(self, batch_size, pad)
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
return DataLoader(self,
|
|
|
|
batch_size=batch_size,
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
sampler=sampler,
|
|
|
|
num_workers=num_workers,
|
2019-02-14 12:40:30 +01:00
|
|
|
pin_memory=True,
|
|
|
|
drop_last=False)
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ParallelDataset(TextDataset):
|
|
|
|
def __init__(self, src_fname, tgt_fname, tokenizer,
|
|
|
|
min_len, max_len, sort=False, max_size=None):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Constructor for the ParallelDataset.
|
|
|
|
Tokenization is done when the data is loaded from the disk.
|
|
|
|
|
|
|
|
:param src_fname: path to the file with src language data
|
|
|
|
:param tgt_fname: path to the file with tgt language data
|
|
|
|
:param tokenizer: tokenizer
|
|
|
|
:param min_len: minimum sequence length
|
|
|
|
:param max_len: maximum sequence length
|
|
|
|
:param sort: sorts dataset by sequence length
|
|
|
|
:param max_size: loads at most 'max_size' samples from the input file,
|
|
|
|
if None loads the entire dataset
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
self.min_len = min_len
|
|
|
|
self.max_len = max_len
|
|
|
|
self.parallel = True
|
2019-02-14 12:40:30 +01:00
|
|
|
self.sorted = False
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
self.src = self.process_data(src_fname, tokenizer, max_size)
|
|
|
|
self.tgt = self.process_data(tgt_fname, tokenizer, max_size)
|
|
|
|
assert len(self.src) == len(self.tgt)
|
|
|
|
|
|
|
|
self.filter_data(min_len, max_len)
|
|
|
|
assert len(self.src) == len(self.tgt)
|
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
src_lengths = [len(s) for s in self.src]
|
|
|
|
tgt_lengths = [len(t) for t in self.tgt]
|
|
|
|
self.src_lengths = torch.tensor(src_lengths)
|
|
|
|
self.tgt_lengths = torch.tensor(tgt_lengths)
|
|
|
|
self.lengths = self.src_lengths + self.tgt_lengths
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
if sort:
|
|
|
|
self.sort_by_length()
|
|
|
|
|
|
|
|
def sort_by_length(self):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Sorts dataset by the sequence length.
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
self.lengths, indices = self.lengths.sort(descending=True)
|
|
|
|
|
|
|
|
self.src = [self.src[idx] for idx in indices]
|
|
|
|
self.tgt = [self.tgt[idx] for idx in indices]
|
2019-02-14 12:40:30 +01:00
|
|
|
self.src_lengths = [self.src_lengths[idx] for idx in indices]
|
|
|
|
self.tgt_lengths = [self.tgt_lengths[idx] for idx in indices]
|
|
|
|
self.indices = indices.tolist()
|
|
|
|
self.sorted = True
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
def filter_data(self, min_len, max_len):
|
2019-02-14 12:40:30 +01:00
|
|
|
"""
|
|
|
|
Preserves only samples which satisfy the following inequality:
|
|
|
|
min_len <= src sample sequence length <= max_len AND
|
|
|
|
min_len <= tgt sample sequence length <= max_len
|
|
|
|
|
|
|
|
:param min_len: minimum sequence length
|
|
|
|
:param max_len: maximum sequence length
|
|
|
|
"""
|
2018-08-07 16:27:43 +02:00
|
|
|
logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}')
|
|
|
|
|
|
|
|
initial_len = len(self.src)
|
|
|
|
filtered_src = []
|
|
|
|
filtered_tgt = []
|
|
|
|
for src, tgt in zip(self.src, self.tgt):
|
|
|
|
if min_len <= len(src) <= max_len and \
|
|
|
|
min_len <= len(tgt) <= max_len:
|
|
|
|
filtered_src.append(src)
|
|
|
|
filtered_tgt.append(tgt)
|
|
|
|
|
|
|
|
self.src = filtered_src
|
|
|
|
self.tgt = filtered_tgt
|
|
|
|
filtered_len = len(self.src)
|
|
|
|
logging.info(f'Pairs before: {initial_len}, after: {filtered_len}')
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
return self.src[idx], self.tgt[idx]
|
2019-02-14 12:40:30 +01:00
|
|
|
|
|
|
|
|
|
|
|
class LazyParallelDataset(TextDataset):
|
|
|
|
def __init__(self, src_fname, tgt_fname, tokenizer,
|
|
|
|
min_len, max_len, sort=False, max_size=None):
|
|
|
|
"""
|
|
|
|
Constructor for the LazyParallelDataset.
|
|
|
|
Tokenization is done on the fly.
|
|
|
|
|
|
|
|
:param src_fname: path to the file with src language data
|
|
|
|
:param tgt_fname: path to the file with tgt language data
|
|
|
|
:param tokenizer: tokenizer
|
|
|
|
:param min_len: minimum sequence length
|
|
|
|
:param max_len: maximum sequence length
|
|
|
|
:param sort: sorts dataset by sequence length
|
|
|
|
:param max_size: loads at most 'max_size' samples from the input file,
|
|
|
|
if None loads the entire dataset
|
|
|
|
"""
|
|
|
|
self.min_len = min_len
|
|
|
|
self.max_len = max_len
|
|
|
|
self.parallel = True
|
|
|
|
self.sorted = False
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
self.raw_src = self.process_raw_data(src_fname, max_size)
|
|
|
|
self.raw_tgt = self.process_raw_data(tgt_fname, max_size)
|
|
|
|
assert len(self.raw_src) == len(self.raw_tgt)
|
|
|
|
|
|
|
|
logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}')
|
|
|
|
# Subtracting 2 because EOS and BOS are added later during tokenization
|
|
|
|
self.filter_raw_data(min_len - 2, max_len - 2)
|
|
|
|
assert len(self.raw_src) == len(self.raw_tgt)
|
|
|
|
|
|
|
|
# Adding 2 because EOS and BOS are added later during tokenization
|
|
|
|
src_lengths = [i + 2 for i in self.src_len]
|
|
|
|
tgt_lengths = [i + 2 for i in self.tgt_len]
|
|
|
|
self.src_lengths = torch.tensor(src_lengths)
|
|
|
|
self.tgt_lengths = torch.tensor(tgt_lengths)
|
|
|
|
self.lengths = self.src_lengths + self.tgt_lengths
|
|
|
|
|
|
|
|
def process_raw_data(self, fname, max_size):
|
|
|
|
"""
|
|
|
|
Loads data from the input file.
|
|
|
|
|
|
|
|
:param fname: input file name
|
|
|
|
:param max_size: loads at most 'max_size' samples from the input file,
|
|
|
|
if None loads the entire dataset
|
|
|
|
"""
|
|
|
|
logging.info(f'Processing data from {fname}')
|
|
|
|
data = []
|
|
|
|
with open(fname) as dfile:
|
|
|
|
for idx, line in enumerate(dfile):
|
|
|
|
if max_size and idx == max_size:
|
|
|
|
break
|
|
|
|
data.append(line)
|
|
|
|
return data
|
|
|
|
|
|
|
|
def filter_raw_data(self, min_len, max_len):
|
|
|
|
"""
|
|
|
|
Preserves only samples which satisfy the following inequality:
|
|
|
|
min_len <= src sample sequence length <= max_len AND
|
|
|
|
min_len <= tgt sample sequence length <= max_len
|
|
|
|
|
|
|
|
:param min_len: minimum sequence length
|
|
|
|
:param max_len: maximum sequence length
|
|
|
|
"""
|
|
|
|
initial_len = len(self.raw_src)
|
|
|
|
filtered_src = []
|
|
|
|
filtered_tgt = []
|
|
|
|
filtered_src_len = []
|
|
|
|
filtered_tgt_len = []
|
|
|
|
for src, tgt in zip(self.raw_src, self.raw_tgt):
|
|
|
|
src_len = src.count(' ') + 1
|
|
|
|
tgt_len = tgt.count(' ') + 1
|
|
|
|
if min_len <= src_len <= max_len and \
|
|
|
|
min_len <= tgt_len <= max_len:
|
|
|
|
filtered_src.append(src)
|
|
|
|
filtered_tgt.append(tgt)
|
|
|
|
filtered_src_len.append(src_len)
|
|
|
|
filtered_tgt_len.append(tgt_len)
|
|
|
|
|
|
|
|
self.raw_src = filtered_src
|
|
|
|
self.raw_tgt = filtered_tgt
|
|
|
|
self.src_len = filtered_src_len
|
|
|
|
self.tgt_len = filtered_tgt_len
|
|
|
|
filtered_len = len(self.raw_src)
|
|
|
|
logging.info(f'Pairs before: {initial_len}, after: {filtered_len}')
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
src = torch.tensor(self.tokenizer.segment(self.raw_src[idx]))
|
|
|
|
tgt = torch.tensor(self.tokenizer.segment(self.raw_tgt[idx]))
|
|
|
|
return src, tgt
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.raw_src)
|