Skip to content
Snippets Groups Projects
Unverified Commit f548c818 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhancement] Handle tensor device type in sync_random_seed (#126)

parent 6d73b6cd
No related branches found
No related tags found
No related merge requests found
...@@ -305,10 +305,16 @@ def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int: ...@@ -305,10 +305,16 @@ def sync_random_seed(group: Optional[dist.ProcessGroup] = None) -> int:
if group is None: if group is None:
group = get_default_group() group = get_default_group()
group_backend = get_backend(group)
is_nccl_backend = group_backend == dist.Backend.NCCL
current_device = torch.device('cpu')
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
if get_rank(group) == 0: if get_rank(group) == 0:
random_num = torch.tensor(seed, dtype=torch.int32) random_num = torch.tensor(seed, dtype=torch.int32).to(current_device)
else: else:
random_num = torch.tensor(0, dtype=torch.int32) random_num = torch.tensor(0, dtype=torch.int32).to(current_device)
dist.broadcast(random_num, src=0, group=group) dist.broadcast(random_num, src=0, group=group)
......
...@@ -190,8 +190,7 @@ def _test_broadcast_dist(device): ...@@ -190,8 +190,7 @@ def _test_broadcast_dist(device):
def _test_sync_random_seed_dist(device): def _test_sync_random_seed_dist(device):
with patch.object( with patch.object(
torch, 'tensor', torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor:
return_value=torch.tensor(1024).to(device)) as mock_tensor:
output = dist.sync_random_seed() output = dist.sync_random_seed()
assert output == 1024 assert output == 1024
mock_tensor.assert_called() mock_tensor.assert_called()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment