Adding TransformerXL/PyT

This commit is contained in:
Przemek Strzelczyk 2019-11-27 17:00:18 +01:00
parent 1164331725
commit 3d46067af9
45 changed files with 6572 additions and 0 deletions

View file

@ -0,0 +1,3 @@
**/.DS_Store
__pycache__/
data/

View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,9 @@
Transformer-XL for PyTorch
This repository includes software from https://github.com/kimiyoung/transformer-xl licensed under the Apache License 2.0.
This repository includes software from https://github.com/salesforce/awd-lstm-lm licensed under the BSD-3-Clause license.
This repository includes software from https://github.com/cybertronai/transformer-xl licensed under the Apache License 2.0.
This repository includes software from https://github.com/cybertronai/pytorch-lamb licensed under the MIT license.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,120 @@
# BSD 3-Clause License
#
# Copyright (c) 2017,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
echo "=== Acquiring datasets ==="
echo "---"
mkdir -p data
cd data
if [[ ! -d 'wikitext-2' ]]; then
echo "- Downloading WikiText-2 (WT2)"
wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
unzip -q wikitext-2-v1.zip
cd wikitext-2
mv wiki.train.tokens train.txt
mv wiki.valid.tokens valid.txt
mv wiki.test.tokens test.txt
cd ..
fi
echo "- Downloading WikiText-103 (WT2)"
if [[ ! -d 'wikitext-103' ]]; then
wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
unzip -q wikitext-103-v1.zip
cd wikitext-103
mv wiki.train.tokens train.txt
mv wiki.valid.tokens valid.txt
mv wiki.test.tokens test.txt
cd ..
fi
echo "- Downloading enwik8 (Character)"
if [[ ! -d 'enwik8' ]]; then
mkdir -p enwik8
cd enwik8
wget --continue http://mattmahoney.net/dc/enwik8.zip
wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
python3 prep_enwik8.py
cd ..
fi
echo "- Downloading text8 (Character)"
if [[ ! -d 'text8' ]]; then
mkdir -p text8
cd text8
wget --continue http://mattmahoney.net/dc/text8.zip
python ../../prep_text8.py
cd ..
fi
echo "- Downloading Penn Treebank (PTB)"
if [[ ! -d 'penn' ]]; then
wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar -xzf simple-examples.tgz
mkdir -p penn
cd penn
mv ../simple-examples/data/ptb.train.txt train.txt
mv ../simple-examples/data/ptb.test.txt test.txt
mv ../simple-examples/data/ptb.valid.txt valid.txt
cd ..
echo "- Downloading Penn Treebank (Character)"
mkdir -p pennchar
cd pennchar
mv ../simple-examples/data/ptb.char.train.txt train.txt
mv ../simple-examples/data/ptb.char.test.txt test.txt
mv ../simple-examples/data/ptb.char.valid.txt valid.txt
cd ..
rm -rf simple-examples/
fi
echo "- Downloading 1B words"
if [[ ! -d 'one-billion-words' ]]; then
mkdir -p one-billion-words
cd one-billion-words
wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
cat ${path}/news.en.heldout-00000-of-00050 > valid.txt
cat ${path}/news.en.heldout-00000-of-00050 > test.txt
wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
cd ..
fi
echo "---"
echo "Happy language modeling :)"

View file

@ -0,0 +1,62 @@
#!/usr/bin/env python
# coding=utf-8
# BSD 3-Clause License
#
# Copyright (c) 2017,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import sys
import zipfile
from io import open
if os.path.exists('train.txt'):
print('Tokenized text8 already exists - skipping processing')
sys.exit()
data = zipfile.ZipFile('text8.zip').extractall()
data = open('text8', 'r', encoding='utf-8').read()
print('Length of text8: {}'.format(len(data)))
num_test_chars = 5000000
train_data = data[: -2 * num_test_chars]
valid_data = data[-2 * num_test_chars: -num_test_chars]
test_data = data[-num_test_chars:]
for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]:
print('{} will have {} bytes'.format(fn, len(part)))
print('- Tokenizing...')
# Change space ' ' to underscore '_'
part_str = ' '.join(['_' if c == ' ' else c for c in part.strip()])
print('- Writing...')
f = open(fn, 'w').write(part_str)
f = open(fn + '.raw', 'w', encoding='utf-8').write(part)

View file

@ -0,0 +1,2 @@
LM-TFM*
internal/result*

View file

@ -0,0 +1,30 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:19.09-py3
FROM ${FROM_IMAGE_NAME}
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
WORKDIR /tmp/unique_for_apex
RUN git clone https://github.com/NVIDIA/apex.git && cd apex && git reset --hard 3ae89c754d945e407a6674aa2006d5a0e35d540e
RUN cd apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
WORKDIR /workspace/transformer-xl/pytorch
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
ADD . /workspace/transformer-xl/pytorch

View file

