[TXL/PyT] update: (#989)

* changed API calls to torch.einsum
* added export OMP_NUM_THREADS=1 to all launcher scripts
* additional runtime checks to ensure that launch configuration is valid
This commit is contained in:
Szymon Migacz 2021-08-20 08:39:12 -07:00 committed by GitHub
parent 706ef498c9
commit ef98b2cef9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 61 additions and 27 deletions

View file

@ -1113,7 +1113,11 @@ perplexity on the test dataset.
## Performance
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIAs latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
The performance measurements in this document were conducted at the time of
publication and may not reflect the performance achieved from NVIDIAs latest
software release. For the most up-to-date performance measurements, go to
[NVIDIA Data Center Deep Learning Product
Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
### Benchmarking

View file

@ -122,7 +122,7 @@ class MultiHeadAttn(nn.Module):
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [bsz x n_head x qlen x klen]
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
attn_score = torch.einsum('ibnd,jbnd->bnij', head_q, head_k)
attn_score.mul_(self.scale)
if attn_mask is not None:
if attn_mask.dim() == 2:
@ -135,7 +135,7 @@ class MultiHeadAttn(nn.Module):
attn_prob = self.dropatt(attn_prob)
# [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, head_v))
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, head_v)
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
@ -262,13 +262,13 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
# compute attention score
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
# AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
# AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
rw_head_q = rw_head_q.view(qlen, bsz * self.n_head, self.d_head).permute(1, 0, 2)
w_head_k = w_head_k.reshape(klen, bsz * self.n_head, self.d_head).permute(1, 2, 0)
AC = torch.bmm(rw_head_q, w_head_k).view(bsz, self.n_head, qlen, klen)
rr_head_q = w_head_q + r_r_bias
# BD = torch.einsum('ibnd,jnd->bnij', (rr_head_q, r_head_k)) # bsz x n_head x qlen x klen
# BD = torch.einsum('ibnd,jnd->bnij', rr_head_q, r_head_k) # bsz x n_head x qlen x klen
rr_head_q = rr_head_q.permute(2, 1, 0, 3).reshape(self.n_head, bsz * qlen, self.d_head)
r_head_k = r_head_k.permute(1, 2, 0).view(self.n_head, self.d_head, klen)
BD = torch.bmm(rr_head_q, r_head_k).view(self.n_head, bsz, qlen, klen).permute(1, 0, 2, 3)
@ -290,7 +290,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = self.dropatt(attn_prob)
# compute attention vector
# attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
# attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
attn_prob = attn_prob.view(bsz * self.n_head, qlen, klen)
w_head_v = w_head_v.permute(1, 2, 0, 3).reshape(bsz * self.n_head, klen, self.d_head)
attn_vec = torch.bmm(attn_prob, w_head_v).permute(1, 0, 2).view(qlen, bsz, self.n_head, self.d_head)
@ -358,11 +358,11 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
r_bias = r_bias.t()
# compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
B_ = torch.einsum('ibnd,jnd->bnij', (w_head_q, r_emb)) # bsz x n_head x qlen x klen
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
B_ = torch.einsum('ibnd,jnd->bnij', w_head_q, r_emb) # bsz x n_head x qlen x klen
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
BD = self._rel_shift(B_ + D_)
# [bsz x qlen x klen x n_head]
@ -381,7 +381,7 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = self.dropatt(attn_prob)
# compute attention vector
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(

View file

@ -146,7 +146,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if proj is None:
logit = F.linear(hidden, weight, bias=bias)
else:
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
logit = torch.einsum('bd,de,ev->bv', hidden, proj, weight.t())
if bias is not None:
logit = logit + bias
return logit

View file

@ -125,7 +125,7 @@ class MultiHeadAttn(nn.Module):
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [bsz x n_head x qlen x klen]
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
attn_score = torch.einsum('ibnd,jbnd->bnij', head_q, head_k)
attn_score.mul_(self.scale)
if attn_mask is not None:
if attn_mask.dim() == 2:
@ -138,7 +138,7 @@ class MultiHeadAttn(nn.Module):
attn_prob = self.dropatt(attn_prob)
# [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, head_v))
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, head_v)
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
@ -264,10 +264,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
# compute attention score
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
rr_head_q = w_head_q + r_r_bias
BD = torch.einsum('ibnd,jnd->bnij', (rr_head_q, r_head_k)) # bsz x n_head x qlen x klen
BD = torch.einsum('ibnd,jnd->bnij', rr_head_q, r_head_k) # bsz x n_head x qlen x klen
BD = self._rel_shift(BD)
# [bsz x n_head x qlen x klen]
@ -285,7 +285,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = self.dropatt(attn_prob)
# compute attention vector
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
@ -350,11 +350,11 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
r_bias = r_bias.t()
# compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
B_ = torch.einsum('ibnd,jnd->bnij', (w_head_q, r_emb)) # bsz x n_head x qlen x klen
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
B_ = torch.einsum('ibnd,jnd->bnij', w_head_q, r_emb) # bsz x n_head x qlen x klen
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
BD = self._rel_shift(B_ + D_)
# [bsz x qlen x klen x n_head]
@ -372,7 +372,7 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = self.dropatt(attn_prob)
# compute attention vector
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(

View file

@ -1,5 +1,7 @@
#!/bin/bash
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \

View file

@ -1,5 +1,7 @@
#!/bin/bash
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \

View file

@ -1,5 +1,7 @@
#!/bin/bash
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \

View file

@ -1,5 +1,7 @@
#!/bin/bash
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \

View file

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]] || [[ $1 == 'all' ]]; then
echo 'Run training...'
python train.py \

