[TXL/PyT] update: (#989)
* changed API calls to torch.einsum * added export OMP_NUM_THREADS=1 to all launcher scripts * additional runtime checks to ensure that launch configuration is valid
This commit is contained in:
parent
706ef498c9
commit
ef98b2cef9
|
@ -1113,7 +1113,11 @@ perplexity on the test dataset.
|
|||
|
||||
## Performance
|
||||
|
||||
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
||||
The performance measurements in this document were conducted at the time of
|
||||
publication and may not reflect the performance achieved from NVIDIA’s latest
|
||||
software release. For the most up-to-date performance measurements, go to
|
||||
[NVIDIA Data Center Deep Learning Product
|
||||
Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
||||
|
||||
### Benchmarking
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class MultiHeadAttn(nn.Module):
|
|||
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
||||
|
||||
# [bsz x n_head x qlen x klen]
|
||||
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
|
||||
attn_score = torch.einsum('ibnd,jbnd->bnij', head_q, head_k)
|
||||
attn_score.mul_(self.scale)
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
|
@ -135,7 +135,7 @@ class MultiHeadAttn(nn.Module):
|
|||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, head_v))
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, head_v)
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
|
@ -262,13 +262,13 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
|
||||
# compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
|
||||
# AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
|
||||
# AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
|
||||
rw_head_q = rw_head_q.view(qlen, bsz * self.n_head, self.d_head).permute(1, 0, 2)
|
||||
w_head_k = w_head_k.reshape(klen, bsz * self.n_head, self.d_head).permute(1, 2, 0)
|
||||
AC = torch.bmm(rw_head_q, w_head_k).view(bsz, self.n_head, qlen, klen)
|
||||
|
||||
rr_head_q = w_head_q + r_r_bias
|
||||
# BD = torch.einsum('ibnd,jnd->bnij', (rr_head_q, r_head_k)) # bsz x n_head x qlen x klen
|
||||
# BD = torch.einsum('ibnd,jnd->bnij', rr_head_q, r_head_k) # bsz x n_head x qlen x klen
|
||||
rr_head_q = rr_head_q.permute(2, 1, 0, 3).reshape(self.n_head, bsz * qlen, self.d_head)
|
||||
r_head_k = r_head_k.permute(1, 2, 0).view(self.n_head, self.d_head, klen)
|
||||
BD = torch.bmm(rr_head_q, r_head_k).view(self.n_head, bsz, qlen, klen).permute(1, 0, 2, 3)
|
||||
|
@ -290,7 +290,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# compute attention vector
|
||||
# attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
# attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
|
||||
attn_prob = attn_prob.view(bsz * self.n_head, qlen, klen)
|
||||
w_head_v = w_head_v.permute(1, 2, 0, 3).reshape(bsz * self.n_head, klen, self.d_head)
|
||||
attn_vec = torch.bmm(attn_prob, w_head_v).permute(1, 0, 2).view(qlen, bsz, self.n_head, self.d_head)
|
||||
|
@ -358,11 +358,11 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
r_bias = r_bias.t()
|
||||
|
||||
# compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
||||
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
||||
|
||||
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
|
||||
B_ = torch.einsum('ibnd,jnd->bnij', (w_head_q, r_emb)) # bsz x n_head x qlen x klen
|
||||
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
|
||||
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
|
||||
B_ = torch.einsum('ibnd,jnd->bnij', w_head_q, r_emb) # bsz x n_head x qlen x klen
|
||||
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
|
||||
BD = self._rel_shift(B_ + D_)
|
||||
|
||||
# [bsz x qlen x klen x n_head]
|
||||
|
@ -381,7 +381,7 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# compute attention vector
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
|
|
|
@ -146,7 +146,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
if proj is None:
|
||||
logit = F.linear(hidden, weight, bias=bias)
|
||||
else:
|
||||
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
|
||||
logit = torch.einsum('bd,de,ev->bv', hidden, proj, weight.t())
|
||||
if bias is not None:
|
||||
logit = logit + bias
|
||||
return logit
|
||||
|
|
|
@ -125,7 +125,7 @@ class MultiHeadAttn(nn.Module):
|
|||
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
||||
|
||||
# [bsz x n_head x qlen x klen]
|
||||
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
|
||||
attn_score = torch.einsum('ibnd,jbnd->bnij', head_q, head_k)
|
||||
attn_score.mul_(self.scale)
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
|
@ -138,7 +138,7 @@ class MultiHeadAttn(nn.Module):
|
|||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, head_v))
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, head_v)
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
|
@ -264,10 +264,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
|
||||
# compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
|
||||
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
|
||||
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
|
||||
|
||||
rr_head_q = w_head_q + r_r_bias
|
||||
BD = torch.einsum('ibnd,jnd->bnij', (rr_head_q, r_head_k)) # bsz x n_head x qlen x klen
|
||||
BD = torch.einsum('ibnd,jnd->bnij', rr_head_q, r_head_k) # bsz x n_head x qlen x klen
|
||||
BD = self._rel_shift(BD)
|
||||
|
||||
# [bsz x n_head x qlen x klen]
|
||||
|
@ -285,7 +285,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# compute attention vector
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
|
@ -350,11 +350,11 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
r_bias = r_bias.t()
|
||||
|
||||
# compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
||||
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
||||
|
||||
AC = torch.einsum('ibnd,jbnd->bnij', (rw_head_q, w_head_k)) # bsz x n_head x qlen x klen
|
||||
B_ = torch.einsum('ibnd,jnd->bnij', (w_head_q, r_emb)) # bsz x n_head x qlen x klen
|
||||
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
|
||||
AC = torch.einsum('ibnd,jbnd->bnij', rw_head_q, w_head_k) # bsz x n_head x qlen x klen
|
||||
B_ = torch.einsum('ibnd,jnd->bnij', w_head_q, r_emb) # bsz x n_head x qlen x klen
|
||||
D_ = r_bias[None, :, None, :] # 1 x n_head x 1 x klen
|
||||
BD = self._rel_shift(B_ + D_)
|
||||
|
||||
# [bsz x qlen x klen x n_head]
|
||||
|
@ -372,7 +372,7 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# compute attention vector
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, w_head_v)
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]] || [[ $1 == 'all' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ "$1" == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python -m torch.distributed.launch --nproc_per_node="$2" train.py \
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python -m torch.distributed.launch --nproc_per_node="$2" train.py \
|
||||
|
|
|
@ -285,6 +285,13 @@ def parse_args():
|
|||
if args.batch_size % args.batch_chunk != 0:
|
||||
raise RuntimeError('Batch size needs to be divisible by batch chunk')
|
||||
|
||||
if (
|
||||
args.local_batch_size is not None
|
||||
and args.local_batch_size % args.batch_chunk != 0
|
||||
):
|
||||
raise RuntimeError('Local batch size needs to be divisible by '
|
||||
'batch chunk')
|
||||
|
||||
if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
|
||||
raise RuntimeError(
|
||||
'APEX AMP unavailable, install APEX or switch to pytorch AMP'
|
||||
|
@ -444,8 +451,10 @@ def evaluate(eval_iter, model, args):
|
|||
for i, (data, target, seq_len, warm) in enumerate(eval_iter):
|
||||
if args.eval_max_steps > 0 and i >= args.eval_max_steps:
|
||||
break
|
||||
loss, mems = model(data, target, mems)
|
||||
loss = loss.float().mean()
|
||||
enable_autocast = args.fp16 and args.amp == 'pytorch'
|
||||
with torch.cuda.amp.autocast(enable_autocast):
|
||||
loss, mems = model(data, target, mems)
|
||||
loss = loss.float().mean().type_as(loss)
|
||||
if warm:
|
||||
# assert (mems is None) or mems.size(1) == model.mem_len
|
||||
total_loss += seq_len * loss.item()
|
||||
|
@ -735,6 +744,9 @@ def main():
|
|||
args.batch_size = world_size * args.local_batch_size
|
||||
logging.info(f'--local_batch_size was set, adjusting global batch size'
|
||||
f' to {args.batch_size} (local_batch_size * world_size)')
|
||||
if args.batch_size % args.batch_chunk != 0:
|
||||
raise RuntimeError('Batch size needs to be divisible by '
|
||||
'batch chunk')
|
||||
|
||||
if args.profile:
|
||||
try:
|
||||
|
|
|
@ -70,9 +70,9 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
|
|||
hit = (labels[:, :, None] == neg_samples).detach()
|
||||
|
||||
true_logits = torch.einsum('ijk,ijk->ij',
|
||||
[true_w, inputs]) + true_b - true_log_probs
|
||||
true_w, inputs) + true_b - true_log_probs
|
||||
sample_logits = torch.einsum('lk,ijk->ijl',
|
||||
[sample_w, inputs]) + sample_b - samp_log_probs
|
||||
sample_w, inputs) + sample_b - samp_log_probs
|
||||
sample_logits.masked_fill_(hit, -1e30)
|
||||
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
if proj is None:
|
||||
logit = F.linear(hidden, weight, bias=bias)
|
||||
else:
|
||||
logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
|
||||
logit = torch.einsum('bd,de,ev->bv', hidden, proj, weight.t())
|
||||
if bias is not None:
|
||||
logit = logit + bias
|
||||
return logit
|
||||
|
|
Loading…
Reference in a new issue