@ -0,0 +1,333 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import os
import re
import numpy as np
import sacremoses
import torch
import utils
from utils.vocabulary import OpenAIVocab
from utils.vocabulary import Vocab
class LMOrderedIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
"""
data -- LongTensor -- the LongTensor is strictly ordered
"""
self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
# Work out how cleanly we can divide the dataset into bsz parts.
self.n_step = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, self.n_step * bsz)
# Evenly divide the data across the bsz batches.
self.data = data.view(bsz, -1).t().contiguous()
# Partition data for DistributedDataParallel
world_size = utils.distributed.get_world_size()
rank = utils.distributed.get_rank()
self.data = self.data.chunk(world_size, dim=1)[rank].to(device)
# Number of mini-batches
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
def roll(self):
for i in range(self.data.size(1)):
row = self.data[:, i]
shift = torch.randint(0, self.data.size(0), (1,))
row = torch.cat((row[shift:], row[:shift]))
self.data[:, i] = row
def get_batch(self, i, bptt=None):
if bptt is None:
bptt = self.bptt
seq_len = min(bptt, self.data.size(0) - 1 - i)
end_idx = i + seq_len
beg_idx = max(0, i - self.ext_len)
data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len]
return data, target, seq_len
def get_fixlen_iter(self, start=0):
for i in range(start, self.data.size(0) - 1, self.bptt):
yield self.get_batch(i)
def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
max_len = self.bptt + max_deviation * std
i = start
while True:
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
data, target, seq_len = self.get_batch(i, bptt)
i += seq_len
yield data, target, seq_len
if i >= self.data.size(0) - 2:
break
def __iter__(self):
return self.get_fixlen_iter()
class LMShuffledIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
"""
data -- list[LongTensor] -- there is no order among the LongTensors
"""
self.data = data
self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
self.shuffle = shuffle
def get_sent_stream(self):
# index iterator
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
else np.array(range(len(self.data)))
# sentence iterator
for idx in epoch_indices:
yield self.data[idx]
def stream_iterator(self, sent_stream):
# streams for each data in the batch
streams = [None] * self.bsz
data = torch.LongTensor(self.bptt, self.bsz)
target = torch.LongTensor(self.bptt, self.bsz)
n_retain = 0
while True:
# data : [n_retain+bptt x bsz]
# target : [bptt x bsz]
data[n_retain:].fill_(-1)
target.fill_(-1)
valid_batch = True
for i in range(self.bsz):
n_filled = 0
try:
while n_filled < self.bptt:
if streams[i] is None or len(streams[i]) <= 1:
streams[i] = next(sent_stream)
# number of new tokens to fill in
n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
# first n_retain tokens are retained from last batch
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
streams[i][:n_new]
target[n_filled:n_filled+n_new, i] = \
streams[i][1:n_new+1]
streams[i] = streams[i][n_new:]
n_filled += n_new
except StopIteration:
valid_batch = False
break
if not valid_batch:
return
data = data.to(self.device)
target = target.to(self.device)
yield data, target, self.bptt
n_retain = min(data.size(0), self.ext_len)
if n_retain > 0:
data[:n_retain] = data[-n_retain:]
data.resize_(n_retain + self.bptt, data.size(1))
def __iter__(self):
# sent_stream is an iterator
sent_stream = self.get_sent_stream()
for batch in self.stream_iterator(sent_stream):
yield batch
class LMMultiFileIterator(LMShuffledIterator):
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
shuffle=False):
self.paths = paths
self.vocab = vocab
self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
self.shuffle = shuffle
def get_sent_stream(self, path):
sents = self.vocab.encode_file(path, add_double_eos=True)
if self.shuffle:
np.random.shuffle(sents)
sent_stream = iter(sents)
return sent_stream
def __iter__(self):
if self.shuffle:
np.random.shuffle(self.paths)
for path in self.paths:
# sent_stream is an iterator
sent_stream = self.get_sent_stream(path)
for batch in self.stream_iterator(sent_stream):
yield batch
class Corpus(object):
def __init__(self, path, dataset, vocab, *args, **kwargs):
self.dataset = dataset
if vocab == 'word':
self.vocab = Vocab(*args, **kwargs)
elif vocab == 'bpe':
self.vocab = OpenAIVocab()
else:
raise RuntimeError('Unsupported vocab')
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
self.vocab.count_file(os.path.join(path, 'train.txt'))
self.vocab.count_file(os.path.join(path, 'valid.txt'))
self.vocab.count_file(os.path.join(path, 'test.txt'))
elif self.dataset == 'wt103':
self.vocab.count_file(os.path.join(path, 'train.txt'))
elif self.dataset == 'lm1b':
train_path_pattern = os.path.join(
path, '1-billion-word-language-modeling-benchmark-r13output',
'training-monolingual.tokenized.shuffled', 'news.en-*')
train_paths = glob.glob(train_path_pattern)
# the vocab will load from file when build_vocab() is called
self.vocab.build_vocab()
if self.dataset in ['ptb', 'wt2', 'wt103']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True)
elif self.dataset in ['enwik8', 'text8']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
elif self.dataset == 'lm1b':
self.train = train_paths
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
def get_iterator(self, split, *args, **kwargs):
if split == 'train':
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
data_iter = LMOrderedIterator(self.train, *args, **kwargs)
elif self.dataset == 'lm1b':
kwargs['shuffle'] = True
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
elif split in ['valid', 'test']:
data = self.valid if split == 'valid' else self.test
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == 'lm1b':
data_iter = LMShuffledIterator(data, *args, **kwargs)
return data_iter
def get_lm_corpus(datadir, dataset, vocab):
if vocab == 'word':
fn = os.path.join(datadir, 'cache.pt')
elif vocab == 'bpe':
fn = os.path.join(datadir, 'cache.pt.bpe')
else:
raise RuntimeError('Unsupported vocab')
if os.path.exists(fn):
logging.info('Loading cached dataset...')
corpus = torch.load(fn)
else:
logging.info('Producing dataset {}...'.format(dataset))
kwargs = {}
if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = False
elif dataset == 'ptb':
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = True
elif dataset == 'lm1b':
kwargs['special'] = []
kwargs['lower_case'] = False
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
elif dataset in ['enwik8', 'text8']:
pass
corpus = Corpus(datadir, dataset, vocab, **kwargs)
with utils.distributed.sync_workers() as rank:
if rank == 0:
torch.save(corpus, fn)
return corpus
def tokenize_raw(text, lang='en'):
mt = sacremoses.MosesTokenizer(lang)
text = mt.tokenize(text, return_str=True)
text = re.sub(r'&quot;', '"', text)
text = re.sub(r'&apos;', "'", text)
text = re.sub(r'(\d)\.(\d)', r'\1 @.@ \2', text)
text = re.sub(r'(\d),(\d)', r'\1 @,@ \2', text)
text = re.sub(r'(\w)-(\w)', r'\1 @-@ \2', text)
return text
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='unit test')
parser.add_argument('--datadir', type=str, default='../data/text8',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='text8',
choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
corpus = get_lm_corpus(args.datadir, args.dataset, vocab='word')
logging.info('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))

View file

@ -0,0 +1,320 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
import pickle
import sys
import time
import numpy as np
import torch
import data_utils
import utils
from data_utils import get_lm_corpus
from data_utils import tokenize_raw
from utils.exp_utils import AverageMeter
from utils.exp_utils import benchmark
from utils.exp_utils import create_exp_dir
def parse_args():
parser = argparse.ArgumentParser(
description='PyTorch Transformer Language Model',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--work_dir', default='LM-TFM', type=str,
help='experiment directory')
parser.add_argument('--debug', action='store_true',
help='run in debug mode (do not create exp dir)')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus')
parser.add_argument('--manual', type=str, default=None, nargs='+',
help='run model on raw input data')
parser.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
parser.add_argument('--split', type=str, default='all',
choices=['all', 'valid', 'test'],
help='which split to evaluate')
parser.add_argument('--type', type=str, default='pytorch',
choices=['pytorch', 'torchscript', 'onnx'],
help='type of runtime to use')
parser.add_argument('--batch_size', type=int, default=16,
help='batch size')
parser.add_argument('--tgt_len', type=int, default=64,
help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=640,
help='length of the retained previous heads')
parser.add_argument('--clamp_len', type=int, default=-1,
help='max positional embedding index')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--model', type=str, default='',
help='path to the checkpoint')
parser.add_argument('--fp16', action='store_true',
help='Run training in fp16/mixed precision')
parser.add_argument('--log_all_ranks', action='store_true',
help='Enable logging for all distributed ranks')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--target_perplexity', type=float, default=None,
help='target perplexity')
parser.add_argument('--target_throughput', type=float, default=None,
help='target throughput')
parser.add_argument('--save_data', action='store_true',
help='save latency and throughput data to a file')
parser.add_argument('--repeat', type=int, default=1,
help='loop over the dataset REPEAT times')
parser.add_argument('--max_size', type=int, default=None,
help='run inference on up to MAX_SIZE batches')
parser.add_argument('--percentiles', nargs='+', default=[90, 95, 99],
help='percentiles for latency confidence intervals')
parser.add_argument('--save_torchscript', default=None, type=str,
help='save torchscript model to a file')
parser.add_argument('--load_torchscript', default=None, type=str,
help='load torchscript model from a file')
parser.add_argument('--local_rank', default=0, type=int,
help='Used for multi-process training. ' +
'Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative'
return args
def load_checkpoint(path):
dst = f'cuda:{torch.cuda.current_device()}'
logging.info(f'Loading checkpoint from {path}')
checkpoint = torch.load(path, map_location=dst)
return checkpoint
def format_log(loss, split, args):
if args.dataset in ['enwik8', 'text8']:
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
split, loss, loss / math.log(2))
else:
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
split, loss, math.exp(loss))
return log_str
def evaluate(eval_iter, model, meters, max_size=None, repeat=1):
total_len, total_loss = 0, 0.
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
mems = None
for _ in range(repeat):
for idx, (data, target, seq_len) in enumerate(eval_iter):
if max_size and idx >= max_size:
break
torch.cuda.synchronize()
start_iter = time.time()
ret = model(data, target, mems)
torch.cuda.synchronize()
elapsed = time.time() - start_iter
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.item()
total_len += seq_len
meters['eval_latency'].update(elapsed)
target_tokens = target.numel()
throughput = target_tokens / elapsed
throughput = utils.distributed.all_reduce_item(throughput, op='sum')
meters['eval_throughput'].update(throughput)
utils.distributed.barrier()
torch.cuda.synchronize()
total_time = time.time() - start_time
logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
total_time, 1000 * total_time / (idx+1)))
avg_loss = total_loss / total_len
avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
return avg_loss
def compile_model(model, device, args):
inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
start = time.time()
with torch.no_grad():
mems = None
for _ in range(2):
ret = model(inp, tgt, mems)
_, mems = ret[0], ret[1:]
torch.cuda.synchronize()
stop = time.time()
logging.info(f'Building the model took {stop - start:.2f} seconds')
def main():
args = parse_args()
if args.type == 'pytorch':
from mem_transformer import MemTransformerLM
else:
from inference.mem_transformer_base_jit import MemTransformerLM
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda' if args.cuda else 'cpu')
utils.distributed.init_distributed(args.cuda)
with utils.distributed.sync_workers() as rank:
if rank == 0:
create_exp_dir(args.work_dir, debug=args.debug)
# Setup logging
if args.log_all_ranks:
log_file = f'log_rank_{utils.distributed.get_rank()}.log'
else:
log_file = f'log.log'
log_file = os.path.join(args.work_dir, log_file)
if args.debug:
log_file = os.devnull
utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
filename=log_file,
filemode='a',
)
logging.info(args)
if args.model:
model_path = args.model
elif args.work_dir:
model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
else:
raise RuntimeError('Specify path to checkpoint using --model or --work_dir')
checkpoint = load_checkpoint(model_path)
if args.manual:
args.batch_size = 1
vocab = checkpoint['vocab']
if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
vocab.unk_idx = vocab.sym2idx['<unk>']
text = " ".join(args.manual)
tokenized = tokenize_raw(text)
symbols = vocab.tokenize(tokenized, add_eos=True)
tensor = vocab.convert_to_tensor(symbols)
iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
bptt=args.tgt_len, device=device,
ext_len=args.ext_len)
else:
# Load dataset
corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab)
if args.split == 'valid':
iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
elif args.split == 'test':
iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
else:
raise RuntimeError('Unknown split')
if args.fp16:
dtype = torch.float16
math_str = 'fp16'
else:
dtype = torch.float32
math_str = 'fp32'
if args.load_torchscript:
model = torch.jit.load(args.load_torchscript)
else:
checkpoint['model_config']['tgt_len'] = args.tgt_len
checkpoint['model_config']['ext_len'] = args.ext_len
checkpoint['model_config']['mem_len'] = args.mem_len
checkpoint['model_config']['clamp_len'] = args.clamp_len
checkpoint['model_config']['same_length'] = args.same_length
checkpoint['model_config']['dtype'] = dtype
model = MemTransformerLM(**checkpoint['model_config'])
model.load_state_dict(checkpoint['model_state'])
model = model.eval()
model = model.to(device)
model = model.float()
if args.fp16:
model = model.half()
if args.type != 'pytorch':
compile_model(model, device, args)
if args.type == 'torchscript' and args.save_torchscript:
torch.jit.save(model, args.save_torchscript)
logging.info(f'Evaluating with: math {math_str} type {args.type} '
f'bsz {args.batch_size} tgt_len {args.tgt_len} '
f'ext_len {args.ext_len} mem_len {args.mem_len} '
f'clamp_len {args.clamp_len}')
meters = {}
warmup = args.mem_len // args.tgt_len + 1
meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
loss = evaluate(iter, model, meters, args.max_size, args.repeat)
perplexity = math.exp(loss)
log_str = format_log(loss, args.split, args)
logging.info('=' * 100)
logging.info(log_str)
logging.info('=' * 100)
if args.save_data:
latency_data = np.array(meters['eval_latency'].vals)
throughput_data = np.array(meters['eval_throughput'].vals)
precision = 'fp16' if args.fp16 else 'fp32'
data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
data_path = os.path.join(args.work_dir, data_fname)
data = {
'args': args,
'throughput': throughput_data,
'latency': latency_data,
}
with open(data_path, 'wb') as f:
pickle.dump(data, f)
logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
for p in args.percentiles:
logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
logging.info('=' * 100)
passed = benchmark(target_perplexity=args.target_perplexity,
test_perplexity=perplexity,
target_throughput=args.target_throughput,
test_throughput=meters['eval_throughput'].avg,
)
if not passed:
sys.exit(1)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 694 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

View file

