12 lines
240 B
Python
12 lines
240 B
Python
import torch
|
|
import torch.distributed as dist
|
|
|
|
def get_rank():
|
|
if not dist.is_available():
|
|
return 0
|
|
if not dist.is_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0 |