Skip to content
Snippets Groups Projects
Unverified Commit a1adbff1 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Get local_rank in init_dist_mpi from env (#212)

parent 0c59eeab
No related branches found
No related tags found
No related merge requests found
......@@ -82,10 +82,15 @@ def _init_dist_mpi(backend, **kwargs) -> None:
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
if 'MASTER_ADDR' not in os.environ:
raise KeyError('The environment variable MASTER_ADDR is not set')
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
torch_dist.init_process_group(backend=backend, **kwargs)
......@@ -99,8 +104,6 @@ def _init_dist_slurm(backend, port=None) -> None:
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
TODO: https://github.com/open-mmlab/mmcv/pull/1682
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment