diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py
index 179d003836a7a9c7a03f653106a66ca5452cddf5..17f707d3bb8cdc099b876082cb14bba6afecc625 100644
--- a/mmengine/dataset/base_dataset.py
+++ b/mmengine/dataset/base_dataset.py
@@ -229,7 +229,7 @@ class BaseDataset(Dataset):
         self.test_mode = test_mode
         self.max_refetch = max_refetch
         self.data_list: List[dict] = []
-        self.date_bytes: np.ndarray
+        self.data_bytes: np.ndarray
 
         # Set meta information.
         self._metainfo = self._get_meta_info(copy.deepcopy(metainfo))
@@ -259,7 +259,7 @@ class BaseDataset(Dataset):
             start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
             end_addr = self.data_address[idx].item()
             bytes = memoryview(
-                self.date_bytes[start_addr:end_addr])  # type: ignore
+                self.data_bytes[start_addr:end_addr])  # type: ignore
             data_info = pickle.loads(bytes)  # type: ignore
         else:
             data_info = self.data_list[idx]
@@ -302,7 +302,7 @@ class BaseDataset(Dataset):
 
         # serialize data_list
         if self.serialize_data:
-            self.date_bytes, self.data_address = self._serialize_data()
+            self.data_bytes, self.data_address = self._serialize_data()
 
         self._fully_initialized = True
 
@@ -575,7 +575,7 @@ class BaseDataset(Dataset):
         # Get subset of data from serialized data or data information sequence
         # according to `self.serialize_data`.
         if self.serialize_data:
-            self.date_bytes, self.data_address = \
+            self.data_bytes, self.data_address = \
                 self._get_serialized_subset(indices)
         else:
             self.data_list = self._get_unserialized_subset(indices)
@@ -626,9 +626,9 @@ class BaseDataset(Dataset):
         sub_dataset = self._copy_without_annotation()
         # Get subset of dataset with serialize and unserialized data.
         if self.serialize_data:
-            date_bytes, data_address = \
+            data_bytes, data_address = \
                 self._get_serialized_subset(indices)
-            sub_dataset.date_bytes = date_bytes.copy()
+            sub_dataset.data_bytes = data_bytes.copy()
             sub_dataset.data_address = data_address.copy()
         else:
             data_list = self._get_unserialized_subset(indices)
@@ -650,7 +650,7 @@ class BaseDataset(Dataset):
             Tuple[np.ndarray, np.ndarray]: subset of serialized data
             information.
         """
-        sub_date_bytes: Union[List, np.ndarray]
+        sub_data_bytes: Union[List, np.ndarray]
         sub_data_address: Union[List, np.ndarray]
         if isinstance(indices, int):
             if indices >= 0:
@@ -661,7 +661,7 @@ class BaseDataset(Dataset):
                     if indices > 0 else 0
                 # Slicing operation of `np.ndarray` does not trigger a memory
                 # copy.
-                sub_date_bytes = self.date_bytes[:end_addr]
+                sub_data_bytes = self.data_bytes[:end_addr]
                 # Since the buffer size of first few data information is not
                 # changed,
                 sub_data_address = self.data_address[:indices]
@@ -671,11 +671,11 @@ class BaseDataset(Dataset):
                 # Return the last few data information.
                 ignored_bytes_size = self.data_address[indices - 1]
                 start_addr = self.data_address[indices - 1].item()
-                sub_date_bytes = self.date_bytes[start_addr:]
+                sub_data_bytes = self.data_bytes[start_addr:]
                 sub_data_address = self.data_address[indices:]
                 sub_data_address = sub_data_address - ignored_bytes_size
         elif isinstance(indices, Sequence):
-            sub_date_bytes = []
+            sub_data_bytes = []
             sub_data_address = []
             for idx in indices:
                 assert len(self) > idx >= -len(self)
@@ -683,20 +683,20 @@ class BaseDataset(Dataset):
                     self.data_address[idx - 1].item()
                 end_addr = self.data_address[idx].item()
                 # Get data information by address.
-                sub_date_bytes.append(self.date_bytes[start_addr:end_addr])
+                sub_data_bytes.append(self.data_bytes[start_addr:end_addr])
                 # Get data information size.
                 sub_data_address.append(end_addr - start_addr)
             # Handle indices is an empty list.
-            if sub_date_bytes:
-                sub_date_bytes = np.concatenate(sub_date_bytes)
+            if sub_data_bytes:
+                sub_data_bytes = np.concatenate(sub_data_bytes)
                 sub_data_address = np.cumsum(sub_data_address)
             else:
-                sub_date_bytes = np.array([])
+                sub_data_bytes = np.array([])
                 sub_data_address = np.array([])
         else:
             raise TypeError('indices should be a int or sequence of int, '
                             f'but got {type(indices)}')
-        return sub_date_bytes, sub_data_address  # type: ignore
+        return sub_data_bytes, sub_data_address  # type: ignore
 
     def _get_unserialized_subset(self, indices: Union[Sequence[int],
                                                       int]) -> list:
@@ -795,7 +795,7 @@ class BaseDataset(Dataset):
 
     def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset':
         """Deepcopy for all attributes other than ``data_list``,
