2018-08-07 16:27:43 +02:00
|
|
|
import logging.config
|
2019-02-14 12:40:30 +01:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
from contextlib import contextmanager
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2019-02-14 12:40:30 +01:00
|
|
|
import torch.nn.init as init
|
|
|
|
import torch.utils.collect_env
|
|
|
|
|
|
|
|
|
|
|
|
def init_lstm_(lstm, init_weight=0.1):
|
|
|
|
"""
|
|
|
|
Initializes weights of LSTM layer.
|
|
|
|
Weights and biases are initialized with uniform(-init_weight, init_weight)
|
|
|
|
distribution.
|
|
|
|
|
|
|
|
:param lstm: instance of torch.nn.LSTM
|
|
|
|
:param init_weight: range for the uniform initializer
|
|
|
|
"""
|
|
|
|
# Initialize hidden-hidden weights
|
|
|
|
init.uniform_(lstm.weight_hh_l0.data, -init_weight, init_weight)
|
|
|
|
# Initialize input-hidden weights:
|
|
|
|
init.uniform_(lstm.weight_ih_l0.data, -init_weight, init_weight)
|
|
|
|
|
|
|
|
# Initialize bias. PyTorch LSTM has two biases, one for input-hidden GEMM
|
|
|
|
# and the other for hidden-hidden GEMM. Here input-hidden bias is
|
|
|
|
# initialized with uniform distribution and hidden-hidden bias is
|
|
|
|
# initialized with zeros.
|
|
|
|
init.uniform_(lstm.bias_ih_l0.data, -init_weight, init_weight)
|
|
|
|
init.zeros_(lstm.bias_hh_l0.data)
|
|
|
|
|
|
|
|
if lstm.bidirectional:
|
|
|
|
init.uniform_(lstm.weight_hh_l0_reverse.data, -init_weight, init_weight)
|
|
|
|
init.uniform_(lstm.weight_ih_l0_reverse.data, -init_weight, init_weight)
|
|
|
|
|
|
|
|
init.uniform_(lstm.bias_ih_l0_reverse.data, -init_weight, init_weight)
|
|
|
|
init.zeros_(lstm.bias_hh_l0_reverse.data)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_seeds(rng, size):
|
|
|
|
"""
|
|
|
|
Generate list of random seeds
|
|
|
|
|
|
|
|
:param rng: random number generator
|
|
|
|
:param size: length of the returned list
|
|
|
|
"""
|
|
|
|
seeds = [rng.randint(0, 2**32 - 1) for _ in range(size)]
|
|
|
|
return seeds
|
|
|
|
|
|
|
|
|
|
|
|
def broadcast_seeds(seeds, device):
|
|
|
|
"""
|
|
|
|
Broadcasts random seeds to all distributed workers.
|
|
|
|
Returns list of random seeds (broadcasted from workers with rank 0).
|
|
|
|
|
|
|
|
:param seeds: list of seeds (integers)
|
|
|
|
:param device: torch.device
|
|
|
|
"""
|
|
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
|
|
seeds_tensor = torch.LongTensor(seeds).to(device)
|
|
|
|
torch.distributed.broadcast(seeds_tensor, 0)
|
|
|
|
seeds = seeds_tensor.tolist()
|
|
|
|
return seeds
|
|
|
|
|
|
|
|
|
|
|
|
def setup_seeds(master_seed, epochs, device):
|
|
|
|
"""
|
|
|
|
Generates seeds from one master_seed.
|
|
|
|
Function returns (worker_seeds, shuffling_seeds), worker_seeds are later
|
|
|
|
used to initialize per-worker random number generators (mostly for
|
|
|
|
dropouts), shuffling_seeds are for RNGs resposible for reshuffling the
|
|
|
|
dataset before each epoch.
|
|
|
|
Seeds are generated on worker with rank 0 and broadcasted to all other
|
|
|
|
workers.
|
|
|
|
|
|
|
|
:param master_seed: master RNG seed used to initialize other generators
|
|
|
|
:param epochs: number of epochs
|
|
|
|
:param device: torch.device (used for distributed.broadcast)
|
|
|
|
"""
|
|
|
|
if master_seed is None:
|
|
|
|
# random master seed, random.SystemRandom() uses /dev/urandom on Unix
|
|
|
|
master_seed = random.SystemRandom().randint(0, 2**32 - 1)
|
|
|
|
if get_rank() == 0:
|
|
|
|
# master seed is reported only from rank=0 worker, it's to avoid
|
|
|
|
# confusion, seeds from rank=0 are later broadcasted to other
|
|
|
|
# workers
|
|
|
|
logging.info(f'Using random master seed: {master_seed}')
|
|
|
|
else:
|
|
|
|
# master seed was specified from command line
|
|
|
|
logging.info(f'Using master seed from command line: {master_seed}')
|
|
|
|
|
|
|
|
# initialize seeding RNG
|
|
|
|
seeding_rng = random.Random(master_seed)
|
|
|
|
|
|
|
|
# generate worker seeds, one seed for every distributed worker
|
|
|
|
worker_seeds = generate_seeds(seeding_rng, get_world_size())
|
|
|
|
|
|
|
|
# generate seeds for data shuffling, one seed for every epoch
|
|
|
|
shuffling_seeds = generate_seeds(seeding_rng, epochs)
|
|
|
|
|
|
|
|
# broadcast seeds from rank=0 to other workers
|
|
|
|
worker_seeds = broadcast_seeds(worker_seeds, device)
|
|
|
|
shuffling_seeds = broadcast_seeds(shuffling_seeds, device)
|
|
|
|
return worker_seeds, shuffling_seeds
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
|
|
|
|
def barrier():
|
|
|
|
"""
|
|
|
|
Works as a temporary distributed barrier, currently pytorch
|
|
|
|
doesn't implement barrier for NCCL backend.
|
|
|
|
Calls all_reduce on dummy tensor and synchronizes with GPU.
|
|
|
|
"""
|
2019-02-14 12:40:30 +01:00
|
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
2018-08-07 16:27:43 +02:00
|
|
|
torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
def get_rank():
|
|
|
|
"""
|
|
|
|
Gets distributed rank or returns zero if distributed is not initialized.
|
|
|
|
"""
|
2019-02-14 12:40:30 +01:00
|
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
2018-08-07 16:27:43 +02:00
|
|
|
rank = torch.distributed.get_rank()
|
|
|
|
else:
|
|
|
|
rank = 0
|
|
|
|
return rank
|
|
|
|
|
|
|
|
|
|
|
|
def get_world_size():
|
|
|
|
"""
|
|
|
|
Gets total number of distributed workers or returns one if distributed is
|
|
|
|
not initialized.
|
|
|
|
"""
|
2019-02-14 12:40:30 +01:00
|
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
2018-08-07 16:27:43 +02:00
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
else:
|
|
|
|
world_size = 1
|
|
|
|
return world_size
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def sync_workers():
|
|
|
|
"""
|
|
|
|
Yields distributed rank and synchronizes all workers on exit.
|
|
|
|
"""
|
|
|
|
rank = get_rank()
|
|
|
|
yield rank
|
|
|
|
barrier()
|
|
|
|
|
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
@contextmanager
|
|
|
|
def timer(name, ndigits=2, sync_gpu=True):
|
|
|
|
if sync_gpu:
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
start = time.time()
|
|
|
|
yield
|
|
|
|
if sync_gpu:
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
stop = time.time()
|
|
|
|
elapsed = round(stop - start, ndigits)
|
|
|
|
logging.info(f'TIMER {name} {elapsed}')
|
|
|
|
|
|
|
|
|
|
|
|
def setup_logging(log_file=os.devnull):
|
2018-08-07 16:27:43 +02:00
|
|
|
"""
|
|
|
|
Configures logging.
|
|
|
|
By default logs from all workers are printed to the console, entries are
|
|
|
|
prefixed with "N: " where N is the rank of the worker. Logs printed to the
|
|
|
|
console don't include timestaps.
|
|
|
|
Full logs with timestamps are saved to the log_file file.
|
|
|
|
"""
|
|
|
|
class RankFilter(logging.Filter):
|
|
|
|
def __init__(self, rank):
|
|
|
|
self.rank = rank
|
|
|
|
|
|
|
|
def filter(self, record):
|
|
|
|
record.rank = self.rank
|
|
|
|
return True
|
|
|
|
|
|
|
|
rank = get_rank()
|
|
|
|
rank_filter = RankFilter(rank)
|
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
logging_format = "%(asctime)s - %(levelname)s - %(rank)s - %(message)s"
|
2018-08-07 16:27:43 +02:00
|
|
|
logging.basicConfig(level=logging.DEBUG,
|
2019-02-14 12:40:30 +01:00
|
|
|
format=logging_format,
|
2018-08-07 16:27:43 +02:00
|
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
|
|
filename=log_file,
|
|
|
|
filemode='w')
|
2019-02-14 12:40:30 +01:00
|
|
|
console = logging.StreamHandler(sys.stdout)
|
2018-08-07 16:27:43 +02:00
|
|
|
console.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter('%(rank)s: %(message)s')
|
|
|
|
console.setFormatter(formatter)
|
|
|
|
logging.getLogger('').addHandler(console)
|
|
|
|
logging.getLogger('').addFilter(rank_filter)
|
|
|
|
|
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
def set_device(cuda, local_rank):
|
|
|
|
"""
|
|
|
|
Sets device based on local_rank and returns instance of torch.device.
|
|
|
|
|
|
|
|
:param cuda: if True: use cuda
|
|
|
|
:param local_rank: local rank of the worker
|
|
|
|
"""
|
|
|
|
if cuda:
|
|
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
device = torch.device('cuda')
|
|
|
|
else:
|
|
|
|
device = torch.device('cpu')
|
|
|
|
return device
|
|
|
|
|
|
|
|
|
|
|
|
def init_distributed(cuda):
|
|
|
|
"""
|
|
|
|
Initializes distributed backend.
|
|
|
|
|
|
|
|
:param cuda: (bool) if True initializes nccl backend, if False initializes
|
|
|
|
gloo backend
|
|
|
|
"""
|
|
|
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
|
|
distributed = (world_size > 1)
|
|
|
|
if distributed:
|
|
|
|
backend = 'nccl' if cuda else 'gloo'
|
|
|
|
dist.init_process_group(backend=backend,
|
|
|
|
init_method='env://')
|
|
|
|
assert dist.is_initialized()
|
|
|
|
return distributed
|
|
|
|
|
|
|
|
|
|
|
|
def log_env_info():
|
|
|
|
"""
|
|
|
|
Prints information about execution environment.
|
|
|
|
"""
|
|
|
|
logging.info('Collecting environment information...')
|
|
|
|
env_info = torch.utils.collect_env.get_pretty_env_info()
|
|
|
|
logging.info(f'{env_info}')
|
|
|
|
|
|
|
|
|
|
|
|
def pad_vocabulary(math):
|
|
|
|
if math == 'fp16':
|
|
|
|
pad_vocab = 8
|
|
|
|
elif math == 'fp32':
|
|
|
|
pad_vocab = 1
|
|
|
|
return pad_vocab
|
|
|
|
|
|
|
|
|
|
|
|
def benchmark(test_acc, target_acc, test_perf, target_perf):
|
|
|
|
def test(achieved, target, name):
|
|
|
|
passed = True
|
|
|
|
if target is not None and achieved is not None:
|
|
|
|
logging.info(f'{name} achieved: {achieved:.2f} '
|
|
|
|
f'target: {target:.2f}')
|
|
|
|
if achieved >= target:
|
|
|
|
logging.info(f'{name} test passed')
|
|
|
|
else:
|
|
|
|
logging.info(f'{name} test failed')
|
|
|
|
passed = False
|
|
|
|
return passed
|
|
|
|
|
|
|
|
passed = True
|
|
|
|
passed &= test(test_acc, target_acc, 'Accuracy')
|
|
|
|
passed &= test(test_perf, target_perf, 'Performance')
|
|
|
|
return passed
|
|
|
|
|
2018-08-07 16:27:43 +02:00
|
|
|
class AverageMeter:
|
|
|
|
"""
|
|
|
|
Computes and stores the average and current value
|
|
|
|
"""
|
|
|
|
def __init__(self, skip_first=True):
|
|
|
|
self.reset()
|
|
|
|
self.skip = skip_first
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.val = 0
|
|
|
|
self.avg = 0
|
|
|
|
self.sum = 0
|
|
|
|
self.count = 0
|
|
|
|
|
|
|
|
def update(self, val, n=1):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
if self.skip:
|
|
|
|
self.skip = False
|
|
|
|
else:
|
|
|
|
self.sum += val * n
|
|
|
|
self.count += n
|
|
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
|
|
def reduce(self, op):
|
|
|
|
"""
|
|
|
|
Reduces average value over all workers.
|
|
|
|
|
|
|
|
:param op: 'sum' or 'mean', reduction operator
|
|
|
|
"""
|
|
|
|
if op not in ('sum', 'mean'):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
distributed = (get_world_size() > 1)
|
|
|
|
if distributed:
|
2019-02-14 12:40:30 +01:00
|
|
|
# Backward/forward compatibility around
|
|
|
|
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
|
|
|
|
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
|
|
|
|
# To accomodate change in Pytorch's distributed API
|
|
|
|
if hasattr(dist, "get_backend"):
|
|
|
|
_backend = dist.get_backend()
|
|
|
|
if hasattr(dist, "DistBackend"):
|
|
|
|
backend_enum_holder = dist.DistBackend
|
|
|
|
else:
|
|
|
|
backend_enum_holder = dist.Backend
|
|
|
|
else:
|
|
|
|
_backend = dist._backend
|
|
|
|
backend_enum_holder = dist.dist_backend
|
|
|
|
|
|
|
|
cuda = _backend == backend_enum_holder.NCCL
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
if cuda:
|
|
|
|
avg = torch.cuda.FloatTensor([self.avg])
|
2019-02-14 12:40:30 +01:00
|
|
|
_sum = torch.cuda.FloatTensor([self.sum])
|
2018-08-07 16:27:43 +02:00
|
|
|
else:
|
|
|
|
avg = torch.FloatTensor([self.avg])
|
2019-02-14 12:40:30 +01:00
|
|
|
_sum = torch.FloatTensor([self.sum])
|
|
|
|
|
|
|
|
try:
|
|
|
|
_reduce_op = dist.ReduceOp
|
|
|
|
except AttributeError:
|
|
|
|
_reduce_op = dist.reduce_op
|
2018-08-07 16:27:43 +02:00
|
|
|
|
2019-02-14 12:40:30 +01:00
|
|
|
dist.all_reduce(avg, op=_reduce_op.SUM)
|
|
|
|
dist.all_reduce(_sum, op=_reduce_op.SUM)
|
2018-08-07 16:27:43 +02:00
|
|
|
self.avg = avg.item()
|
2019-02-14 12:40:30 +01:00
|
|
|
self.sum = _sum.item()
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
if op == 'mean':
|
|
|
|
self.avg /= get_world_size()
|
2019-02-14 12:40:30 +01:00
|
|
|
self.sum /= get_world_size()
|
2018-08-07 16:27:43 +02:00
|
|
|
|
|
|
|
|
|
|
|
def debug_tensor(tensor, name):
|
|
|
|
"""
|
|
|
|
Simple utility which helps with debugging.
|
|
|
|
Takes a tensor and outputs: min, max, avg, std, number of NaNs, number of
|
|
|
|
INFs.
|
|
|
|
|
|
|
|
:param tensor: torch tensor
|
|
|
|
:param name: name of the tensor (only for logging)
|
|
|
|
"""
|
|
|
|
logging.info(name)
|
|
|
|
tensor = tensor.detach().float().cpu().numpy()
|
|
|
|
logging.info(f'MIN: {tensor.min()} MAX: {tensor.max()} '
|
|
|
|
f'AVG: {tensor.mean()} STD: {tensor.std()} '
|
|
|
|
f'NAN: {np.isnan(tensor).sum()} INF: {np.isinf(tensor).sum()}')
|