# Copyright (c) 2018-2019, NVIDIA CORPORATION # Copyright (c) 2017- Facebook, Inc # # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # * Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os import numpy as np import torch import shutil import torch.distributed as dist def should_backup_checkpoint(args): def _sbc(epoch): return args.gather_checkpoints and (epoch < 10 or epoch % 10 == 0) return _sbc def save_checkpoint( state, is_best, filename="checkpoint.pth.tar", checkpoint_dir="./", backup_filename=None, ): if (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0: filename = os.path.join(checkpoint_dir, filename) print("SAVING {}".format(filename)) torch.save(state, filename) if is_best: shutil.copyfile( filename, os.path.join(checkpoint_dir, "model_best.pth.tar") ) if backup_filename is not None: shutil.copyfile(filename, os.path.join(checkpoint_dir, backup_filename)) def timed_generator(gen): start = time.time() for g in gen: end = time.time() t = end - start yield g, t start = time.time() def timed_function(f): def _timed_function(*args, **kwargs): start = time.time() ret = f(*args, **kwargs) return ret, time.time() - start return _timed_function def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def reduce_tensor(tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= ( torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 ) return rt def first_n(n, generator): for i, d in zip(range(n), generator): yield d