From f548c8184614c48857fb5bed53f1e5072ab54d28 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Sun, 13 Mar 2022 17:45:02 +0800
Subject: [PATCH] [Enhancement] Handle tensor device type in sync_random_seed
 (#126)

---
 mmengine/dist/dist.py        | 10 ++++++++--
 tests/test_dist/test_dist.py |  3 +--
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py
index 6569d901..2ca21ca6 100644
--- a/mmengine/dist/dist.py
+++ b/mmengine/dist/dist.py
@@ -305,10 +305,16 @@ def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int:
     if group is None:
         group = get_default_group()
 
+    group_backend = get_backend(group)
+    is_nccl_backend = group_backend == dist.Backend.NCCL
+    current_device = torch.device('cpu')
+    if is_nccl_backend:
+        current_device = torch.device('cuda', torch.cuda.current_device())
+
     if get_rank(group) == 0:
-        random_num = torch.tensor(seed, dtype=torch.int32)
+        random_num = torch.tensor(seed, dtype=torch.int32).to(current_device)
     else:
-        random_num = torch.tensor(0, dtype=torch.int32)
+        random_num = torch.tensor(0, dtype=torch.int32).to(current_device)
 
     dist.broadcast(random_num, src=0, group=group)
 
diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py
index 78a55c54..3dccb075 100644
--- a/tests/test_dist/test_dist.py
+++ b/tests/test_dist/test_dist.py
@@ -190,8 +190,7 @@ def _test_broadcast_dist(device):
 
 def _test_sync_random_seed_dist(device):
     with patch.object(
-            torch, 'tensor',
-            return_value=torch.tensor(1024).to(device)) as mock_tensor:
+            torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor:
         output = dist.sync_random_seed()
         assert output == 1024
     mock_tensor.assert_called()
-- 
GitLab