@ -0,0 +1,469 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from inference.proj_adaptive_softmax_jit import ProjectedAdaptiveLogSoftmax
class PositionalEmbedding(torch.jit.ScriptModule):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
@torch.jit.script_method
def forward(self, pos_seq, bsz: Optional[int] = None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
if bsz is not None:
return pos_emb[:, None, :].expand(-1, bsz, -1)
else:
return pos_emb[:, None, :]
class PositionwiseFF(torch.jit.ScriptModule):
__constants__ = ['pre_lnorm']
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
@torch.jit.script_method
def forward(self, inp):
if self.pre_lnorm:
# layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
# residual connection
output = core_out + inp
else:
# positionwise feed-forward
core_out = self.CoreNet(inp)
# residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class RelMultiHeadAttn(torch.jit.ScriptModule):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def _parallelogram_mask(self, h, w, left=False):
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m, :m] = torch.triu(mask[:m, :m])
mask[-m:, -m:] = torch.tril(mask[-m:, -m:])
if left:
return mask.bool()
else:
return mask.flip(0).bool()
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:, :, None, None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
@torch.jit.script_method
def _rel_shift(self, x, zero_triu: bool = False):
zero_pad = torch.zeros((x.size(0), x.size(1), 1, x.size(3)),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=2)
x_padded = x_padded.view(x.size(0), x.size(2) + 1, x.size(1), x.size(3))
x = x_padded[:, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[None, :, :, None]
return x
@torch.jit.script_method
def forward(self, w, r, attn_mask, mems=None):
raise NotImplementedError
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
__constants__ = ['pre_lnorm', 'n_head', 'd_head', 'scale']
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
@torch.jit.script_method
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask,
mems: Optional[torch.Tensor] = None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
# compute attention score
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
# AC = torch.einsum('ibnd,jbnd->bijn', (rw_head_q, w_head_k)) # bsz x qlen x klen x n_head
rw_head_q = rw_head_q.view(qlen, bsz * self.n_head, self.d_head).permute(1, 0, 2)
w_head_k = w_head_k.reshape(klen, bsz * self.n_head, self.d_head).permute(1, 2, 0)
AC = torch.bmm(rw_head_q, w_head_k).view(bsz, self.n_head, qlen, klen).permute(0, 2, 3 ,1)
rr_head_q = w_head_q + r_r_bias
# BD = torch.einsum('ibnd,jnd->bijn', (rr_head_q, r_head_k)) # bsz x qlen x klen x n_head
rr_head_q = rr_head_q.permute(2, 1, 0, 3).reshape(self.n_head, bsz * qlen, self.d_head)
r_head_k = r_head_k.permute(1, 2, 0).view(self.n_head, self.d_head, klen)
BD = torch.bmm(rr_head_q, r_head_k).permute(1, 2, 0).view(bsz, qlen, klen, self.n_head)
BD = self._rel_shift(BD, False)
# [bsz x qlen x klen x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
# compute attention probability
if attn_mask is not None and attn_mask.any():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
# [bsz x qlen x klen x n_head]
attn_prob = F.softmax(attn_score, dim=2)
attn_prob = self.dropatt(attn_prob)
# compute attention vector
# attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, w_head_v))
attn_prob = attn_prob.permute(0, 3, 1 ,2).reshape(bsz * self.n_head, qlen, klen)
w_head_v = w_head_v.permute(1, 2, 0, 3).reshape(bsz * self.n_head, klen, self.d_head)
attn_vec = torch.bmm(attn_prob, w_head_v).permute(1, 0, 2).view(qlen, bsz, self.n_head, self.d_head)
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.reshape(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
# residual connection
output = w + attn_out
else:
# residual connection + layer normalization
output = self.layer_norm(w + attn_out)
output = output.type_as(w)
return output
class RelPartialLearnableDecoderLayer(torch.jit.ScriptModule):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
@torch.jit.script_method
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask,
mems: Optional[torch.Tensor] = None
):
output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class AdaptiveEmbedding(torch.jit.ScriptModule):
__constants__ = ['div_val', 'd_proj', 'd_embed', 'emb_scale']
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val != 1:
raise RuntimeError('TorchScripted model supports only div_val == 1')
if d_proj != d_embed:
raise RuntimeError('TorchScripted model supports only d_proj == d_embed')
self.emb_layers.append(nn.Embedding(n_token, d_embed))
@torch.jit.script_method
def forward(self, x):
for emb_layer in self.emb_layers:
x = emb_layer(x)
x.mul_(self.emb_scale)
return x
class MemTransformerLM(torch.jit.ScriptModule):
__constants__ = ['same_length', 'mem_len', 'clamp_len', 'ext_len',
'n_layer', 'dtype']
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, dtype, tie_weight=True, d_embed=None,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
d_embed = d_model if d_embed is None else d_embed
self.d_embed = d_embed
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)
self.drop = nn.Dropout(dropout)
self.n_layer = n_layer
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
self.dtype = dtype
self.attn_type = attn_type
if attn_type != 0:
raise RuntimeError('TorchScripted supports only attn_type == 0')
self.layers = nn.ModuleList()
# the default attention
for i in range(n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
cutoffs, div_val=div_val)
self.same_length = same_length
self.clamp_len = clamp_len
self._create_params()
def _create_params(self):
self.pos_emb = PositionalEmbedding(self.d_model)
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
@torch.jit.script_method
def init_mems(self):
mems = []
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=self.dtype, device=torch.device('cuda'))
mems.append(empty)
return mems
def _update_mems(self, hids: List[torch.Tensor], mems: List[torch.Tensor],
qlen: int, mlen: int):
assert len(hids) == len(mems), 'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = torch.cat([mems[i], hids[i]], dim=0)
new_mems.append(cat[beg_idx:end_idx].detach())
return new_mems
@torch.jit.script_method
def _forward(self, dec_inp, mems: List[torch.Tensor]):
qlen, bsz = dec_inp.size()
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0)
klen = mlen + qlen
if self.same_length:
# all_ones = word_emb.new_ones(qlen, klen)
all_ones = torch.ones((qlen, klen), device=torch.device('cuda'), dtype=torch.float32)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len)).to(torch.bool)
else:
all_ones = torch.ones((qlen, klen), device=torch.device('cuda'), dtype=torch.float32)
dec_attn_mask = torch.triu(
all_ones, diagonal=1+mlen).to(torch.bool)
hids = []
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
i = 0
for layer in self.layers:
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
i += 1
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, qlen, mlen)
return core_out, new_mems
@torch.jit.script_method
def forward(self, data, target, mems: Optional[List[torch.Tensor]]):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if mems is None:
mems = self.init_mems()
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
pred_hid = hidden[-tgt_len:]
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
loss = loss.view(tgt_len, -1)
if new_mems is None:
return [loss]
else:
return [loss] + new_mems

View file

@ -0,0 +1,141 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class ProjectedAdaptiveLogSoftmax(torch.jit.ScriptModule):
__constants__ = ['n_clusters', 'cutoffs', 'cutoff_ends', 'keep_order']
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False):
super().__init__()
self.n_token = n_token
self.d_embed = d_embed
self.d_proj = d_proj
self.cutoffs = cutoffs + [n_token]
self.cutoff_ends = type(self.cutoffs)([0]) + self.cutoffs
self.div_val = div_val
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
if self.n_clusters > 0:
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.out_layers = nn.ModuleList()
if d_proj != d_embed:
raise RuntimeError('TorchScripted module requires d_proj == d_embed')
if div_val != 1:
raise RuntimeError('TorchScripted module requires div_val == 1')
self.out_layers.append(nn.Linear(d_embed, n_token))
self.keep_order = keep_order
@torch.jit.script_method
def _compute_logit(self, hidden, weight, bias, proj: Optional[torch.Tensor]):
if proj is not None:
raise RuntimeError('TorchScripted module requires proj == None')
logit = F.linear(hidden, weight, bias=bias)
return logit
@torch.jit.script_method
def forward(self, hidden, target, keep_order: bool = False):
'''
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
'''
if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if self.n_clusters == 0:
for out_layer in self.out_layers:
hidden = self._compute_logit(hidden, out_layer.weight,
out_layer.bias, None)
nll = -F.log_softmax(hidden, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
for out_layer in self.out_layers:
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = out_layer.weight[l_idx:r_idx]
bias_i = out_layer.bias[l_idx:r_idx]
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], None
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target, layout=torch.strided,
dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
weight_i, bias_i, proj_i = weights[i], biases[i], None
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i,
bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob_i[:, -i] \
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
if self.keep_order or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return nll

View file

@ -0,0 +1,247 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2019 cybertronai
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Lamb optimizer."""
import collections
import math
import torch
from torch.optim import Optimizer
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# Paper v3 does not use debiasing.
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
weight_norm = p.data.norm(p=2).clamp_(0, 10)
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
adam_step.add_(group['weight_decay'], p.data)
adam_norm = adam_step.norm(p=2)
trust_ratio = weight_norm / (adam_norm + group['eps'])
state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm
state['trust_ratio'] = trust_ratio
if self.adam:
trust_ratio = 1
p.data.add_(-step_size * trust_ratio, adam_step)
return loss
@torch.jit.script
def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float,
beta2: float, step_size: float, eps: float, weight_decay: float):
exp_avg = exp_avg * beta1 + (1 - beta1) * grad
exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad)
adam_step = exp_avg / (exp_avg_sq.sqrt() + eps)
adam_step = adam_step + weight_decay * param
weight_norm = param.norm(p=2).clamp_(0, 10)
adam_norm = adam_step.norm(p=2)
trust_ratio = weight_norm / (adam_norm + eps)
param = param - step_size * trust_ratio * adam_step
return param, exp_avg, exp_avg_sq
class JITLamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super().__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
step_size = group['lr']
param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg,
exp_avg_sq, beta1,
beta2, step_size,
group['eps'],
group['weight_decay'],
)
state['exp_avg'] = exp_avg
state['exp_avg_sq'] = exp_avg_sq
p.data = param
return loss

View file

@ -0,0 +1,842 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
from utils.log_uniform_sampler import LogUniformSampler
from utils.log_uniform_sampler import sample_logits
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:, None, :].expand(-1, bsz, -1)
else:
return pos_emb[:, None, :]
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
if self.pre_lnorm:
# layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
# residual connection
output = core_out + inp
else:
# positionwise feed-forward
core_out = self.CoreNet(inp)
# residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def forward(self, h, attn_mask=None, mems=None):
# multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h
if self.pre_lnorm:
# layer normalization
c = self.layer_norm(c)
head_q = self.q_net(h)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [bsz x qlen x klen x n_head]
attn_score = torch.einsum('ibnd,jbnd->bijn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
# [bsz x qlen x klen x n_head]
attn_prob = F.softmax(attn_score, dim=2)
attn_prob = self.dropatt(attn_prob)
# [bsz x qlen x klen x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
# residual connection
output = h + attn_out
else:
# residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def _parallelogram_mask(self, h, w, left=False):
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m, :m] = torch.triu(mask[:m, :m])
mask[-m:, -m:] = torch.tril(mask[-m:, -m:])
if left:
return mask.bool()
else:
return mask.flip(0).bool()
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:, :, None, None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), x.size(1), 1, x.size(3)),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=2)
x_padded = x_padded.view(x.size(0), x.size(2) + 1, x.size(1), x.size(3))
x = x_padded[:, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[None, :, :, None]
return x
def forward(self, w, r, attn_mask=None, mems=None):
raise NotImplementedError
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
# compute attention score
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->bijn', (rw_head_q, w_head_k)) # bsz x qlen x klen x n_head
rr_head_q = w_head_q + r_r_bias
BD = torch.einsum('ibnd,jnd->bijn', (rr_head_q, r_head_k)) # bsz x qlen x klen x n_head
BD = self._rel_shift(BD)
# [bsz x qlen x klen x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
# compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
# [bsz x qlen x klen x n_head]
attn_prob = F.softmax(attn_score, dim=2)
attn_prob = self.dropatt(attn_prob)
# compute attention vector
attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
# residual connection
output = w + attn_out
else:
# residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
if klen > r_emb.size(0):
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
r_emb = torch.cat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
r_bias = torch.cat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
# compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->bijn', (rw_head_q, w_head_k)) # bsz x qlen x klen x n_head
B_ = torch.einsum('ibnd,jnd->bijn', (w_head_q, r_emb)) # bsz x qlen x klen x n_head
D_ = r_bias[None, None, :, :] # 1 x 1 x klen x n_head
BD = self._rel_shift(B_ + D_)
# [bsz x qlen x klen x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
# compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
# [bsz x qlen x klen x n_head]
attn_prob = F.softmax(attn_score, dim=2)
attn_prob = self.dropatt(attn_prob)
# compute attention vector
attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
# residual connection
output = w + attn_out
else:
# residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head,
dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=(sample_softmax > 0))
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i)
embed = emb_flat.view(*inp.size(), self.d_proj)
embed.mul_(self.emb_scale)
return embed
class MemTransformerLM(nn.Module):
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, dtype, tie_weight=True, d_embed=None,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
d_embed = d_model if d_embed is None else d_embed
self.d_embed = d_embed
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)
self.drop = nn.Dropout(dropout)
self.n_layer = n_layer
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
self.attn_type = attn_type
self.layers = nn.ModuleList()
# the default attention
if attn_type == 0:
for i in range(n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
# learnable embeddings
elif attn_type == 1:
for i in range(n_layer):
self.layers.append(
RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
# absolute embeddings
elif attn_type in [2, 3]:
for i in range(n_layer):
self.layers.append(
DecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
self.sample_softmax = sample_softmax
# use sampled softmax
if sample_softmax > 0:
self.out_layer = nn.Linear(d_model, n_token)
if tie_weight:
self.out_layer.weight = self.word_emb.weight
self.tie_weight = tie_weight
self.sampler = LogUniformSampler(n_token, sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
cutoffs, div_val=div_val)
if tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
if tie_projs:
for i, tie_proj in enumerate(tie_projs):
if tie_proj and div_val == 1 and d_model != d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
elif tie_proj and div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
self.same_length = same_length
self.clamp_len = clamp_len
self._create_params()
def backward_compatible(self):
self.sample_softmax = -1
def _create_params(self):
# default attention
if self.attn_type == 0:
self.pos_emb = PositionalEmbedding(self.d_model)
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
# learnable
elif self.attn_type == 1:
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head))
# absolute standard
elif self.attn_type == 2:
self.pos_emb = PositionalEmbedding(self.d_model)
# absolute deeper SA
elif self.attn_type == 3:
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.d_model))
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
def init_mems(self):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device)
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None:
return None
# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with torch.no_grad():
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = torch.cat([mems[i], hids[i]], dim=0)
new_mems.append(cat[beg_idx:end_idx].detach())
return new_mems
def _forward(self, dec_inp, mems=None):
qlen, bsz = dec_inp.size()
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).bool()
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()
hids = []
# default
if self.attn_type == 0:
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
# learnable
elif self.attn_type == 1:
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len:]
r_bias = self.r_bias[i][-self.clamp_len:]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
# absolute
elif self.attn_type == 2:
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and len(mems_i) and i == 0:
mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and len(mems_i) and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, qlen, mlen)
return core_out, new_mems
def forward(self, data, target, mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if mems is None:
mems = self.init_mems()
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training:
assert self.tie_weight
logit = sample_logits(self.word_emb, self.out_layer.bias, target,
pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
loss = loss.view(tgt_len, -1)
if new_mems is None:
return [loss]
else:
return [loss] + new_mems
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='unit test')
parser.add_argument('--n_layer', type=int, default=4, help='')
parser.add_argument('--n_rel_layer', type=int, default=4, help='')
parser.add_argument('--n_head', type=int, default=2, help='')
parser.add_argument('--d_head', type=int, default=2, help='')
parser.add_argument('--d_model', type=int, default=200, help='')
parser.add_argument('--d_embed', type=int, default=200, help='')
parser.add_argument('--d_inner', type=int, default=200, help='')
parser.add_argument('--dropout', type=float, default=0.0, help='')
parser.add_argument('--cuda', action='store_true', help='')
parser.add_argument('--seed', type=int, default=1111, help='')
parser.add_argument('--multi_gpu', action='store_true', help='')
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
B = 4
tgt_len, mem_len, ext_len = 36, 36, 0
data_len = tgt_len * 20
args.n_token = 10000
import data_utils
data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
cutoffs = [args.n_token // 2]
tie_projs = [False] + [True] * len(cutoffs)
for div_val in [1, 2]:
for d_embed in [200, 100]:
model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
args.d_model, args.d_head, args.d_inner, args.dropout,
dropatt=args.dropout, tie_weight=True,
d_embed=d_embed, div_val=div_val,
tie_projs=tie_projs, pre_lnorm=True,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
cutoffs=cutoffs, attn_type=0).to(device)
print(sum(p.numel() for p in model.parameters()))
mems = tuple()
for idx, (inp, tgt, seqlen) in enumerate(diter):
print('batch {}'.format(idx))
out = model(inp, tgt, *mems)
mems = out[1:]

View file

@ -0,0 +1,2 @@
pytorch-transformers==1.1.0
sacremoses==0.0.35

View file

@ -0,0 +1,41 @@
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,41 @@
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 24 \
--d_model 1024 \
--n_head 8 \
--d_head 128 \
--d_inner 3072 \
--dropout 0.15 \
--dropatt 0.15 \
--optim adam \
--lr 0.00025 \
--warmup_step 4000 \
--max_step 400000 \
--tgt_len 768 \
--mem_len 768 \
--eval_tgt_len 128 \
--batch_size 64 \
--multi_gpu \
--gpu0_bsz 0 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 128 \
--mem_len 3800 \
--clamp_len 1000 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,43 @@
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--adaptive \
--n_layer 18 \
--d_model 1024 \
--div_val 4 \
--n_head 8 \
--d_head 128 \
--d_inner 4096 \
--dropout 0.0 \
--dropatt 0.0 \
--optim adam \
--warmup_step 20000 \
--max_step 500000 \
--lr 0.00025 \
--tgt_len 32 \
--mem_len 32 \
--eval_tgt_len 32 \
--batch_size 224 \
--multi_gpu \
--gpu0_bsz 32 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--batch_size 64 \
--tgt_len 32 \
--mem_len 128 \
--split test \
--same_length \
${@:2}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,43 @@
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--adaptive \
--div_val 4 \
--n_layer 24 \
--d_model 1280 \
--n_head 16 \
--d_head 80 \
--d_inner 8192 \
--dropout 0.05 \
--dropatt 0.05 \
--optim adam \
--warmup_step 30000 \
--max_step 1200000 \
--lr 0.00025 \
--tgt_len 32 \
--mem_len 32 \
--eval_tgt_len 32 \
--batch_size 512 \
--multi_gpu \
--gpu0_bsz 0 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--batch_size 8 \
--tgt_len 32 \
--mem_len 128 \
--split test \
--same_length \
${@:2}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,41 @@
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,38 @@
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--n_layer 24 \
--d_model 1024 \
--n_head 8 \
--d_head 128 \
--d_inner 3072 \
--dropout 0.15 \
--dropatt 0.15 \
--optim adam \
--lr 0.00025 \
--tgt_len 768 \
--mem_len 768 \
--eval_tgt_len 128 \
--batch_size 64 \
--max_step 400000 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/text8/ \
--dataset text8 \
--tgt_len 128 \
--mem_len 3800 \
--clamp_len 1000 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python -m torch.distributed.launch --nproc_per_node=$2 train.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--n_layer 16 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim jitlamb \
--lr 0.01 \
--eta_min 0.001 \
--roll \
--warmup_step 1000 \
--max_step 40000 \
--tgt_len 192 \
--mem_len 192 \
--eval_tgt_len 192 \
--batch_size 256 \
--multi_gpu ddp \
--log_interval 10 \
--eval_interval 5000 \
${@:3}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python -m torch.distributed.launch --nproc_per_node=$2 eval.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--tgt_len 64 \
--mem_len 640 \
--clamp_len 400 \
--same_length \
--split test \
${@:3}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python -m torch.distributed.launch --nproc_per_node=$2 train.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--n_layer 18 \
--d_model 1024 \
--n_head 16 \
--d_head 64 \
--d_inner 4096 \
--dropout 0.2 \
--dropatt 0.2 \
--optim adam \
--lr 0.00025 \
--warmup_step 16000 \
--max_step 4000000 \
--tgt_len 256 \
--mem_len 256 \
--eval_tgt_len 128 \
--batch_size 128 \
--multi_gpu ddp \
${@:3}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python -m torch.distributed.launch --nproc_per_node=$2 eval.py \
--cuda \
--data ../data/wikitext-103/ \
--dataset wt103 \
--tgt_len 128 \
--mem_len 1600 \
--clamp_len 1000 \
--same_length \
--split test \
${@:3}
else
echo 'unknown argment 1'
fi

View file

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker build . --network=host --rm -t transformer-xl:latest

View file

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
nvidia-docker run --init -it --rm --network=host --ipc=host -v $(dirname $PWD):/workspace/transformer-xl transformer-xl bash

View file

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MODEL=${MODEL:-"LM-TFM/checkpoint_best.pt"}
BATCH_SIZES=(1 2 4 8 16 32)
TYPES=("pytorch" "torchscript")
# "empty" MATH corresponds to fp32
MATHS=("" "--fp16")
for (( i = 0; i < ${#TYPES[@]}; i++ )); do
for (( j = 0; j < ${#BATCH_SIZES[@]}; j++ )); do
for (( k = 0; k < ${#MATHS[@]}; k++ )); do
echo type: ${TYPES[i]} batch size: ${BATCH_SIZES[j]} math: ${MATHS[k]}
taskset -c 0 bash run_wt103_base.sh eval 1 \
--model "${MODEL}" \
--type "${TYPES[i]}" \
--batch_size "${BATCH_SIZES[j]}" \
"${MATHS[k]}" \
--save_data \
"${@:1}"
done
done
done

View file

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_inference_throughput
MATH=$1
if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
exit 1
fi
if [[ ${MATH} == 'fp16' ]]; then
MATH_OPT='--fp16'
elif [[ ${MATH} == 'fp32' ]]; then
MATH_OPT=''
fi
TYPE=$2
if [[ ${TYPE} != "pytorch" && ${TYPE} != "torchscript" ]]; then
echo "Unsupported option for TYPE, use either 'pytorch' or 'torchscript'"
exit 1
fi
PERF_TOLERANCE=0.9
BATCH_SIZE=16
GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
echo 'GPU_NAME:' "${GPU_NAME}"
GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
echo 'GPU_COUNT:' "${GPU_COUNT}"
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
echo 'GPU_MEM:' "${GPU_MEM}"
REFERENCE_PERF=$(grep "${MATH},${BATCH_SIZE},${GPU_NAME}" \
${REFERENCE_FILE} | \cut -f 4 -d ',')
if [ -z "${REFERENCE_PERF}" ]; then
echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
TARGET_PERF=''
else
PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
fi
cd $REPO_DIR
export CUDA_VISIBLE_DEVICES=0
bash run_wt103_base.sh eval 1 \
--model checkpoint/checkpoint_best.pt \
--target_perplexity 23.4 \
--batch_size "${BATCH_SIZE}" \
--type "${TYPE}" \
"${MATH_OPT}" \
"${TARGET_PERF}"

View file

@ -0,0 +1,6 @@
fp16,16,Tesla V100-SXM2-16GB,40000
fp32,16,Tesla V100-SXM2-16GB,18750
fp16,16,Tesla V100-SXM2-32GB,40000
fp32,16,Tesla V100-SXM2-32GB,18750
fp16,16,Tesla V100-SXM3-32GB,40000
fp32,16,Tesla V100-SXM3-32GB,18750

View file

@ -0,0 +1,10 @@
fp16,4,Tesla V100-SXM2-16GB,126000
fp32,4,Tesla V100-SXM2-16GB,45000
fp16,4,Tesla V100-SXM2-32GB,126000
fp32,4,Tesla V100-SXM2-32GB,45000
fp16,8,Tesla V100-SXM2-16GB,233000
fp32,8,Tesla V100-SXM2-16GB,88000
fp16,8,Tesla V100-SXM2-32GB,233000
fp32,8,Tesla V100-SXM2-32GB,88000
fp16,16,Tesla V100-SXM3-32GB,356000
fp32,16,Tesla V100-SXM3-32GB,176000

View file

@ -0,0 +1,78 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
MATH=$1
if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
exit 1
fi
if [[ ${MATH} == 'fp16' ]]; then
MATH_OPT='--fp16'
elif [[ ${MATH} == 'fp32' ]]; then
MATH_OPT=''
fi
PERF_TOLERANCE=0.9
GLOBAL_BATCH_SIZE=256
GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
echo 'GPU_NAME:' "${GPU_NAME}"
GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
echo 'GPU_COUNT:' "${GPU_COUNT}"
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
echo 'GPU_MEM:' "${GPU_MEM}"
if (( GPU_MEM > 16500 )); then
LOCAL_BATCH_SIZE=32
else
if [[ ${MATH} == 'fp16' ]]; then
LOCAL_BATCH_SIZE=32
elif [[ ${MATH} == 'fp32' ]]; then
LOCAL_BATCH_SIZE=16
fi
fi
BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
${REFERENCE_FILE} | \cut -f 4 -d ',')
if [ -z "${REFERENCE_PERF}" ]; then
echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
TARGET_PERF=''
else
PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
fi
cd $REPO_DIR
bash run_wt103_base.sh train "${GPU_COUNT}" \
--debug \
--max_step $((256 / GPU_COUNT)) \
--batch_chunk "${BATCH_CHUNK}" \
--log_interval 1 \
--adaptive \
--vocab word \
"${MATH_OPT}" \
"${TARGET_PERF}"

View file

@ -0,0 +1,79 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
MATH=$1
if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
exit 1
fi
if [[ ${MATH} == 'fp16' ]]; then
MATH_OPT='--fp16'
elif [[ ${MATH} == 'fp32' ]]; then
MATH_OPT=''
fi
PERF_TOLERANCE=0.9
GLOBAL_BATCH_SIZE=256
GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
echo 'GPU_NAME:' "${GPU_NAME}"
GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
echo 'GPU_COUNT:' "${GPU_COUNT}"
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
echo 'GPU_MEM:' "${GPU_MEM}"
if (( GPU_MEM > 16500 )); then
LOCAL_BATCH_SIZE=32
else
if [[ ${MATH} == 'fp16' ]]; then
LOCAL_BATCH_SIZE=32
elif [[ ${MATH} == 'fp32' ]]; then
LOCAL_BATCH_SIZE=16
fi
fi
BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
${REFERENCE_FILE} | \cut -f 4 -d ',')
if [ -z "${REFERENCE_PERF}" ]; then
echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
TARGET_PERF=''
else
PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
fi
cd $REPO_DIR
bash run_wt103_base.sh train "${GPU_COUNT}" \
--debug \
--max_step 40000 \
--target_perplexity 23.4 \
--batch_chunk "${BATCH_CHUNK}" \
--log_interval 1 \
--adaptive \
--vocab word \
"${MATH_OPT}" \
"${TARGET_PERF}"

View file

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
MATH=$1
if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
exit 1
fi
if [[ ${MATH} == 'fp16' ]]; then
MATH_OPT='--fp16'
elif [[ ${MATH} == 'fp32' ]]; then
MATH_OPT=''
fi
PERF_TOLERANCE=0.9
GLOBAL_BATCH_SIZE=256
GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
echo 'GPU_NAME:' "${GPU_NAME}"
GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
echo 'GPU_COUNT:' "${GPU_COUNT}"
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
echo 'GPU_MEM:' "${GPU_MEM}"
if (( GPU_MEM > 16500 )); then
LOCAL_BATCH_SIZE=32
else
if [[ ${MATH} == 'fp16' ]]; then
LOCAL_BATCH_SIZE=32
elif [[ ${MATH} == 'fp32' ]]; then
LOCAL_BATCH_SIZE=16
fi
fi
BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
${REFERENCE_FILE} | \cut -f 4 -d ',')
if [ -z "${REFERENCE_PERF}" ]; then
echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
TARGET_PERF=''
else
PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
fi
cd $REPO_DIR
bash run_wt103_base.sh train "${GPU_COUNT}" \
--debug \
--max_step 30000 \
--max_step_scheduler 40000 \
--target_perplexity 24.2 \
--batch_chunk "${BATCH_CHUNK}" \
--log_interval 1 \
--adaptive \
--vocab word \
"${MATH_OPT}" \
"${TARGET_PERF}"

View file

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
MATH=$1
if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
exit 1
fi
if [[ ${MATH} == 'fp16' ]]; then
MATH_OPT='--fp16'
elif [[ ${MATH} == 'fp32' ]]; then
MATH_OPT=''
fi
PERF_TOLERANCE=0.9
GLOBAL_BATCH_SIZE=256
GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
echo 'GPU_NAME:' "${GPU_NAME}"
GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
echo 'GPU_COUNT:' "${GPU_COUNT}"
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
echo 'GPU_MEM:' "${GPU_MEM}"
if (( GPU_MEM > 16500 )); then
LOCAL_BATCH_SIZE=32
else
if [[ ${MATH} == 'fp16' ]]; then
LOCAL_BATCH_SIZE=32
elif [[ ${MATH} == 'fp32' ]]; then
LOCAL_BATCH_SIZE=16
fi
fi
BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
${REFERENCE_FILE} | \cut -f 4 -d ',')
if [ -z "${REFERENCE_PERF}" ]; then
echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
TARGET_PERF=''
else
PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
fi
cd $REPO_DIR
bash run_wt103_base.sh train "${GPU_COUNT}" \
--debug \
--max_step 5000 \
--max_step_scheduler 40000 \
--target_perplexity 43.5 \
--batch_chunk "${BATCH_CHUNK}" \
--log_interval 1 \
--adaptive \
--vocab word \
"${MATH_OPT}" \
"${TARGET_PERF}"

View file

@ -0,0 +1,810 @@
# coding: utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
import itertools
import logging
import math
import os
import time
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from apex.parallel import DistributedDataParallel
import lamb
import utils
from apex import amp
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.data_parallel import BalancedDataParallel
from utils.exp_utils import create_exp_dir
from utils.exp_utils import benchmark
from utils.exp_utils import AverageMeter
def parse_args():
parser = argparse.ArgumentParser(
description='PyTorch Transformer-XL Language Model',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
general = parser.add_argument_group('general setup')
general.add_argument('--work_dir', default='LM-TFM', type=str,
help='Directory for the results')
general.add_argument('--append_dataset', action='store_true',
help='Automatically append dataset name to work_dir')
general.add_argument('--append_time', action='store_true',
help='Automatically append current time to work_dir')
general.add_argument('--cuda', action='store_true',
help='Use CUDA')
general.add_argument('--fp16', action='store_true',
help='Run training in fp16/mixed precision')
general.add_argument('--restart', type=str, default='',
help='Restart training from the saved checkpoint')
general.add_argument('--debug', action='store_true',
help='Run in debug mode (do not create exp dir)')
general.add_argument('--log_all_ranks', action='store_true',
help='Enable logging from all distributed ranks')
general.add_argument('--save-all', action='store_true',
help='Save all checkpoints')
general.add_argument('--log_interval', type=int, default=10,
help='Report interval')
general.add_argument('--target_throughput', type=float, default=None,
help='Target training throughput (for benchmarking)')
general.add_argument('--target_perplexity', type=float, default=None,
help='Target validation perplexity (for benchmarking)')
dataset = parser.add_argument_group('dataset setup')
dataset.add_argument('--data', type=str, default='../data/wikitext-103',
help='Location of the data corpus')
dataset.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='Dataset name')
dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
help='Type of vocabulary')
model = parser.add_argument_group('model setup')
model.add_argument('--n_layer', type=int, default=16,
help='Number of total layers')
model.add_argument('--n_head', type=int, default=8,
help='Number of heads')
model.add_argument('--d_head', type=int, default=64,
help='Head dimension')
model.add_argument('--d_embed', type=int, default=-1,
help='Embedding dimension')
model.add_argument('--d_model', type=int, default=512,
help='Model dimension')
model.add_argument('--d_inner', type=int, default=2048,
help='Inner dimension in feedforward layer')
model.add_argument('--dropout', type=float, default=0.1,
help='Global dropout rate')
model.add_argument('--dropatt', type=float, default=0.0,
help='Attention probability dropout rate')
model.add_argument('--pre_lnorm', action='store_true',
help='Apply LayerNorm to the input instead of the output')
model.add_argument('--attn_type', type=int, default=0,
help='Attention type. 0 for ours, 1 for Shaw et al,'
'2 for Vaswani et al, 3 for Al Rfou et al.')
model.add_argument('--not_tied', action='store_true',
help='Do not tie the word embedding and softmax weights')
model.add_argument('--clamp_len', type=int, default=-1,
help='Use the same pos embeddings after clamp_len')
model.add_argument('--adaptive', action='store_true',
help='Use adaptive softmax')
model.add_argument('--div_val', type=int, default=1,
help='Dividend value for adaptive input and softmax')
model.add_argument('--sample_softmax', type=int, default=-1,
help='Number of samples in sampled softmax')
model.add_argument('--init', default='normal', type=str,
help='Parameter initializer to use')
model.add_argument('--emb_init', default='normal', type=str,
help='Parameter initializer to use')
model.add_argument('--init_range', type=float, default=0.1,
help='Parameters initialized by U(-init_range, init_range)')
model.add_argument('--emb_init_range', type=float, default=0.01,
help='Parameters initialized by U(-init_range, init_range)')
model.add_argument('--init_std', type=float, default=0.02,
help='Parameters initialized by N(0, init_std)')
model.add_argument('--proj_init_std', type=float, default=0.01,
help='Parameters initialized by N(0, init_std)')
opt = parser.add_argument_group('optimizer setup')
opt.add_argument('--optim', default='jitlamb', type=str,
choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
help='Optimizer to use')
opt.add_argument('--lr', type=float, default=0.01,
help='Initial learning rate')
opt.add_argument('--mom', type=float, default=0.0,
help='Momentum for sgd')
opt.add_argument('--scheduler', default='cosine', type=str,
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
help='LR scheduler to use')
opt.add_argument('--max_step_scheduler', type=int, default=None,
help='Max number of training steps for LR scheduler')
opt.add_argument('--warmup_step', type=int, default=1000,
help='Number of iterations for LR warmup')
opt.add_argument('--decay_rate', type=float, default=0.5,
help='Decay factor when ReduceLROnPlateau is used')
opt.add_argument('--lr_min', type=float, default=0.0,
help='Minimum learning rate during annealing')
opt.add_argument('--clip', type=float, default=0.25,
help='Gradient clipping')
opt.add_argument('--weight_decay', type=float, default=0.0,
help='Weight decay for adam|lamb')
opt.add_argument('--clip_nonemb', action='store_true',
help='Only clip the gradient of non-embedding params')
opt.add_argument('--patience', type=int, default=0,
help='Patience')
opt.add_argument('--eta_min', type=float, default=0.001,
help='Min learning rate for cosine scheduler')
training = parser.add_argument_group('training setup')
training.add_argument('--max_step', type=int, default=40000,
help='Max number of training steps')
training.add_argument('--batch_size', type=int, default=256,
help='Global batch size')
training.add_argument('--batch_chunk', type=int, default=1,
help='Split batch into chunks to save memory')
training.add_argument('--roll', action='store_true',
help='Enable random shifts within each data stream')
training.add_argument('--tgt_len', type=int, default=192,
help='Number of tokens to predict')
training.add_argument('--ext_len', type=int, default=0,
help='Length of the extended context')
training.add_argument('--mem_len', type=int, default=192,
help='Length of the retained previous heads')
training.add_argument('--seed', type=int, default=1111,
help='Random seed')
training.add_argument('--multi_gpu', default=None, type=str,
choices=['ddp', 'dp'],
help='Use multiple GPU')
training.add_argument('--gpu0_bsz', type=int, default=-1,
help='Batch size on gpu 0 (for "dp" backend)')
training.add_argument('--same_length', action='store_true',
help='Use the same attn length for all tokens')
training.add_argument('--varlen', action='store_true',
help='Use variable length')
val = parser.add_argument_group('validation setup')
val.add_argument('--eval_tgt_len', type=int, default=192,
help='Number of tokens to predict for evaluation')
val.add_argument('--eval_batch_size', type=int, default=16,
help='Eval batch size')
val.add_argument('--eval_max_steps', type=int, default=-1,
help='Max eval steps')
val.add_argument('--eval_interval', type=int, default=5000,
help='Evaluation interval')
dist = parser.add_argument_group('distributed setup')
dist.add_argument('--local_rank', default=0, type=int,
help='Used for multi-process training. ' +
'Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'')
args = parser.parse_args()
args.tied = not args.not_tied
if args.d_embed < 0:
args.d_embed = args.d_model
assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0
return args
def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
train_step, best_val_loss, work_dir, name='checkpoint.pt'):
if args.fp16:
amp_state = amp.state_dict()
else:
amp_state = None
state = {
'args': args,
'model_config': model_config,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'vocab': vocab,
'amp_state': amp_state,
'train_step': train_step,
'best_val_loss': best_val_loss,
}
with utils.distributed.sync_workers() as rank:
path = os.path.join(work_dir, name)
logging.info(f'Saving checkpoint to {path}')
if rank == 0:
torch.save(state, path)
def load_checkpoint(path):
if os.path.isdir(path):
path = os.path.join(path, 'checkpoint_last.pt')
dst = f'cuda:{torch.cuda.current_device()}'
logging.info(f'Loading checkpoint from {path}')
checkpoint = torch.load(path, map_location=dst)
return checkpoint
def init_weight(weight, args):
if args.init == 'uniform':
nn.init.uniform_(weight, -args.init_range, args.init_range)
elif args.init == 'normal':
nn.init.normal_(weight, 0.0, args.init_std)
def init_bias(bias):
nn.init.constant_(bias, 0.0)
def weights_init(m, args):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight, args)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight, args)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight, args)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, args.init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb, args)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias, args)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias, args)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)
def update_dropout(m, args):
classname = m.__class__.__name__
if classname.find('Dropout') != -1:
if hasattr(m, 'p'):
m.p = args.dropout
def update_dropatt(m, args):
if hasattr(m, 'dropatt'):
m.dropatt.p = args.dropatt
def evaluate(eval_iter, model, args):
# Turn on evaluation mode which disables dropout.
model.eval()
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if args.mem_len == 0:
model.reset_length(tgt_len=args.eval_tgt_len,
ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
mem_len=args.mem_len
)
else:
model.reset_length(tgt_len=args.eval_tgt_len,
ext_len=args.ext_len,
mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
)
# Evaluation
total_len, total_loss = 0, 0.
with torch.no_grad():
mems = None
for i, (data, target, seq_len) in enumerate(eval_iter):
if args.eval_max_steps > 0 and i >= args.eval_max_steps:
break
ret = model(data, target, mems)
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.float().item()
total_len += seq_len
# Switch back to the training mode
model.reset_length(tgt_len=args.tgt_len,
ext_len=args.ext_len,
mem_len=args.mem_len
)
model.train()
return total_loss / total_len
def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step,
best_val_loss, meters, args):
# Turn on training mode which enables dropout.
model.train()
train_loss = 0
target_tokens = 0
log_step = 0
log_start_time = time.time()
mems = [None for _ in range(args.batch_chunk)]
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter):
log_step += 1
target_tokens += target.numel()
model.zero_grad()
data_chunks = torch.chunk(data, args.batch_chunk, 1)
target_chunks = torch.chunk(target, args.batch_chunk, 1)
for i in range(args.batch_chunk):
data_i = data_chunks[i].contiguous()
target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, mems[i])
loss, mems[i] = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
train_loss += loss.float().item()
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
if optimizer_sparse:
optimizer_sparse.step()
# step-wise learning rate annealing
train_step += 1
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
# linear warmup stage
if train_step < args.warmup_step:
curr_lr = args.lr * train_step / args.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
if optimizer_sparse:
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
else:
if args.scheduler == 'cosine':
scheduler.step(train_step)
if scheduler_sparse:
scheduler_sparse.step(train_step)
elif args.scheduler == 'inv_sqrt':
scheduler.step(train_step)
if train_step % args.log_interval == 0:
cur_loss = train_loss / log_step
cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
train_loss = 0
elapsed = time.time() - log_start_time
avg_elapsed = elapsed / log_step
avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
log_start_time = time.time()
log_step = 0
lr = optimizer.param_groups[0]['lr']
throughput = target_tokens / elapsed
throughput = utils.distributed.all_reduce_item(throughput, op='sum')
meters['train_throughput'].update(throughput)
target_tokens = 0
log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
'| ms/batch {:5.1f} | tok/s {:>7d} | loss {:5.2f}'.format(
epoch,
train_step,
batch+1,
tr_iter.n_batch,
lr,
avg_elapsed * 1000,
int(throughput),
cur_loss,
)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else:
log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
logging.info(log_str)
if train_step % args.eval_interval == 0:
eval_start_time = time.time()
val_loss = evaluate(va_iter, model, args)
val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
logging.info('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(
train_step // args.eval_interval,
train_step,
(time.time() - eval_start_time),
val_loss,
)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
logging.info(log_str)
logging.info('-' * 100)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
best_val_loss = val_loss
if not args.debug:
name = 'checkpoint_best.pt'
save_checkpoint(args, model, model_config, optimizer,
scheduler, vocab, train_step,
best_val_loss, args.work_dir, name)
# Always save after eval if save_all is true and not debug
if not args.debug and args.save_all:
name = f'checkpoint_{train_step}.pt'
save_checkpoint(args, model, model_config, optimizer,
scheduler, vocab, train_step, best_val_loss,
args.work_dir, name)
# Save last checkpoint if not debug and not save_all
if not args.debug and not args.save_all:
name = 'checkpoint_last.pt'
save_checkpoint(args, model, model_config, optimizer,
scheduler, vocab, train_step, best_val_loss,
args.work_dir, name)
# dev-performance based learning rate annealing
if args.scheduler == 'dev_perf':
scheduler.step(val_loss)
if scheduler_sparse:
scheduler_sparse.step(val_loss)
# subtract eval time from timers for training
log_start_time += time.time() - eval_start_time
if train_step == args.max_step:
break
return train_step, best_val_loss
def main():
args = parse_args()
# Initialize device and distributed backend
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda' if args.cuda else 'cpu')
utils.distributed.init_distributed(args.cuda)
args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
args.dataset,
args.append_dataset,
args.append_time,
)
with utils.distributed.sync_workers() as rank:
if rank == 0:
create_exp_dir(args.work_dir,
scripts_to_save=['train.py', 'mem_transformer.py'],
debug=args.debug)
# Setup logging
if args.log_all_ranks:
log_file = f'log_rank_{utils.distributed.get_rank()}.log'
else:
log_file = f'log.log'
log_file = os.path.join(args.work_dir, log_file)
if args.debug:
log_file = os.devnull
utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
filename=log_file,
)
logging.info(args)
# Set the random seed manually for reproducibility.
np.random.seed(args.seed + utils.distributed.get_rank())
torch.manual_seed(args.seed + utils.distributed.get_rank())
###########################################################################
# Load data
###########################################################################
corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
ntokens = len(corpus.vocab)
vocab = corpus.vocab
args.n_token = ntokens
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
# adaptive softmax / embedding
cutoffs, tie_projs = [], [False]
if args.adaptive:
assert args.dataset in ['wt103', 'lm1b']
if args.dataset == 'wt103':
cutoffs = [19997, 39997, 199997]
tie_projs += [True] * len(cutoffs)
elif args.dataset == 'lm1b':
cutoffs = [59997, 99997, 639997]
tie_projs += [False] * len(cutoffs)
###########################################################################
# Build the model
###########################################################################
model_config = {
'n_token': ntokens,
'n_layer': args.n_layer,
'n_head': args.n_head,
'd_model': args.d_model,
'd_head': args.d_head,
'd_inner': args.d_inner,
'dropout': args.dropout,
'dropatt': args.dropatt,
'dtype': None,
'tie_weight': args.tied,
'd_embed': args.d_embed,
'div_val': args.div_val,
'tie_projs': tie_projs,
'pre_lnorm': args.pre_lnorm,
'tgt_len': args.tgt_len,
'ext_len': args.ext_len,
'mem_len': args.mem_len,
'cutoffs': cutoffs,
'same_length': args.same_length,
'attn_type': args.attn_type,
'clamp_len': args.clamp_len,
'sample_softmax': args.sample_softmax,
}
model = MemTransformerLM(**model_config)
model.apply(functools.partial(weights_init, args=args))
# ensure embedding init is not overridden by out_layer in case of weight sharing
model.word_emb.apply(functools.partial(weights_init, args=args))
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
# optimizer
if args.optim.lower() == 'sgd':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
else:
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.mom)
optimizer_sparse = None
elif args.optim.lower() == 'adam':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
optimizer = optim.Adam(dense_params, lr=args.lr,
weight_decay=args.weight_decay)
else:
optimizer = optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer_sparse = None
elif args.optim.lower() == 'adagrad':
optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
optimizer_sparse = None
elif args.optim.lower() == 'lamb':
optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer_sparse = None
elif args.optim.lower() == 'jitlamb':
optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer_sparse = None
model = model.to(device)
if args.fp16:
model, optimizer = amp.initialize(
model,
optimizer,
opt_level='O2',
)
if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
para_model = DistributedDataParallel(model,
delay_allreduce=True,
)
elif args.multi_gpu == 'dp':
if args.gpu0_bsz >= 0:
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
model, dim=1).to(device)
else:
para_model = nn.DataParallel(model, dim=1).to(device)
else:
para_model = model
# scheduler
if args.scheduler == 'cosine':
if args.max_step_scheduler:
max_step = args.max_step_scheduler
else:
max_step = args.max_step
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, max_step, eta_min=args.eta_min
)
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
optimizer_sparse, max_step, eta_min=args.eta_min
)
else:
scheduler_sparse = None
elif args.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
elif args.scheduler == 'dev_perf':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=args.decay_rate, patience=args.patience,
min_lr=args.lr_min,
)
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
optimizer_sparse, factor=args.decay_rate, patience=args.patience,
min_lr=args.lr_min,
)
else:
scheduler_sparse = None
elif args.scheduler == 'constant':
pass
logging.info('=' * 100)
for k, v in args.__dict__.items():
logging.info(' - {} : {}'.format(k, v))
logging.info('=' * 100)
logging.info('#params = {}'.format(args.n_all_param))
logging.info('#non emb params = {}'.format(args.n_nonemb_param))
train_step = 0
best_val_loss = None
if args.restart:
checkpoint = load_checkpoint(args.restart)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
scheduler.load_state_dict(checkpoint['scheduler_state'])
if args.fp16:
amp.load_state_dict(checkpoint['amp_state'])
train_step = checkpoint['train_step']
best_val_loss = checkpoint['best_val_loss']
model.apply(functools.partial(update_dropout, args=args))
model.apply(functools.partial(update_dropatt, args=args))
meters = {}
warmup = args.mem_len // args.tgt_len + 1
meters['train_throughput'] = AverageMeter(warmup=warmup)
###########################################################################
# Train
###########################################################################
# Loop over epochs.
# At any point you can hit Ctrl + C to break out of training early.
start_time = time.time()
try:
for epoch in itertools.count(start=1):
if args.roll:
tr_iter.roll()
train_step, best_val_loss = train(
tr_iter, va_iter, model, para_model, model_config, optimizer,
optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
train_step, best_val_loss, meters, args
)
if train_step == args.max_step:
logging.info('-' * 100)
logging.info('End of training')
break
except KeyboardInterrupt:
logging.info('-' * 100)
logging.info('Exiting from training early')
elapsed = time.time() - start_time
###########################################################################
# Test
###########################################################################
test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
if not args.debug and os.path.exists(test_path):
# Load the best saved model.
checkpoint = load_checkpoint(test_path)
model.load_state_dict(checkpoint['model_state'])
# Run on test data.
test_start_time = time.time()
test_loss = evaluate(te_iter, model, args)
test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
logging.info('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
time.time() - test_start_time, test_loss, test_loss / math.log(2)))
else:
logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
time.time() - test_start_time, test_loss, math.exp(test_loss)))
logging.info('=' * 100)
logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
if best_val_loss:
val_perplexity = math.exp(best_val_loss)
else:
val_perplexity = None
passed = benchmark(
target_perplexity=args.target_perplexity,
test_perplexity=val_perplexity,
target_throughput=args.target_throughput,
test_throughput=meters['train_throughput'].avg
)
if not passed:
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,16 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import distributed
from . import exp_utils

View file

@ -0,0 +1,90 @@
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveLogSoftmax(nn.Module):
def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
super(AdaptiveLogSoftmax, self).__init__()
cutoffs = list(cutoffs)
if (cutoffs != sorted(cutoffs)) \
or (min(cutoffs) <= 0) \
or (max(cutoffs) >= (n_classes - 1)) \
or (len(set(cutoffs)) != len(cutoffs)) \
or any([int(c) != c for c in cutoffs]):
raise ValueError("cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1")
self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.keep_order = keep_order
def forward(self, hidden, target, weight, bias, keep_order=False):
if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
head_weight = torch.cat(
[weight[:self.shortlist_size], self.cluster_weight], dim=0)
head_bias = torch.cat(
[bias[:self.shortlist_size], self.cluster_bias], dim=0)
head_logit = F.linear(hidden, head_weight, bias=head_bias)
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < h_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else:
weight_i = weight[l_idx:h_idx]
bias_i = bias[l_idx:h_idx]
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob_i[:, -i] \
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return nll

View file

@ -0,0 +1,91 @@
from torch.nn.parallel import DataParallel
import torch
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.parallel_apply import parallel_apply
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
try:
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
except:
print('obj', obj.size())
print('dim', dim)
print('chunk_sizes', chunk_sizes)
quit()
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
r"""Scatter with support for kwargs dictionary"""
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
class BalancedDataParallel(DataParallel):
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids)
if self.gpu0_bsz == 0:
replicas = replicas[1:]
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids)
def scatter(self, inputs, kwargs, device_ids):
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

View file

@ -0,0 +1,110 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import torch
def init_distributed(cuda):
"""
Initializes distributed backend.
:param cuda: (bool) if True initializes nccl backend, if False initializes
gloo backend
"""
world_size = int(os.environ.get('WORLD_SIZE', 1))
distributed = (world_size > 1)
if distributed:
backend = 'nccl' if cuda else 'gloo'
torch.distributed.init_process_group(backend=backend,
init_method='env://')
assert torch.distributed.is_initialized()
return distributed
def barrier():
"""
Call torch.distributed.barrier() if distritubed is in use
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
def get_rank():
"""
Gets distributed rank or returns zero if distributed is not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
return rank
def get_world_size():
"""
Gets total number of distributed workers or returns one if distributed is
not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def all_reduce_item(value, op='sum'):
"""
All-reduces single scalar value if distributed is in use
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if op == 'sum' or op == 'mean':
dop = torch.distributed.ReduceOp.SUM
elif op == 'min':
dop = torch.distributed.ReduceOp.MIN
elif op == 'max':
dop = torch.distributed.ReduceOp.MAX
elif op == 'product':
dop = torch.distributed.ReduceOp.PRODUCT
else:
raise RuntimeError('Unsupported reduce op')
backend = torch.distributed.get_backend()
if backend == torch.distributed.Backend.NCCL:
device = torch.device('cuda')
elif backend == torch.distributed.Backend.GLOO:
device = torch.device('cpu')
else:
raise RuntimeError('Unsupported distributed backend')
tensor = torch.tensor(value, device=device)
torch.distributed.all_reduce(tensor, dop)
if op == 'mean':
tensor /= get_world_size()
ret = tensor.item()
else:
ret = value
return ret
@contextmanager
def sync_workers():
"""
Yields distributed rank and synchronizes all workers on exit.
"""
rank = get_rank()
yield rank
barrier()

View file

@ -0,0 +1,147 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import os
import shutil
import sys
import time
import utils
class AverageMeter:
"""
Computes and stores the average and current value
"""
def __init__(self, warmup=0, keep=False):
self.reset()
self.warmup = warmup
self.keep = keep
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.iters = 0
self.vals = []
def update(self, val, n=1):
self.iters += 1
self.val = val
if self.iters > self.warmup:
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.keep:
self.vals.append(val)
def benchmark(test_perplexity=None, target_perplexity=None,
test_throughput=None, target_throughput=None):
def test(achieved, target, name, higher_better=True):
passed = True
if target is not None and achieved is not None:
logging.info(f'{name} achieved: {achieved:.2f} '
f'target: {target:.2f}')
if higher_better:
result = (achieved >= target)
else:
result = (achieved <= target)
if result:
logging.info(f'{name} test passed')
else:
logging.info(f'{name} test failed')
passed = False
return passed
passed = True
passed &= test(test_perplexity, target_perplexity, 'Perplexity', False)
passed &= test(test_throughput, target_throughput, 'Throughput')
return passed
def setup_logging(log_all_ranks=True, filename=os.devnull, filemode='w'):
"""
Configures logging.
By default logs from all workers are printed to the console, entries are
prefixed with "N: " where N is the rank of the worker. Logs printed to the
console don't include timestaps.
Full logs with timestamps are saved to the log_file file.
"""
class RankFilter(logging.Filter):
def __init__(self, rank, log_all_ranks):
self.rank = rank
self.log_all_ranks = log_all_ranks
def filter(self, record):
record.rank = self.rank
if self.log_all_ranks:
return True
else:
return (self.rank == 0)
rank = utils.distributed.get_rank()
rank_filter = RankFilter(rank, log_all_ranks)
if log_all_ranks:
logging_format = "%(asctime)s - %(levelname)s - %(rank)s - %(message)s"
else:
logging_format = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG,
format=logging_format,
datefmt="%Y-%m-%d %H:%M:%S",
filename=filename,
filemode=filemode)
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
if log_all_ranks:
formatter = logging.Formatter('%(rank)s: %(message)s')
else:
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
logging.getLogger('').addFilter(rank_filter)
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
if debug:
return
os.makedirs(dir_path, exist_ok=True)
print('Experiment dir : {}'.format(dir_path))
if scripts_to_save is not None:
script_path = os.path.join(dir_path, 'scripts')
os.makedirs(script_path, exist_ok=True)
for script in scripts_to_save:
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
def build_work_dir_name(work_dir, dataset, append_dataset, append_time):
if append_dataset:
work_dir = '{}-{}'.format(work_dir, dataset)
if append_time:
now = int(time.time())
now_max = utils.distributed.all_reduce_item(now, op='max')
now_str = datetime.datetime.fromtimestamp(now_max).strftime('%Y%m%d-%H%M%S')
work_dir = os.path.join(work_dir, now_str)
return work_dir

View file

@ -0,0 +1,147 @@
import torch
from torch import nn
import numpy as np
class LogUniformSampler(object):
def __init__(self, range_max, n_sample):
"""
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
expected count can be approximated by 1 - (1 - p)^n
and we use a numerically stable version -expm1(num_tries * log1p(-p))
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
"""
with torch.no_grad():
self.range_max = range_max
log_indices = torch.arange(1., range_max+2., 1.).log_()
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# print('P', self.dist.numpy().tolist()[-30:])
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
self.n_sample = n_sample
def sample(self, labels):
"""
labels: [b1, b2]
Return
true_log_probs: [b1, b2]
samp_log_probs: [n_sample]
neg_samples: [n_sample]
"""
# neg_samples = torch.empty(0).long()
n_sample = self.n_sample
n_tries = 2 * n_sample
with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device
neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device)
return true_log_probs, samp_log_probs, neg_samples
def sample_logits(embedding, bias, labels, inputs, sampler):
"""
embedding: an nn.Embedding layer
bias: [n_vocab]
labels: [b1, b2]
inputs: [b1, b2, n_emb]
sampler: you may use a LogUniformSampler
Return
logits: [b1, b2, 1 + n_sample]
"""
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
n_sample = neg_samples.size(0)
b1, b2 = labels.size(0), labels.size(1)
all_ids = torch.cat([labels.view(-1), neg_samples])
all_w = embedding(all_ids)
true_w = all_w[: -n_sample].view(b1, b2, -1)
sample_w = all_w[- n_sample:].view(n_sample, -1)
all_b = bias[all_ids]
true_b = all_b[: -n_sample].view(b1, b2)
sample_b = all_b[- n_sample:]
hit = (labels[:, :, None] == neg_samples).detach()
true_logits = torch.einsum('ijk,ijk->ij',
[true_w, inputs]) + true_b - true_log_probs
sample_logits = torch.einsum('lk,ijk->ijl',
[sample_w, inputs]) + sample_b - samp_log_probs
sample_logits.masked_fill_(hit, -1e30)
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
return logits
# class LogUniformSampler(object):
# def __init__(self, range_max, unique=False):
# """
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
# """
# self.range_max = range_max
# log_indices = torch.arange(1., range_max+2., 1.).log_()
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# self.unique = unique
# if self.unique:
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
# def sample(self, n_sample, labels):
# pos_sample, new_labels = labels.unique(return_inverse=True)
# n_pos_sample = pos_sample.size(0)
# n_neg_sample = n_sample - n_pos_sample
# if self.unique:
# self.exclude_mask.index_fill_(0, pos_sample, 1)
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
# self.exclude_mask.index_fill_(0, pos_sample, 0)
# else:
# sample_dist = self.dist
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
# sample = torch.cat([pos_sample, neg_sample])
# sample_prob = self.dist[sample]
# return new_labels, sample, sample_prob
if __name__ == '__main__':
S, B = 3, 4
n_vocab = 10000
n_sample = 5
H = 32
labels = torch.LongTensor(S, B).random_(0, n_vocab)
# sampler = LogUniformSampler(n_vocab, unique=False)
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
sampler = LogUniformSampler(n_vocab, unique=True)
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
# print('true_probs', true_probs.numpy().tolist())
# print('samp_probs', samp_probs.numpy().tolist())
# print('neg_samples', neg_samples.numpy().tolist())
# print('sum', torch.sum(sampler.dist).item())
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
embedding = nn.Embedding(n_vocab, H)
bias = torch.zeros(n_vocab)
inputs = torch.Tensor(S, B, H).normal_()
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
print('logits', logits.detach().numpy().tolist())
print('logits shape', logits.size())
print('out_labels', out_labels.detach().numpy().tolist())
print('out_labels shape', out_labels.size())

View file

@ -0,0 +1,154 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ProjectedAdaptiveLogSoftmax(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False):
super().__init__()
self.n_token = n_token
self.d_embed = d_embed
self.d_proj = d_proj
self.cutoffs = cutoffs + [n_token]
self.cutoff_ends = [0] + self.cutoffs
self.div_val = div_val
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
if self.n_clusters > 0:
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
self.out_layers = nn.ModuleList()
if div_val == 1:
if d_proj != d_embed:
self.out_projs = nn.ParameterList()
for i in range(len(self.cutoffs)):
self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_embed))
)
else:
self.out_projs = [None] * len(self.cutoffs)
self.out_layers.append(nn.Linear(d_embed, n_token))
else:
self.out_projs = nn.ParameterList()
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
)
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
self.keep_order = keep_order
def _compute_logit(self, hidden, weight, bias, proj):
if proj is None:
logit = F.linear(hidden, weight, bias=bias)
else:
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
if bias is not None:
logit = logit + bias
return logit
def forward(self, hidden, target, keep_order=False):
'''
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
'''
if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
if self.div_val == 1:
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = self.out_layers[0].weight[l_idx:r_idx]
bias_i = self.out_layers[0].bias[l_idx:r_idx]
else:
weight_i = self.out_layers[i].weight
bias_i = self.out_layers[i].bias
if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
logprob_i = head_logprob_i[:, -i] \
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
if self.keep_order or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
return nll

View file

@ -0,0 +1,229 @@
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
from collections import Counter, OrderedDict
import utils
import torch
class Vocab(object):
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
delimiter=None, vocab_file=None):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
self.max_size = max_size
self.lower_case = lower_case
self.delimiter = delimiter
self.vocab_file = vocab_file
def tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()
# empty delimiter '' will evaluate False
if self.delimiter == '':
symbols = line
else:
symbols = line.split(self.delimiter)
if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>']
elif add_eos:
return symbols + ['<eos>']
else:
return symbols
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
assert os.path.exists(path)
sents = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols)
sents.append(symbols)
return sents
def count_sents(self, sents, verbose=False):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: print('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
self.counter.update(symbols)
def _build_from_file(self, vocab_file):
self.idx2sym = []
self.sym2idx = OrderedDict()
with open(vocab_file, 'r', encoding='utf-8') as f:
for line in f:
symb = line.strip().split()[0]
self.add_symbol(symb)
self.unk_idx = self.sym2idx['<UNK>']
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self)))
else:
print('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
self.idx2sym = []
self.sym2idx = OrderedDict()
for sym in self.special:
self.add_special(sym)
for sym, cnt in self.counter.most_common(self.max_size):
if cnt < self.min_freq: break
self.add_symbol(sym)
print('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
add_double_eos=False):
if verbose: print('encoding file {} ...'.format(path))
assert os.path.exists(path)
encoded = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos,
add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols))
if ordered:
encoded = torch.cat(encoded)
return encoded
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents)))
encoded = []
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols))
if ordered:
encoded = torch.cat(encoded)
return encoded
def add_special(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
def add_symbol(self, sym):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
return self.idx2sym[idx]
def get_idx(self, sym):
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
# print('encounter unk {}'.format(sym))
assert '<eos>' not in sym
assert hasattr(self, 'unk_idx')
return self.sym2idx.get(sym, self.unk_idx)
def get_symbols(self, indices):
return [self.get_sym(idx) for idx in indices]
def get_indices(self, symbols):
return [self.get_idx(sym) for sym in symbols]
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.get_indices(symbols))
def convert_to_sent(self, indices, exclude=None):
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
def __len__(self):
return len(self.idx2sym)
# Class OpenAIVocab has been adapted from
# https://github.com/cybertronai/transformer-xl/blob/master/utils/vocabulary.py
class OpenAIVocab(Vocab):
def __init__(self, max_size=None, vocab_file=None):
from pytorch_transformers import GPT2Tokenizer
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.EOT = self.tokenizer.encoder['<|endoftext|>']
self.max_size = max_size
self.vocab_file = vocab_file
pad = 8
vocab_size = len(self.tokenizer)
padded_vocab_size = (vocab_size + pad - 1) // pad * pad
for i in range(0, padded_vocab_size - vocab_size):
token = f'madeupword{i:09d}'
self.tokenizer.add_tokens([token])
def __len__(self):
return len(self.tokenizer)
def count_file(self, path, verbose=False, add_eos=False):
# TODO: train from scratch, respect self.max_size
pass
def build_vocab(self):
pass
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False) -> torch.LongTensor:
cached = path + '.bpe'
if os.path.exists(cached):
return torch.load(cached)
print(f'encoding file {path} ...')
assert os.path.exists(path), f"{path} doesn't exist"
with open(path, encoding='utf-8') as f:
# Suppress warnings about length.
with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull):
out = torch.LongTensor(self.tokenizer.encode(f.read()) + [self.EOT])
with utils.distributed.sync_workers() as rank:
if rank == 0:
torch.save(out, cached)
return out
def tokenize(self, line, add_eos=False, add_double_eos=False):
return self.tokenizer.encode(line)
def convert_to_tensor(self, symbols):
return torch.LongTensor(symbols)

View file

@ -26,6 +26,8 @@ The examples are organized first by framework, such as TensorFlow, PyTorch, etc.
- __GNMT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/GNMT)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Translation/GNMT)]
- __Transformer__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/Transformer)]
- __BERT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)][[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT)]
- __TransformerXL__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/TransformerXL)]
### Recommender Systems
- __NCF__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/NCF)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/NCF)]