-        ``data_address`` and ``date_bytes``.
+        ``data_address`` and ``data_bytes``.
 
         Args:
             memo: Memory dict which used to reconstruct complex object
@@ -806,7 +806,7 @@ class BaseDataset(Dataset):
         memo[id(self)] = other
 
         for key, value in self.__dict__.items():
-            if key in ['data_list', 'data_address', 'date_bytes']:
+            if key in ['data_list', 'data_address', 'data_bytes']:
                 continue
             super(BaseDataset, other).__setattr__(key,
                                                   copy.deepcopy(value, memo))
diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py
index 14d3dec40f8eb0e93deeae4a952e3794e541f022..6d3cb23a7cf4bb7d8b987660af920d8f1963c405 100644
--- a/tests/test_dist/test_dist.py
+++ b/tests/test_dist/test_dist.py
@@ -2,382 +2,460 @@
 import os
 import os.path as osp
 import tempfile
+import unittest
+from unittest import TestCase
 from unittest.mock import patch
 
-import pytest
 import torch
 import torch.distributed as torch_dist
-import torch.multiprocessing as mp
 
 import mmengine.dist as dist
 from mmengine.dist.dist import sync_random_seed
+from mmengine.testing._internal import MultiProcessTestCase
 from mmengine.utils import TORCH_VERSION, digit_version
 
 
-def _test_all_reduce_non_dist():
-    data = torch.arange(2, dtype=torch.int64)
-    expected = torch.arange(2, dtype=torch.int64)
-    dist.all_reduce(data)
-    assert torch.allclose(data, expected)
+class TestDist(TestCase):
+    """Test dist module in non-distributed environment."""
 
+    def test_all_reduce(self):
+        data = torch.arange(2, dtype=torch.int64)
+        expected = torch.arange(2, dtype=torch.int64)
+        dist.all_reduce(data)
+        self.assertTrue(torch.allclose(data, expected))
 
-def _test_all_gather_non_dist():
-    data = torch.arange(2, dtype=torch.int64)
-    expected = torch.arange(2, dtype=torch.int64)
-    output = dist.all_gather(data)
-    assert torch.allclose(output[0], expected)
+    def test_all_gather(self):
+        data = torch.arange(2, dtype=torch.int64)
+        expected = torch.arange(2, dtype=torch.int64)
+        output = dist.all_gather(data)
+        self.assertTrue(torch.allclose(output[0], expected))
 
+    def test_gather(self):
+        data = torch.arange(2, dtype=torch.int64)
+        expected = torch.arange(2, dtype=torch.int64)
+        output = dist.gather(data)
+        self.assertTrue(torch.allclose(output[0], expected))
 
-def _test_gather_non_dist():
-    data = torch.arange(2, dtype=torch.int64)
-    expected = torch.arange(2, dtype=torch.int64)
-    output = dist.gather(data)
-    assert torch.allclose(output[0], expected)
+    def test_broadcast(self):
+        data = torch.arange(2, dtype=torch.int64)
+        expected = torch.arange(2, dtype=torch.int64)
+        dist.broadcast(data)
+        self.assertTrue(torch.allclose(data, expected))
 
+    @patch('numpy.random.randint', return_value=10)
+    def test_sync_random_seed(self, mock):
+        self.assertEqual(sync_random_seed(), 10)
 
-def _test_broadcast_non_dist():
-    data = torch.arange(2, dtype=torch.int64)
-    expected = torch.arange(2, dtype=torch.int64)
-    dist.broadcast(data)
-    assert torch.allclose(data, expected)
-
-
-@patch('numpy.random.randint', return_value=10)
-def _test_sync_random_seed_no_dist(mock):
-    assert sync_random_seed() == 10
-
-
-def _test_broadcast_object_list_no_dist():
-    with pytest.raises(AssertionError):
-        # input should be list of object
-        dist.broadcast_object_list('foo')
-
-    data = ['foo', 12, {1: 2}]
-    expected = ['foo', 12, {1: 2}]
-    dist.broadcast_object_list(data)
-    assert data == expected
-
-
-def _test_all_reduce_dict_no_dist():
-    with pytest.raises(AssertionError):
-        # input should be dict
-        dist.all_reduce_dict('foo')
-
-    data = {
-        'key1': torch.arange(2, dtype=torch.int64),
-        'key2': torch.arange(3, dtype=torch.int64)
-    }
-    expected = {
-        'key1': torch.arange(2, dtype=torch.int64),
-        'key2': torch.arange(3, dtype=torch.int64)
-    }
-    dist.all_reduce_dict(data)
-    for key in data:
-        assert torch.allclose(data[key], expected[key])
-
-
-def _test_all_gather_object_no_dist():
-    data = 'foo'
-    expected = 'foo'
-    gather_objects = dist.all_gather_object(data)
-    assert gather_objects[0] == expected
-
-
-def _test_gather_object_no_dist():
-    data = 'foo'
-    expected = 'foo'
-    gather_objects = dist.gather_object(data)
-    assert gather_objects[0] == expected
-
-
-def _test_collect_results_non_dist():
-    data = ['foo', {1: 2}]
-    size = 2
-    expected = ['foo', {1: 2}]
-
-    # test `device=cpu`
-    output = dist.collect_results(data, size, device='cpu')
-    assert output == expected
-
-    # test `device=gpu`
-    output = dist.collect_results(data, size, device='cpu')
-    assert output == expected
+    def test_broadcast_object_list(self):
+        with self.assertRaises(AssertionError):
+            # input should be list of object
+            dist.broadcast_object_list('foo')
 
