Merge pull request #901 from szmigacz/txl-update3
[TXL/PyT] Update for PyT Transformer-XL:
This commit is contained in:
commit
0a3a4d0916
|
@ -20,11 +20,16 @@ import os
|
|||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import dllogger
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
try:
|
||||
import pyprof
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn('PyProf is unavailable')
|
||||
|
||||
import data_utils
|
||||
import utils
|
||||
|
@ -72,6 +77,15 @@ def parse_args():
|
|||
parser.add_argument('--split', type=str, default='all',
|
||||
choices=['all', 'valid', 'test'],
|
||||
help='which split to evaluate')
|
||||
parser.add_argument('--affinity', type=str,
|
||||
default='single_unique',
|
||||
choices=['socket', 'single', 'single_unique',
|
||||
'socket_unique_interleaved',
|
||||
'socket_unique_continuous',
|
||||
'disabled'],
|
||||
help='type of CPU affinity')
|
||||
parser.add_argument('--profile', action='store_true',
|
||||
help='Enable profiling with DLProf')
|
||||
parser.add_argument('--type', type=str, default='pytorch',
|
||||
choices=['pytorch', 'torchscript'],
|
||||
help='type of runtime to use')
|
||||
|
@ -134,7 +148,12 @@ def parse_args():
|
|||
if args.manual:
|
||||
args.batch_size = 1
|
||||
|
||||
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||
if args.same_length and args.tgt_len > args.mem_len:
|
||||
warnings.warn('--same_length is intended to be used with large '
|
||||
'mem_len relative to tgt_len')
|
||||
|
||||
if args.ext_len < 0:
|
||||
raise RuntimeError('Extended context length must be non-negative')
|
||||
return args
|
||||
|
||||
|
||||
|
@ -182,7 +201,6 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
|
|||
loss = loss.float().mean()
|
||||
log_loss += loss.item()
|
||||
if warm:
|
||||
# assert all([m.size(0) == model.mem_len for m in mems])
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
|
||||
|
@ -251,7 +269,14 @@ def compile_model(model, device, args):
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
utils.gpu_affinity.set_affinity(args.local_rank)
|
||||
if args.affinity != 'disabled':
|
||||
nproc_per_node = torch.cuda.device_count()
|
||||
affinity = utils.gpu_affinity.set_affinity(
|
||||
args.local_rank,
|
||||
nproc_per_node,
|
||||
args.affinity
|
||||
)
|
||||
print(f'{args.local_rank}: thread affinity: {affinity}')
|
||||
|
||||
if args.type == 'pytorch':
|
||||
from mem_transformer import MemTransformerLM
|
||||
|
@ -286,6 +311,12 @@ def main():
|
|||
)
|
||||
utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
|
||||
|
||||
if args.profile:
|
||||
try:
|
||||
pyprof.init(enable_function_stack=True)
|
||||
except NameError:
|
||||
warnings.warn('Called pyprof.init() but pyprof is not available')
|
||||
|
||||
logging.info(args)
|
||||
dllogger.log(step='PARAMETER', data=vars(args))
|
||||
|
||||
|
@ -423,7 +454,9 @@ def main():
|
|||
meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
|
||||
meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
|
||||
|
||||
loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
|
||||
with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
|
||||
loss = evaluate(iter, model, meters, args.log_interval, args.max_size,
|
||||
args.repeat)
|
||||
perplexity = math.exp(loss)
|
||||
log_str = format_log(loss, args.split, args)
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ class MultiHeadAttn(nn.Module):
|
|||
# [bsz x n_head x qlen x klen]
|
||||
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
|
||||
attn_score.mul_(self.scale)
|
||||
if attn_mask is not None and attn_mask.any():
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
|
@ -279,7 +279,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
attn_score.mul_(self.scale)
|
||||
|
||||
# compute attention probability
|
||||
if attn_mask is not None and attn_mask.any():
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
|
@ -370,7 +370,7 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
attn_score.mul_(self.scale)
|
||||
|
||||
# compute attention probability
|
||||
if attn_mask is not None and attn_mask.any():
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
|
@ -521,8 +521,7 @@ class AdaptiveEmbedding(nn.Module):
|
|||
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
|
||||
dtype=self.dtype, device=torch.device('cuda'))
|
||||
|
||||
i = 0
|
||||
for emb_layer in self.emb_layers:
|
||||
for i, emb_layer in enumerate(self.emb_layers):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
|
||||
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
|
||||
|
@ -534,7 +533,6 @@ class AdaptiveEmbedding(nn.Module):
|
|||
emb_i = F.linear(emb_i, self.emb_projs[i])
|
||||
|
||||
emb_flat.index_copy_(0, indices_i, emb_i)
|
||||
i += 1
|
||||
|
||||
embed = emb_flat.view(inp.size(0), inp.size(1), self.d_proj)
|
||||
|
||||
|
@ -627,23 +625,15 @@ class MemTransformerLM(nn.Module):
|
|||
# default attention
|
||||
if self.attn_type == 0:
|
||||
self.pos_emb = PositionalEmbedding(self.d_model)
|
||||
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
||||
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.tgt_len = tgt_len
|
||||
self.mem_len = mem_len
|
||||
self.ext_len = ext_len
|
||||
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head).zero_())
|
||||
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head).zero_())
|
||||
|
||||
def init_mems(self):
|
||||
mems = []
|
||||
for i in range(self.n_layer+1):
|
||||
empty = torch.empty(0, dtype=self.dtype, device=torch.device('cuda'))
|
||||
mems.append(empty)
|
||||
mems = torch.empty(self.n_layer, 0, dtype=self.dtype, device=torch.device('cuda'))
|
||||
|
||||
return mems
|
||||
|
||||
def _update_mems(self, hids: List[torch.Tensor], mems: List[torch.Tensor],
|
||||
def _update_mems(self, hids: List[torch.Tensor], mems: torch.Tensor,
|
||||
qlen: int, mlen: int):
|
||||
assert len(hids) == len(mems), 'len(hids) != len(mems)'
|
||||
|
||||
|
@ -652,16 +642,18 @@ class MemTransformerLM(nn.Module):
|
|||
# will be used as the extended context. Hence, we only cache
|
||||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
||||
# to `mlen + qlen - self.ext_len`.
|
||||
new_mems = []
|
||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
||||
stacked = torch.stack(hids)
|
||||
end_idx = mlen + max(0, qlen - self.ext_len)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
for i in range(len(hids)):
|
||||
cat = torch.cat([mems[i], hids[i]], dim=0)
|
||||
new_mems.append(cat[beg_idx:end_idx].detach())
|
||||
if mems.numel():
|
||||
cat = torch.cat([mems, stacked], dim=1)
|
||||
else:
|
||||
cat = stacked
|
||||
new_mems = cat[:, beg_idx:end_idx].detach()
|
||||
|
||||
return new_mems
|
||||
|
||||
def _forward(self, dec_inp, mems: List[torch.Tensor]):
|
||||
def _forward(self, dec_inp, mems: torch.Tensor):
|
||||
qlen, bsz = dec_inp.size()
|
||||
|
||||
word_emb = self.word_emb(dec_inp)
|
||||
|
@ -671,7 +663,7 @@ class MemTransformerLM(nn.Module):
|
|||
all_ones = torch.ones((qlen, klen), device=torch.device('cuda'),
|
||||
dtype=self.dtype)
|
||||
if self.same_length:
|
||||
mask_len = klen - self.mem_len
|
||||
mask_len = klen - self.mem_len - 1
|
||||
if mask_len > 0:
|
||||
mask_shift_len = qlen - mask_len
|
||||
else:
|
||||
|
@ -681,7 +673,6 @@ class MemTransformerLM(nn.Module):
|
|||
else:
|
||||
dec_attn_mask = torch.triu(all_ones, diagonal=1+mlen).to(torch.bool)
|
||||
|
||||
hids = []
|
||||
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
|
@ -691,22 +682,20 @@ class MemTransformerLM(nn.Module):
|
|||
core_out = self.drop(word_emb)
|
||||
pos_emb = self.drop(pos_emb)
|
||||
|
||||
hids.append(core_out)
|
||||
i = 0
|
||||
for layer in self.layers:
|
||||
hids = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out)
|
||||
mems_i = None if mems is None else mems[i]
|
||||
core_out = layer(core_out, pos_emb, self.r_w_bias,
|
||||
self.r_r_bias, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i)
|
||||
hids.append(core_out)
|
||||
i += 1
|
||||
core_out = self.drop(core_out)
|
||||
|
||||
new_mems = self._update_mems(hids, mems, qlen, mlen)
|
||||
|
||||
return core_out, new_mems
|
||||
|
||||
def forward(self, data, target, mems: Optional[List[torch.Tensor]]):
|
||||
def forward(self, data, target, mems: Optional[torch.Tensor]):
|
||||
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
||||
# So, have to initialize size(0) mems inside the model forward.
|
||||
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
||||
|
|
|
@ -21,6 +21,11 @@ from utils.log_uniform_sampler import sample_logits
|
|||
from utils.proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def add_and_scale(tensor1, tensor2, alpha: float):
|
||||
return alpha * (tensor1 + tensor2)
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, demb):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
|
@ -122,7 +127,7 @@ class MultiHeadAttn(nn.Module):
|
|||
# [bsz x n_head x qlen x klen]
|
||||
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
|
||||
attn_score.mul_(self.scale)
|
||||
if attn_mask is not None and attn_mask.any().item():
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
|
@ -266,11 +271,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
BD = self._rel_shift(BD)
|
||||
|
||||
# [bsz x n_head x qlen x klen]
|
||||
attn_score = AC + BD
|
||||
attn_score.mul_(self.scale)
|
||||
attn_score = add_and_scale(AC, BD, self.scale)
|
||||
|
||||
# compute attention probability
|
||||
if attn_mask is not None and attn_mask.any().item():
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
|
@ -354,11 +358,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||
BD = self._rel_shift(B_ + D_)
|
||||
|
||||
# [bsz x qlen x klen x n_head]
|
||||
attn_score = AC + BD
|
||||
attn_score.mul_(self.scale)
|
||||
attn_score = add_and_scale(AC, BD, self.scale)
|
||||
|
||||
# compute attention probability
|
||||
if attn_mask is not None and attn_mask.any().item():
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
|
@ -627,6 +630,12 @@ class MemTransformerLM(nn.Module):
|
|||
self.n_layer, self.max_klen, self.d_model).zero_())
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
if tgt_len < 1:
|
||||
raise RuntimeError(f'tgt_len should be >= 1, but got {tgt_len}')
|
||||
if ext_len < 0:
|
||||
raise RuntimeError(f'ext_len should be >= 0, but got {ext_len}')
|
||||
if mem_len < 0:
|
||||
raise RuntimeError(f'mem_len should be >= 0, but got {mem_len}')
|
||||
self.tgt_len = tgt_len
|
||||
self.mem_len = mem_len
|
||||
self.ext_len = ext_len
|
||||
|
@ -634,7 +643,7 @@ class MemTransformerLM(nn.Module):
|
|||
def init_mems(self):
|
||||
if self.mem_len > 0:
|
||||
param = next(self.parameters())
|
||||
mems = torch.empty(self.n_layer + 1, 0, dtype=param.dtype,
|
||||
mems = torch.empty(self.n_layer, 0, dtype=param.dtype,
|
||||
device=param.device)
|
||||
return mems
|
||||
else:
|
||||
|
@ -654,14 +663,21 @@ class MemTransformerLM(nn.Module):
|
|||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
||||
# to `mlen + qlen - self.ext_len`.
|
||||
with torch.no_grad():
|
||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
stacked = torch.stack(hids)
|
||||
if mems.numel():
|
||||
cat = torch.cat([mems, stacked], dim=1)
|
||||
if (
|
||||
self.mem_len == self.tgt_len
|
||||
and self.ext_len == 0
|
||||
and stacked.size(1) == self.mem_len
|
||||
):
|
||||
new_mems = stacked.detach()
|
||||
else:
|
||||
cat = stacked
|
||||
new_mems = cat[:, beg_idx:end_idx].detach()
|
||||
end_idx = mlen + max(0, qlen - self.ext_len)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
if mems.numel():
|
||||
cat = torch.cat([mems, stacked], dim=1)
|
||||
else:
|
||||
cat = stacked
|
||||
new_mems = cat[:, beg_idx:end_idx].detach()
|
||||
|
||||
return new_mems
|
||||
|
||||
|
@ -697,18 +713,17 @@ class MemTransformerLM(nn.Module):
|
|||
core_out = self.drop(word_emb)
|
||||
pos_emb = self.drop(pos_emb)
|
||||
|
||||
hids.append(core_out.detach())
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out.detach())
|
||||
mems_i = None if mems is None else mems[i]
|
||||
core_out = layer(core_out, pos_emb, self.r_w_bias,
|
||||
self.r_r_bias, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i)
|
||||
hids.append(core_out.detach())
|
||||
# learnable
|
||||
elif self.attn_type == 1:
|
||||
core_out = self.drop(word_emb)
|
||||
hids.append(core_out.detach())
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out.detach())
|
||||
if self.clamp_len > 0:
|
||||
r_emb = self.r_emb[i][-self.clamp_len:]
|
||||
r_bias = self.r_bias[i][-self.clamp_len:]
|
||||
|
@ -718,7 +733,6 @@ class MemTransformerLM(nn.Module):
|
|||
mems_i = None if mems is None else mems[i]
|
||||
core_out = layer(core_out, r_emb, self.r_w_bias[i],
|
||||
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
||||
hids.append(core_out.detach())
|
||||
# absolute
|
||||
elif self.attn_type == 2:
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
|
||||
|
@ -729,19 +743,18 @@ class MemTransformerLM(nn.Module):
|
|||
|
||||
core_out = self.drop(word_emb + pos_emb[-qlen:])
|
||||
|
||||
hids.append(core_out.detach())
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out.detach())
|
||||
mems_i = None if mems is None else mems[i]
|
||||
if mems_i is not None and len(mems_i) and i == 0:
|
||||
mems_i += pos_emb[:mlen]
|
||||
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i)
|
||||
hids.append(core_out.detach())
|
||||
elif self.attn_type == 3:
|
||||
core_out = self.drop(word_emb)
|
||||
|
||||
hids.append(core_out.detach())
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out.detach())
|
||||
mems_i = None if mems is None else mems[i]
|
||||
if mems_i is not None and len(mems_i) and mlen > 0:
|
||||
cur_emb = self.r_emb[i][:-qlen]
|
||||
|
@ -756,7 +769,6 @@ class MemTransformerLM(nn.Module):
|
|||
|
||||
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i)
|
||||
hids.append(core_out.detach())
|
||||
|
||||
core_out = self.drop(core_out)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ for (( i = 0; i < ${#TYPES[@]}; i++ )); do
|
|||
DIR="LM-TFM/inference/${GPU}_${BATCH_SIZES[j]}_${MATHS_FULL[k]}_${TYPES[i]}"
|
||||
mkdir -p "${DIR}"
|
||||
|
||||
taskset -c 0 bash run_wt103_"${MODEL}".sh eval 1 \
|
||||
bash run_wt103_"${MODEL}".sh eval 1 \
|
||||
--work_dir "${DIR}" \
|
||||
--model "${CHECKPOINT}" \
|
||||
--type "${TYPES[i]}" \
|
||||
|
|
|
@ -35,6 +35,10 @@ try:
|
|||
from apex import amp
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn('APEX AMP is unavailable')
|
||||
try:
|
||||
import pyprof
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn('PyProf is unavailable')
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
@ -111,6 +115,15 @@ def parse_args():
|
|||
help='Optimization level for apex amp')
|
||||
general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
|
||||
help='Implementation of automatic mixed precision')
|
||||
general.add_argument('--affinity', type=str,
|
||||
default='socket_unique_interleaved',
|
||||
choices=['socket', 'single', 'single_unique',
|
||||
'socket_unique_interleaved',
|
||||
'socket_unique_continuous',
|
||||
'disabled'],
|
||||
help='type of CPU affinity')
|
||||
general.add_argument('--profile', action='store_true',
|
||||
help='Enable profiling with DLProf')
|
||||
|
||||
dataset = parser.add_argument_group('dataset setup')
|
||||
dataset.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||
|
@ -256,6 +269,19 @@ def parse_args():
|
|||
if args.ext_len < 0:
|
||||
raise RuntimeError('Extended context length must be non-negative')
|
||||
|
||||
if args.mem_len == 0:
|
||||
if args.eval_tgt_len > args.ext_len + args.tgt_len:
|
||||
raise RuntimeError('eval_tgt_len should be <= tgt_len + ext_len; '
|
||||
f'eval_tgt_len: {args.eval_tgt_len}, '
|
||||
f'tgt_len: {args.tgt_len}, '
|
||||
f'ext_len: {args.ext_len}')
|
||||
else:
|
||||
if args.eval_tgt_len > args.mem_len + args.tgt_len:
|
||||
raise RuntimeError('eval_tgt_len should be <= tgt_len + mem_len; '
|
||||
f'eval_tgt_len: {args.eval_tgt_len}, '
|
||||
f'tgt_len: {args.tgt_len}, '
|
||||
f'mem_len: {args.mem_len}')
|
||||
|
||||
if args.batch_size % args.batch_chunk != 0:
|
||||
raise RuntimeError('Batch size needs to be divisible by batch chunk')
|
||||
|
||||
|
@ -421,7 +447,7 @@ def evaluate(eval_iter, model, args):
|
|||
loss, mems = model(data, target, mems)
|
||||
loss = loss.float().mean()
|
||||
if warm:
|
||||
assert (mems is None) or mems.size(1) == model.mem_len
|
||||
# assert (mems is None) or mems.size(1) == model.mem_len
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
|
||||
|
@ -659,7 +685,14 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
utils.gpu_affinity.set_affinity(args.local_rank)
|
||||
if args.affinity != 'disabled':
|
||||
nproc_per_node = torch.cuda.device_count()
|
||||
affinity = utils.gpu_affinity.set_affinity(
|
||||
args.local_rank,
|
||||
nproc_per_node,
|
||||
args.affinity
|
||||
)
|
||||
print(f'{args.local_rank}: thread affinity: {affinity}')
|
||||
|
||||
# Initialize device and distributed backend
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
|
@ -703,6 +736,12 @@ def main():
|
|||
logging.info(f'--local_batch_size was set, adjusting global batch size'
|
||||
f' to {args.batch_size} (local_batch_size * world_size)')
|
||||
|
||||
if args.profile:
|
||||
try:
|
||||
pyprof.init(enable_function_stack=True)
|
||||
except NameError:
|
||||
warnings.warn('Called pyprof.init() but pyprof is not available')
|
||||
|
||||
logging.info(args)
|
||||
dllogger.log(step='PARAMETER', data=vars(args))
|
||||
|
||||
|
@ -956,28 +995,30 @@ def main():
|
|||
# Loop over epochs.
|
||||
# At any point you can hit Ctrl + C to break out of training early.
|
||||
start_time = time.time()
|
||||
with TimeoutHandler() as timeout_handler:
|
||||
try:
|
||||
for epoch in itertools.count(start=start_epoch):
|
||||
if args.roll:
|
||||
tr_iter.roll(seed=args.seed + epoch)
|
||||
train_step, best_val_loss = train(
|
||||
tr_iter, va_iter, model, para_model, model_config,
|
||||
optimizer, optimizer_sparse, scheduler, scheduler_sparse,
|
||||
scaler, vocab, epoch, last_batch, last_iter, train_step,
|
||||
best_val_loss, meters, timeout_handler, device, args
|
||||
)
|
||||
with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
|
||||
with TimeoutHandler() as timeout_handler:
|
||||
try:
|
||||
for epoch in itertools.count(start=start_epoch):
|
||||
if args.roll:
|
||||
tr_iter.roll(seed=args.seed + epoch)
|
||||
train_step, best_val_loss = train(
|
||||
tr_iter, va_iter, model, para_model, model_config,
|
||||
optimizer, optimizer_sparse, scheduler,
|
||||
scheduler_sparse, scaler, vocab, epoch, last_batch,
|
||||
last_iter, train_step, best_val_loss, meters,
|
||||
timeout_handler, device, args
|
||||
)
|
||||
|
||||
last_batch = 0
|
||||
last_iter = 0
|
||||
last_batch = 0
|
||||
last_iter = 0
|
||||
|
||||
if train_step == args.max_step:
|
||||
logging.info('-' * 100)
|
||||
logging.info('End of training')
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
logging.info('-' * 100)
|
||||
logging.info('Exiting from training early')
|
||||
if train_step == args.max_step:
|
||||
logging.info('-' * 100)
|
||||
logging.info('End of training')
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
logging.info('-' * 100)
|
||||
logging.info('Exiting from training early')
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
###########################################################################
|
||||
|
@ -992,8 +1033,9 @@ def main():
|
|||
|
||||
# Run on test data.
|
||||
test_start_time = time.time()
|
||||
test_loss = evaluate(te_iter, model, args)
|
||||
test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
|
||||
with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
|
||||
test_loss = evaluate(te_iter, model, args)
|
||||
test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
|
||||
test_elapsed = time.time() - test_start_time
|
||||
|
||||
logging.info('=' * 100)
|
||||
|
|
|
@ -155,6 +155,10 @@ def setup_logging(log_all_ranks=True, filename=os.devnull, filemode='w'):
|
|||
if rank != 0:
|
||||
filename = os.devnull
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
handler.close()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format=logging_format,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import collections
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
import pynvml
|
||||
|
||||
|
@ -35,15 +38,105 @@ class device:
|
|||
affinity_list = [int(x) for x in affinity_string]
|
||||
affinity_list.reverse() # so core 0 is in 0th element of list
|
||||
|
||||
return [i for i, e in enumerate(affinity_list) if e != 0]
|
||||
ret = [i for i, e in enumerate(affinity_list) if e != 0]
|
||||
return ret
|
||||
|
||||
|
||||
def set_affinity(gpu_id=None):
|
||||
if gpu_id is None:
|
||||
gpu_id = int(os.getenv('LOCAL_RANK', 0))
|
||||
|
||||
def set_socket_affinity(gpu_id):
|
||||
dev = device(gpu_id)
|
||||
os.sched_setaffinity(0, dev.getCpuAffinity())
|
||||
affinity = dev.getCpuAffinity()
|
||||
os.sched_setaffinity(0, affinity)
|
||||
|
||||
# list of ints representing the logical cores this process is now affinitied with
|
||||
return os.sched_getaffinity(0)
|
||||
|
||||
def set_single_affinity(gpu_id):
|
||||
dev = device(gpu_id)
|
||||
affinity = dev.getCpuAffinity()
|
||||
os.sched_setaffinity(0, affinity[:1])
|
||||
|
||||
|
||||
def set_single_unique_affinity(gpu_id, nproc_per_node):
|
||||
devices = [device(i) for i in range(nproc_per_node)]
|
||||
socket_affinities = [dev.getCpuAffinity() for dev in devices]
|
||||
|
||||
siblings_list = get_thread_siblings_list()
|
||||
siblings_dict = dict(siblings_list)
|
||||
|
||||
# remove siblings
|
||||
for idx, socket_affinity in enumerate(socket_affinities):
|
||||
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
||||
|
||||
affinities = []
|
||||
assigned = []
|
||||
|
||||
for socket_affinity in socket_affinities:
|
||||
for core in socket_affinity:
|
||||
if core not in assigned:
|
||||
affinities.append([core])
|
||||
assigned.append(core)
|
||||
break
|
||||
os.sched_setaffinity(0, affinities[gpu_id])
|
||||
|
||||
|
||||
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode):
|
||||
device_ids = [device(i) for i in range(nproc_per_node)]
|
||||
socket_affinities = [dev.getCpuAffinity() for dev in device_ids]
|
||||
|
||||
siblings_list = get_thread_siblings_list()
|
||||
siblings_dict = dict(siblings_list)
|
||||
|
||||
# remove siblings
|
||||
for idx, socket_affinity in enumerate(socket_affinities):
|
||||
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
||||
|
||||
socket_affinities_to_device_ids = collections.defaultdict(list)
|
||||
|
||||
for idx, socket_affinity in enumerate(socket_affinities):
|
||||
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
|
||||
|
||||
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
|
||||
devices_per_group = len(device_ids)
|
||||
cores_per_device = len(socket_affinity) // devices_per_group
|
||||
for group_id, device_id in enumerate(device_ids):
|
||||
if device_id == gpu_id:
|
||||
if mode == 'interleaved':
|
||||
affinity = list(socket_affinity[group_id::devices_per_group])
|
||||
elif mode == 'continuous':
|
||||
affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device])
|
||||
else:
|
||||
raise RuntimeError('Unknown set_socket_unique_affinity mode')
|
||||
|
||||
# reintroduce siblings
|
||||
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
|
||||
os.sched_setaffinity(0, affinity)
|
||||
|
||||
|
||||
def get_thread_siblings_list():
|
||||
path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list'
|
||||
thread_siblings_list = []
|
||||
pattern = re.compile(r'(\d+)\D(\d+)')
|
||||
for fname in pathlib.Path(path[0]).glob(path[1:]):
|
||||
with open(fname) as f:
|
||||
content = f.read().strip()
|
||||
res = pattern.findall(content)
|
||||
if res:
|
||||
pair = tuple(map(int, res[0]))
|
||||
thread_siblings_list.append(pair)
|
||||
return thread_siblings_list
|
||||
|
||||
|
||||
def set_affinity(gpu_id, nproc_per_node, mode='socket'):
|
||||
if mode == 'socket':
|
||||
set_socket_affinity(gpu_id)
|
||||
elif mode == 'single':
|
||||
set_single_affinity(gpu_id)
|
||||
elif mode == 'single_unique':
|
||||
set_single_unique_affinity(gpu_id, nproc_per_node)
|
||||
elif mode == 'socket_unique_interleaved':
|
||||
set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved')
|
||||
elif mode == 'socket_unique_continuous':
|
||||
set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous')
|
||||
else:
|
||||
raise RuntimeError('Unknown affinity mode')
|
||||
|
||||
affinity = os.sched_getaffinity(0)
|
||||
return affinity
|
||||
|
|
Loading…
Reference in a new issue