Skip to content
Snippets Groups Projects
Unverified Commit 2b8a32ec authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

[Fix]: fix RuntimeError of SyncBuffersHook (#309)

* fix RuntimeError of SyncBuffersHook

* add UT
parent e18832f0
No related branches found
No related tags found
No related merge requests found
......@@ -94,7 +94,11 @@ def all_reduce(data: Tensor,
# it with 'sum' operation.
if op.lower() == 'mean':
torch_dist.all_reduce(data_on_device, _get_reduce_op('sum'), group)
data_on_device.div_(world_size) # type: ignore
# When the type of `data_on_device` is int64,
# `data_on_device.div_(world_size)` will appear RuntimeError:
# result type Float can't be cast to the desired output type Long.
data_on_device = data_on_device / world_size # type: ignore
else:
torch_dist.all_reduce(data_on_device, _get_reduce_op(op), group)
......
......@@ -132,8 +132,9 @@ class TestDistWithGLOOBackend(MultiProcessTestCase):
def test_all_reduce(self):
self._init_dist_env(self.rank, self.world_size)
for tensor_type, reduce_op in zip([torch.int64, torch.float32],
['sum', 'mean']):
tensor_types = [torch.int64, torch.float32, torch.int64]
reduce_ops = ['sum', 'mean', 'mean']
for tensor_type, reduce_op in zip(tensor_types, reduce_ops):
if dist.get_rank() == 0:
data = torch.tensor([1, 2], dtype=tensor_type)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment