fixing rng_state for backward compatibility
This commit is contained in:
parent
4a64c5b03b
commit
9a6c5241d7
|
@ -250,7 +250,12 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
|
|||
epoch[0] = checkpoint['epoch']+1
|
||||
device_id = local_rank % torch.cuda.device_count()
|
||||
torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
|
||||
torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
|
||||
if 'random_rng_states_all' in checkpoint:
|
||||
torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
|
||||
elif 'random_rng_state' in checkpoint:
|
||||
torch.random.set_rng_state(checkpoint['random_rng_state'])
|
||||
else:
|
||||
raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
|
||||
config = checkpoint['config']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
|
Loading…
Reference in a new issue