fixing rng_state for backward compatibility

This commit is contained in:
gkarch 2020-11-19 14:21:05 +01:00
parent 4a64c5b03b
commit 9a6c5241d7

View file

@ -250,7 +250,12 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
epoch[0] = checkpoint['epoch']+1
device_id = local_rank % torch.cuda.device_count()
torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
if 'random_rng_states_all' in checkpoint:
torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
elif 'random_rng_state' in checkpoint:
torch.random.set_rng_state(checkpoint['random_rng_state'])
else:
raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
config = checkpoint['config']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])