+        data = ['foo', 12, {1: 2}]
+        expected = ['foo', 12, {1: 2}]
+        dist.broadcast_object_list(data)
+        self.assertEqual(data, expected)
+
+    def test_all_reduce_dict(self):
+        with self.assertRaises(AssertionError):
+            # input should be dict
+            dist.all_reduce_dict('foo')
+
+        data = {
+            'key1': torch.arange(2, dtype=torch.int64),
+            'key2': torch.arange(3, dtype=torch.int64)
+        }
+        expected = {
+            'key1': torch.arange(2, dtype=torch.int64),
+            'key2': torch.arange(3, dtype=torch.int64)
+        }
+        dist.all_reduce_dict(data)
+        for key in data:
+            self.assertTrue(torch.allclose(data[key], expected[key]))
 
-def init_process(rank, world_size, functions, backend='gloo'):
-    """Initialize the distributed environment."""
-    os.environ['MASTER_ADDR'] = '127.0.0.1'
-    os.environ['MASTER_PORT'] = '29505'
-    os.environ['RANK'] = str(rank)
+    def test_all_gather_object(self):
+        data = 'foo'
+        expected = 'foo'
+        gather_objects = dist.all_gather_object(data)
+        self.assertEqual(gather_objects[0], expected)
 
-    if backend == 'nccl':
-        num_gpus = torch.cuda.device_count()
-        torch.cuda.set_device(rank % num_gpus)
-        device = 'cuda'
-    else:
-        device = 'cpu'
+    def test_gather_object(self):
+        data = 'foo'
+        expected = 'foo'
+        gather_objects = dist.gather_object(data)
+        self.assertEqual(gather_objects[0], expected)
 
-    torch_dist.init_process_group(
-        backend=backend, rank=rank, world_size=world_size)
+    def test_collect_results(self):
+        data = ['foo', {1: 2}]
+        size = 2
+        expected = ['foo', {1: 2}]
+
+        # test `device=cpu`
+        output = dist.collect_results(data, size, device='cpu')
+        self.assertEqual(output, expected)
+
+        # test `device=gpu`
+        output = dist.collect_results(data, size, device='gpu')
+        self.assertEqual(output, expected)
+
+
+class TestDistWithGLOOBackend(MultiProcessTestCase):
+
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29505'
+        os.environ['RANK'] = str(rank)
+        torch_dist.init_process_group(
+            backend='gloo', rank=rank, world_size=world_size)
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_processes()
+
+    def test_all_reduce(self):
+        self._init_dist_env(self.rank, self.world_size)
+        for tensor_type, reduce_op in zip([torch.int64, torch.float32],
+                                          ['sum', 'mean']):
+            if dist.get_rank() == 0:
+                data = torch.tensor([1, 2], dtype=tensor_type)
+            else:
+                data = torch.tensor([3, 4], dtype=tensor_type)
+
+            if reduce_op == 'sum':
+                expected = torch.tensor([4, 6], dtype=tensor_type)
+            else:
+                expected = torch.tensor([2, 3], dtype=tensor_type)
+
+            dist.all_reduce(data, reduce_op)
+            self.assertTrue(torch.allclose(data, expected))
+
+    def test_all_gather(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = torch.tensor([0, 1])
+        else:
+            data = torch.tensor([1, 2])
 
-    for func in functions:
-        func(device)
+        expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
 
+        output = dist.all_gather(data)
+        self.assertTrue(
+            torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
 
-def main(functions, world_size=2, backend='gloo'):
-    try:
-        mp.spawn(
-            init_process,
-            args=(world_size, functions, backend),
-            nprocs=world_size)
-    except Exception:
-        pytest.fail(f'{backend} failed')
+    def test_gather(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = torch.tensor([0, 1])
+        else:
+            data = torch.tensor([1, 2])
 
+        output = dist.gather(data)
 
-def _test_all_reduce_dist(device):
-    for tensor_type, reduce_op in zip([torch.int64, torch.float32],
-                                      ['sum', 'mean']):
         if dist.get_rank() == 0:
-            data = torch.tensor([1, 2], dtype=tensor_type).to(device)
+            expected = [torch.tensor([0, 1]), torch.tensor([1, 2])]
+            for i in range(2):
+                assert torch.allclose(output[i], expected[i])
         else:
-            data = torch.tensor([3, 4], dtype=tensor_type).to(device)
+            assert output == []
 
-        if reduce_op == 'sum':
-            expected = torch.tensor([4, 6], dtype=tensor_type).to(device)
+    def test_broadcast_dist(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = torch.tensor([0, 1])
         else:
-            expected = torch.tensor([2, 3], dtype=tensor_type).to(device)
+            data = torch.tensor([1, 2])
 
-        dist.all_reduce(data, reduce_op)
+        expected = torch.tensor([0, 1])
+        dist.broadcast(data, 0)
         assert torch.allclose(data, expected)
 
+    def test_sync_random_seed(self):
+        self._init_dist_env(self.rank, self.world_size)
+        with patch.object(
+                torch, 'tensor',
+                return_value=torch.tensor(1024)) as mock_tensor:
+            output = dist.sync_random_seed()
+            assert output == 1024
+        mock_tensor.assert_called()
+
+    def test_broadcast_object_list(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = ['foo', 12, {1: 2}]
+        else:
+            data = [None, None, None]
+
+        expected = ['foo', 12, {1: 2}]
+        dist.broadcast_object_list(data)
+        self.assertEqual(data, expected)
+
+    def test_all_reduce_dict(self):
+        self._init_dist_env(self.rank, self.world_size)
+        for tensor_type, reduce_op in zip([torch.int64, torch.float32],
+                                          ['sum', 'mean']):
+            if dist.get_rank() == 0:
+                data = {
+                    'key1': torch.tensor([0, 1], dtype=tensor_type),
+                    'key2': torch.tensor([1, 2], dtype=tensor_type),
+                }
+            else:
+                data = {
+                    'key1': torch.tensor([2, 3], dtype=tensor_type),
+                    'key2': torch.tensor([3, 4], dtype=tensor_type),
+                }
+
+            if reduce_op == 'sum':
+                expected = {
+                    'key1': torch.tensor([2, 4], dtype=tensor_type),
+                    'key2': torch.tensor([4, 6], dtype=tensor_type),
+                }
+            else:
+                expected = {
+                    'key1': torch.tensor([1, 2], dtype=tensor_type),
+                    'key2': torch.tensor([2, 3], dtype=tensor_type),
+                }
+
+            dist.all_reduce_dict(data, reduce_op)
+
+            for key in data:
+                assert torch.allclose(data[key], expected[key])
+
+        # `torch.cat` in torch1.5 can not concatenate different types so we
+        # fallback to convert them all to float type.
+        if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
+            if dist.get_rank() == 0:
+                data = {
+                    'key1': torch.tensor([0, 1], dtype=torch.float32),
+                    'key2': torch.tensor([1, 2], dtype=torch.int32)
+                }
+            else:
+                data = {
+                    'key1': torch.tensor([2, 3], dtype=torch.float32),
+                    'key2': torch.tensor([3, 4], dtype=torch.int32),
+                }
 
-def _test_all_gather_dist(device):
-    if dist.get_rank() == 0:
-        data = torch.tensor([0, 1]).to(device)
-    else:
-        data = torch.tensor([1, 2]).to(device)
-
-    expected = [
-        torch.tensor([0, 1]).to(device),
-        torch.tensor([1, 2]).to(device)
-    ]
+            expected = {
+                'key1': torch.tensor([2, 4], dtype=torch.float32),
+                'key2': torch.tensor([4, 6], dtype=torch.float32),
+            }
 
-    output = dist.all_gather(data)
-    assert torch.allclose(output[dist.get_rank()], expected[dist.get_rank()])
+            dist.all_reduce_dict(data, 'sum')
 
+            for key in data:
+                assert torch.allclose(data[key], expected[key])
 
-def _test_gather_dist(device):
-    if dist.get_rank() == 0:
-        data = torch.tensor([0, 1]).to(device)
-    else:
-        data = torch.tensor([1, 2]).to(device)
+    def test_all_gather_object(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = 'foo'
+        else:
+            data = {1: 2}
 
-    output = dist.gather(data)
+        expected = ['foo', {1: 2}]
+        output = dist.all_gather_object(data)
 
-    if dist.get_rank() == 0:
-        expected = [
-            torch.tensor([0, 1]).to(device),
-            torch.tensor([1, 2]).to(device)
-        ]
-        for i in range(2):
-            assert torch.allclose(output[i], expected[i])
-    else:
-        assert output == []
+        self.assertEqual(output, expected)
 
+    def test_gather_object(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = 'foo'
+        else:
+            data = {1: 2}
 
-def _test_broadcast_dist(device):
-    if dist.get_rank() == 0:
-        data = torch.tensor([0, 1]).to(device)
-    else:
-        data = torch.tensor([1, 2]).to(device)
+        output = dist.gather_object(data, dst=0)
 
-    expected = torch.tensor([0, 1]).to(device)
-    dist.broadcast(data, 0)
-    assert torch.allclose(data, expected)
+        if dist.get_rank() == 0:
+            self.assertEqual(output, ['foo', {1: 2}])
+        else:
+            self.assertIsNone(output)
 
 
-def _test_sync_random_seed_dist(device):
-    with patch.object(
-            torch, 'tensor', return_value=torch.tensor(1024)) as mock_tensor:
-        output = dist.sync_random_seed()
-        assert output == 1024
-    mock_tensor.assert_called()
+@unittest.skipIf(
+    torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
+class TestDistWithNCCLBackend(MultiProcessTestCase):
 
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29505'
+        os.environ['RANK'] = str(rank)
 
-def _test_broadcast_object_list_dist(device):
-    if dist.get_rank() == 0:
-        data = ['foo', 12, {1: 2}]
-    else:
-        data = [None, None, None]
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+        torch_dist.init_process_group(
+            backend='nccl', rank=rank, world_size=world_size)
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_processes()
+
+    def test_all_reduce(self):
+        self._init_dist_env(self.rank, self.world_size)
+        for tensor_type, reduce_op in zip([torch.int64, torch.float32],
+                                          ['sum', 'mean']):
+            if dist.get_rank() == 0:
+                data = torch.tensor([1, 2], dtype=tensor_type).cuda()
+            else:
+                data = torch.tensor([3, 4], dtype=tensor_type).cuda()
+
+            if reduce_op == 'sum':
+                expected = torch.tensor([4, 6], dtype=tensor_type).cuda()
+            else:
+                expected = torch.tensor([2, 3], dtype=tensor_type).cuda()
+
+            dist.all_reduce(data, reduce_op)
+            self.assertTrue(torch.allclose(data, expected))
+
+    def test_all_gather(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = torch.tensor([0, 1]).cuda()
+        else:
+            data = torch.tensor([1, 2]).cuda()
 
-    expected = ['foo', 12, {1: 2}]
+        expected = [torch.tensor([0, 1]).cuda(), torch.tensor([1, 2]).cuda()]
 
-    dist.broadcast_object_list(data)
+        output = dist.all_gather(data)
+        self.assertTrue(
+            torch.allclose(output[dist.get_rank()], expected[dist.get_rank()]))
 
-    assert data == expected
+    def test_broadcast_dist(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = torch.tensor([0, 1]).cuda()
+        else:
+            data = torch.tensor([1, 2]).cuda()
 
+        expected = torch.tensor([0, 1]).cuda()
+        dist.broadcast(data, 0)
+        assert torch.allclose(data, expected)
 
-def _test_all_reduce_dict_dist(device):
-    for tensor_type, reduce_op in zip([torch.int64, torch.float32],
-                                      ['sum', 'mean']):
+    def test_sync_random_seed(self):
+        self._init_dist_env(self.rank, self.world_size)
+        with patch.object(
+                torch, 'tensor',
+                return_value=torch.tensor(1024)) as mock_tensor:
+            output = dist.sync_random_seed()
+            assert output == 1024
+        mock_tensor.assert_called()
+
+    def test_broadcast_object_list(self):
+        self._init_dist_env(self.rank, self.world_size)
         if dist.get_rank() == 0:
-            data = {
-                'key1': torch.tensor([0, 1], dtype=tensor_type).to(device),
-                'key2': torch.tensor([1, 2], dtype=tensor_type).to(device)
-            }
+            data = ['foo', 12, {1: 2}]
         else:
-            data = {
-                'key1': torch.tensor([2, 3], dtype=tensor_type).to(device),
-                'key2': torch.tensor([3, 4], dtype=tensor_type).to(device)
-            }
+            data = [None, None, None]
+
+        expected = ['foo', 12, {1: 2}]
+        dist.broadcast_object_list(data)
+        self.assertEqual(data, expected)
+
+    def test_all_reduce_dict(self):
+        self._init_dist_env(self.rank, self.world_size)
+        for tensor_type, reduce_op in zip([torch.int64, torch.float32],
+                                          ['sum', 'mean']):
+            if dist.get_rank() == 0:
+                data = {
+                    'key1': torch.tensor([0, 1], dtype=tensor_type).cuda(),
+                    'key2': torch.tensor([1, 2], dtype=tensor_type).cuda(),
+                }
+            else:
+                data = {
+                    'key1': torch.tensor([2, 3], dtype=tensor_type).cuda(),
+                    'key2': torch.tensor([3, 4], dtype=tensor_type).cuda(),
+                }
+
+            if reduce_op == 'sum':
+                expected = {
+                    'key1': torch.tensor([2, 4], dtype=tensor_type).cuda(),
+                    'key2': torch.tensor([4, 6], dtype=tensor_type).cuda(),
+                }
+            else:
+                expected = {
+                    'key1': torch.tensor([1, 2], dtype=tensor_type).cuda(),
+                    'key2': torch.tensor([2, 3], dtype=tensor_type).cuda(),
+                }
+
+            dist.all_reduce_dict(data, reduce_op)
+
+            for key in data:
+                assert torch.allclose(data[key], expected[key])
+
+        # `torch.cat` in torch1.5 can not concatenate different types so we
+        # fallback to convert them all to float type.
+        if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
+            if dist.get_rank() == 0:
+                data = {
+                    'key1': torch.tensor([0, 1], dtype=torch.float32).cuda(),
+                    'key2': torch.tensor([1, 2], dtype=torch.int32).cuda(),
+                }
+            else:
+                data = {
+                    'key1': torch.tensor([2, 3], dtype=torch.float32).cuda(),
+                    'key2': torch.tensor([3, 4], dtype=torch.int32).cuda(),
+                }
 
-        if reduce_op == 'sum':
-            expected = {
-                'key1': torch.tensor([2, 4], dtype=tensor_type).to(device),
-                'key2': torch.tensor([4, 6], dtype=tensor_type).to(device)
-            }
-        else:
             expected = {
-                'key1': torch.tensor([1, 2], dtype=tensor_type).to(device),
-                'key2': torch.tensor([2, 3], dtype=tensor_type).to(device)
+                'key1': torch.tensor([2, 4], dtype=torch.float32).cuda(),
+                'key2': torch.tensor([4, 6], dtype=torch.float32).cuda(),
             }
 
-        dist.all_reduce_dict(data, reduce_op)
+            dist.all_reduce_dict(data, 'sum')
 
-        for key in data:
-            assert torch.allclose(data[key], expected[key])
+            for key in data:
+                assert torch.allclose(data[key], expected[key])
 
-    # `torch.cat` in torch1.5 can not concatenate different types so we
-    # fallback to convert them all to float type.
-    if digit_version(TORCH_VERSION) == digit_version('1.5.0'):
+    def test_all_gather_object(self):
+        self._init_dist_env(self.rank, self.world_size)
         if dist.get_rank() == 0:
-            data = {
-                'key1': torch.tensor([0, 1], dtype=torch.float32).to(device),
-                'key2': torch.tensor([1, 2], dtype=torch.int32).to(device)
-            }
+            data = 'foo'
         else:
-            data = {
-                'key1': torch.tensor([2, 3], dtype=torch.float32).to(device),
-                'key2': torch.tensor([3, 4], dtype=torch.int32).to(device)
-            }
-
-        expected = {
-            'key1': torch.tensor([2, 4], dtype=torch.float32).to(device),
-            'key2': torch.tensor([4, 6], dtype=torch.float32).to(device)
-        }
-
-        dist.all_reduce_dict(data, 'sum')
-
-        for key in data:
-            assert torch.allclose(data[key], expected[key])
+            data = {1: 2}
 
+        expected = ['foo', {1: 2}]
+        output = dist.all_gather_object(data)
 
-def _test_all_gather_object_dist(device):
-    if dist.get_rank() == 0:
-        data = 'foo'
-    else:
-        data = {1: 2}
-
-    expected = ['foo', {1: 2}]
-    output = dist.all_gather_object(data)
-
-    assert output == expected
+        self.assertEqual(output, expected)
 
+    def test_collect_results(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            data = ['foo', {1: 2}]
+        else:
+            data = [24, {'a': 'b'}]
 
-def _test_gather_object_dist(device):
-    if dist.get_rank() == 0:
-        data = 'foo'
-    else:
-        data = {1: 2}
+        size = 4
 
-    output = dist.gather_object(data, dst=0)
+        expected = ['foo', 24, {1: 2}, {'a': 'b'}]
 
-    if dist.get_rank() == 0:
-        assert output == ['foo', {1: 2}]
-    else:
-        assert output is None
+        # test `device=cpu`
+        output = dist.collect_results(data, size, device='cpu')
+        if dist.get_rank() == 0:
+            self.assertEqual(output, expected)
+        else:
+            self.assertIsNone(output)
+
+        # test `device=cpu` and `tmpdir is not None`
+        tmpdir = tempfile.mkdtemp()
+        # broadcast tmpdir to all ranks to make it consistent
+        object_list = [tmpdir]
+        dist.broadcast_object_list(object_list)
+        output = dist.collect_results(
+            data, size, device='cpu', tmpdir=object_list[0])
+        if dist.get_rank() == 0:
+            self.assertEqual(output, expected)
+        else:
+            self.assertIsNone(output)
 
+        if dist.get_rank() == 0:
+            # object_list[0] will be removed by `dist.collect_results`
+            self.assertFalse(osp.exists(object_list[0]))
 
-def _test_collect_results_dist(device):
-    if dist.get_rank() == 0:
-        data = ['foo', {1: 2}]
-    else:
-        data = [24, {'a': 'b'}]
-
-    size = 4
-
-    expected = ['foo', 24, {1: 2}, {'a': 'b'}]
-
-    # test `device=cpu`
-    output = dist.collect_results(data, size, device='cpu')
-    if dist.get_rank() == 0:
-        assert output == expected
-    else:
-        assert output is None
-
-    # test `device=cpu` and `tmpdir is not None`
-    tmpdir = tempfile.mkdtemp()
-    # broadcast tmpdir to all ranks to make it consistent
-    object_list = [tmpdir]
-    dist.broadcast_object_list(object_list)
-    output = dist.collect_results(
-        data, size, device='cpu', tmpdir=object_list[0])
-    if dist.get_rank() == 0:
-        assert output == expected
-    else:
-        assert output is None
-
-    if dist.get_rank() == 0:
-        # object_list[0] will be removed by `dist.collect_results`
-        assert not osp.exists(object_list[0])
-
-    # test `device=gpu`
-    output = dist.collect_results(data, size, device='gpu')
-    if dist.get_rank() == 0:
-        assert output == expected
-    else:
-        assert output is None
-
-
-def test_non_distributed_env():
-    _test_all_reduce_non_dist()
-    _test_all_gather_non_dist()
-    _test_gather_non_dist()
-    _test_broadcast_non_dist()
-    _test_sync_random_seed_no_dist()
-    _test_broadcast_object_list_no_dist()
-    _test_all_reduce_dict_no_dist()
-    _test_all_gather_object_no_dist()
-    _test_gather_object_no_dist()
-    _test_collect_results_non_dist()
-
-
-def test_gloo_backend():
-    functions_to_test = [
-        _test_all_reduce_dist,
-        _test_all_gather_dist,
-        _test_gather_dist,
-        _test_broadcast_dist,
-        _test_sync_random_seed_dist,
-        _test_broadcast_object_list_dist,
-        _test_all_reduce_dict_dist,
-        _test_all_gather_object_dist,
-        _test_gather_object_dist,
-    ]
-    main(functions_to_test, backend='gloo')
-
-
-@pytest.mark.skipif(
-    torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
-def test_nccl_backend():
-    functions_to_test = [
-        _test_all_reduce_dist,
-        _test_all_gather_dist,
-        _test_broadcast_dist,
-        _test_sync_random_seed_dist,
-        _test_broadcast_object_list_dist,
-        _test_all_reduce_dict_dist,
-        _test_all_gather_object_dist,
-        _test_collect_results_dist,
-    ]
-    main(functions_to_test, backend='nccl')
+        # test `device=gpu`
+        output = dist.collect_results(data, size, device='gpu')
+        if dist.get_rank() == 0:
+            self.assertEqual(output, expected)
+        else:
+            self.assertIsNone(output)
diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py
index b4b74d4526d325c9620793a3ff916575d8b6e963..124551989a1d160e21ae4ec71c79d0c844500b78 100644
--- a/tests/test_dist/test_utils.py
+++ b/tests/test_dist/test_utils.py
@@ -1,158 +1,177 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import os
+import unittest
+from unittest import TestCase
 
-import pytest
 import torch
 import torch.distributed as torch_dist
-import torch.multiprocessing as mp
 
 import mmengine.dist as dist
+from mmengine.testing._internal import MultiProcessTestCase
 
 
-def _test_get_backend_non_dist():
-    assert dist.get_backend() is None
+class TestUtils(TestCase):
 
+    def test_get_backend(self):
+        self.assertIsNone(dist.get_backend())
 
-def _test_get_world_size_non_dist():
-    assert dist.get_world_size() == 1
+    def test_get_world_size(self):
+        self.assertEqual(dist.get_world_size(), 1)
 
+    def test_get_rank(self):
+        self.assertEqual(dist.get_rank(), 0)
 
-def _test_get_rank_non_dist():
-    assert dist.get_rank() == 0
+    def test_local_size(self):
+        self.assertEqual(dist.get_local_size(), 1)
 
+    def test_local_rank(self):
+        self.assertEqual(dist.get_local_rank(), 0)
 
-def _test_local_size_non_dist():
-    assert dist.get_local_size() == 1
+    def test_get_dist_info(self):
+        self.assertEqual(dist.get_dist_info(), (0, 1))
 
+    def test_is_main_process(self):
+        self.assertTrue(dist.is_main_process())
 
-def _test_local_rank_non_dist():
-    assert dist.get_local_rank() == 0
+    def test_master_only(self):
 
+        @dist.master_only
+        def fun():
+            assert dist.get_rank() == 0
 
-def _test_get_dist_info_non_dist():
-    assert dist.get_dist_info() == (0, 1)
+        fun()
 
+    def test_barrier(self):
+        dist.barrier()  # nothing is done
 
-def _test_is_main_process_non_dist():
-    assert dist.is_main_process()
 
+class TestUtilsWithGLOOBackend(MultiProcessTestCase):
 
-def _test_master_only_non_dist():
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29505'
+        os.environ['RANK'] = str(rank)
 
-    @dist.master_only
-    def fun():
-        assert dist.get_rank() == 0
+        torch_dist.init_process_group(
+            backend='gloo', rank=rank, world_size=world_size)
+        dist.init_local_group(0, world_size)
 
-    fun()
+    def setUp(self):
+        super().setUp()
+        self._spawn_processes()
 
+    def test_get_backend(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(dist.get_backend(), torch_dist.get_backend())
 
-def _test_barrier_non_dist():
-    dist.barrier()  # nothing is done
+    def test_get_world_size(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(dist.get_world_size(), 2)
 
+    def test_get_rank(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if torch_dist.get_rank() == 0:
+            self.assertEqual(dist.get_rank(), 0)
+        else:
+            self.assertEqual(dist.get_rank(), 1)
 
-def init_process(rank, world_size, functions, backend='gloo'):
-    """Initialize the distributed environment."""
-    os.environ['MASTER_ADDR'] = '127.0.0.1'
-    os.environ['MASTER_PORT'] = '29501'
-    os.environ['RANK'] = str(rank)
+    def test_local_size(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(dist.get_local_size(), 2)
 
-    if backend == 'nccl':
-        num_gpus = torch.cuda.device_count()
-        torch.cuda.set_device(rank % num_gpus)
-
-    torch_dist.init_process_group(
-        backend=backend, rank=rank, world_size=world_size)
-    dist.init_local_group(0, world_size)
-
-    for func in functions:
-        func()
-
-
-def main(functions, world_size=2, backend='gloo'):
-    try:
-        mp.spawn(
-            init_process,
-            args=(world_size, functions, backend),
-            nprocs=world_size)
-    except Exception:
-        pytest.fail('error')
-
-
-def _test_get_backend_dist():
-    assert dist.get_backend() == torch_dist.get_backend()
-
-
-def _test_get_world_size_dist():
-    assert dist.get_world_size() == 2
-
-
-def _test_get_rank_dist():
-    if torch_dist.get_rank() == 0:
-        assert dist.get_rank() == 0
-    else:
-        assert dist.get_rank() == 1
-
-
-def _test_local_size_dist():
-    assert dist.get_local_size() == 2
+    def test_local_rank(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(
+            torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
 
+    def test_get_dist_info(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            self.assertEqual(dist.get_dist_info(), (0, 2))
+        else:
+            self.assertEqual(dist.get_dist_info(), (1, 2))
 
-def _test_local_rank_dist():
-    torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
+    def test_is_main_process(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            self.assertTrue(dist.is_main_process())
+        else:
+            self.assertFalse(dist.is_main_process())
 
+    def test_master_only(self):
+        self._init_dist_env(self.rank, self.world_size)
 
-def _test_get_dist_info_dist():
-    if dist.get_rank() == 0:
-        assert dist.get_dist_info() == (0, 2)
-    else:
-        assert dist.get_dist_info() == (1, 2)
+        @dist.master_only
+        def fun():
+            assert dist.get_rank() == 0
 
+        fun()
 
-def _test_is_main_process_dist():
-    if dist.get_rank() == 0:
-        assert dist.is_main_process()
-    else:
-        assert not dist.is_main_process()
 
+@unittest.skipIf(
+    torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
+class TestUtilsWithNCCLBackend(MultiProcessTestCase):
 
-def _test_master_only_dist():
-
-    @dist.master_only
-    def fun():
-        assert dist.get_rank() == 0
-
-    fun()
-
-
-def test_non_distributed_env():
-    _test_get_backend_non_dist()
-    _test_get_world_size_non_dist()
-    _test_get_rank_non_dist()
-    _test_local_size_non_dist()
-    _test_local_rank_non_dist()
-    _test_get_dist_info_non_dist()
-    _test_is_main_process_non_dist()
-    _test_master_only_non_dist()
-    _test_barrier_non_dist()
-
-
-functions_to_test = [
-    _test_get_backend_dist,
-    _test_get_world_size_dist,
-    _test_get_rank_dist,
-    _test_local_size_dist,
-    _test_local_rank_dist,
-    _test_get_dist_info_dist,
-    _test_is_main_process_dist,
-    _test_master_only_dist,
-]
-
-
-def test_gloo_backend():
-    main(functions_to_test)
-
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29505'
+        os.environ['RANK'] = str(rank)
 
-@pytest.mark.skipif(
-    torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
-def test_nccl_backend():
-    main(functions_to_test, backend='nccl')
+        num_gpus = torch.cuda.device_count()
+        torch.cuda.set_device(rank % num_gpus)
+        torch_dist.init_process_group(
+            backend='nccl', rank=rank, world_size=world_size)
+        dist.init_local_group(0, world_size)
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_processes()
+
+    def test_get_backend(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(dist.get_backend(), torch_dist.get_backend())
+
+    def test_get_world_size(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(dist.get_world_size(), 2)
+
+    def test_get_rank(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if torch_dist.get_rank() == 0:
+            self.assertEqual(dist.get_rank(), 0)
+        else:
+            self.assertEqual(dist.get_rank(), 1)
+
+    def test_local_size(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(dist.get_local_size(), 2)
+
+    def test_local_rank(self):
+        self._init_dist_env(self.rank, self.world_size)
+        self.assertEqual(
+            torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
+
+    def test_get_dist_info(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            self.assertEqual(dist.get_dist_info(), (0, 2))
+        else:
+            self.assertEqual(dist.get_dist_info(), (1, 2))
+
+    def test_is_main_process(self):
+        self._init_dist_env(self.rank, self.world_size)
+        if dist.get_rank() == 0:
+            self.assertTrue(dist.is_main_process())
+        else:
+            self.assertFalse(dist.is_main_process())
+
+    def test_master_only(self):
+        self._init_dist_env(self.rank, self.world_size)
+
+        @dist.master_only
+        def fun():
+            assert dist.get_rank() == 0
+
+        fun()