From f548c8184614c48857fb5bed53f1e5072ab54d28 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 13 Mar 2022 17:45:02 +0800 Subject: [PATCH] [Enhancement] Handle tensor device type in sync_random_seed (#126) --- mmengine/dist/dist.py | 10 ++++++++-- tests/test_dist/test_dist.py | 3 +-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index 6569d901..2ca21ca6 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -305,10 +305,16 @@ def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int: if group is None: group = get_default_group() + group_backend = get_backend(group) + is_nccl_backend = group_backend == dist.Backend.NCCL + current_device = torch.device('cpu') + if is_nccl_backend: + current_device = torch.device('cuda', torch.cuda.current_device()) + if get_rank(group) == 0: - random_num = torch.tensor(seed, dtype=torch.int32) + random_num = torch.tensor(seed, dtype=torch.int32).to(current_device) else: - random_num = torch.tensor(0, dtype=torch.int32) + random_num = torch.tensor(0, dtype=torch.int32).to(current_device) dist.broadcast(random_num, src=0, group=group) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index 78a55c54..3dccb075 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -190,8 +190,7 @@ def _test_broadcast_dist(device): def _test_sync_random_seed_dist(device): with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024).to(device)) as mock_tensor: + torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() -- GitLab