View file

@ -1,5 +1,7 @@
#!/bin/bash
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \

View file

@ -1,5 +1,7 @@
#!/bin/bash
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \

View file

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
export OMP_NUM_THREADS=1
if [[ "$1" == 'train' ]]; then
echo 'Run training...'
python -m torch.distributed.launch --nproc_per_node="$2" train.py \

View file

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
export OMP_NUM_THREADS=1
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python -m torch.distributed.launch --nproc_per_node="$2" train.py \

View file

@ -285,6 +285,13 @@ def parse_args():
if args.batch_size % args.batch_chunk != 0:
raise RuntimeError('Batch size needs to be divisible by batch chunk')
if (
args.local_batch_size is not None
and args.local_batch_size % args.batch_chunk != 0
):
raise RuntimeError('Local batch size needs to be divisible by '
'batch chunk')
if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
raise RuntimeError(
'APEX AMP unavailable, install APEX or switch to pytorch AMP'
@ -444,8 +451,10 @@ def evaluate(eval_iter, model, args):
for i, (data, target, seq_len, warm) in enumerate(eval_iter):
if args.eval_max_steps > 0 and i >= args.eval_max_steps:
break
loss, mems = model(data, target, mems)
loss = loss.float().mean()
enable_autocast = args.fp16 and args.amp == 'pytorch'
with torch.cuda.amp.autocast(enable_autocast):
loss, mems = model(data, target, mems)
loss = loss.float().mean().type_as(loss)
if warm:
# assert (mems is None) or mems.size(1) == model.mem_len
total_loss += seq_len * loss.item()
@ -735,6 +744,9 @@ def main():
args.batch_size = world_size * args.local_batch_size
logging.info(f'--local_batch_size was set, adjusting global batch size'
f' to {args.batch_size} (local_batch_size * world_size)')
if args.batch_size % args.batch_chunk != 0:
raise RuntimeError('Batch size needs to be divisible by '
'batch chunk')
if args.profile:
try:

View file

@ -70,9 +70,9 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
hit = (labels[:, :, None] == neg_samples).detach()
true_logits = torch.einsum('ijk,ijk->ij',
[true_w, inputs]) + true_b - true_log_probs
true_w, inputs) + true_b - true_log_probs
sample_logits = torch.einsum('lk,ijk->ijl',
[sample_w, inputs]) + sample_b - samp_log_probs
sample_w, inputs) + sample_b - samp_log_probs
sample_logits.masked_fill_(hit, -1e30)
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)

View file

@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if proj is None:
logit = F.linear(hidden, weight, bias=bias)
else:
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
logit = torch.einsum('bd,de,ev->bv', hidden, proj, weight.t())
if bias is not None:
logit = logit + bias
return logit