diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py
index dcf447bed1cdf74ff56d9d4480ecd9803e33633d..bc90634e51693d399f8a4c74a80c0628955a9cc6 100644
--- a/mmengine/evaluator/base.py
+++ b/mmengine/evaluator/base.py
@@ -1,17 +1,11 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-import os.path as osp
-import pickle
-import shutil
-import tempfile
 import warnings
 from abc import ABCMeta, abstractmethod
 from typing import Any, List, Optional, Sequence, Tuple, Union
 
-import torch
-import torch.distributed as dist
-
 from mmengine.data import BaseDataSample
-from mmengine.utils import mkdir_or_exist
+from mmengine.dist import (broadcast_object_list, collect_results,
+                           is_main_process)
 
 
 class BaseEvaluator(metaclass=ABCMeta):
@@ -43,11 +37,6 @@ class BaseEvaluator(metaclass=ABCMeta):
         self._dataset_meta: Union[None, dict] = None
         self.collect_device = collect_device
         self.results: List[Any] = []
-
-        rank, world_size = get_dist_info()
-        self.rank = rank
-        self.world_size = world_size
-
         self.prefix = prefix or self.default_prefix
         if self.prefix is None:
             warnings.warn('The prefix is not set in evaluator class '
@@ -108,131 +97,22 @@ class BaseEvaluator(metaclass=ABCMeta):
                 'ensure that the processed results are properly added into '
                 '`self._results` in `process` method.')
 
-        if self.world_size == 1:
-            # non-distributed
-            results = self.results
-        else:
-            results = collect_results(self.results, size, self.collect_device)
+        results = collect_results(self.results, size, self.collect_device)
 
-        if self.rank == 0:
-            # TODO: replace with mmengine.dist.master_only
-            metrics = self.compute_metrics(results)
+        if is_main_process():
+            _metrics = self.compute_metrics(results)  # type: ignore
             # Add prefix to metric names
             if self.prefix:
-                metrics = {
+                _metrics = {
                     '/'.join((self.prefix, k)): v
-                    for k, v in metrics.items()
+                    for k, v in _metrics.items()
                 }
-            metrics = [metrics]  # type: ignore
+            metrics = [_metrics]
         else:
             metrics = [None]  # type: ignore
 
-        # TODO: replace with mmengine.dist.broadcast
-        if self.world_size > 1:
-            metrics = dist.broadcast_object_list(metrics)
+        broadcast_object_list(metrics)
 
         # reset the results list
         self.results.clear()
         return metrics[0]
-
-
-# TODO: replace with mmengine.dist.get_dist_info
-def get_dist_info():
-    if dist.is_available() and dist.is_initialized():
-        rank = dist.get_rank()
-        world_size = dist.get_world_size()
-    else:
-        rank = 0
-        world_size = 1
-    return rank, world_size
-
-
-# TODO: replace with mmengine.dist.collect_results
-def collect_results(results, size, device='cpu'):
-    """Collected results in distributed environments."""
-    # TODO: replace with mmengine.dist.collect_results
-    if device == 'gpu':
-        return collect_results_gpu(results, size)
-    elif device == 'cpu':
-        return collect_results_cpu(results, size)
-    else:
-        NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}")
-
-
-# TODO: replace with mmengine.dist.collect_results
-def collect_results_cpu(result_part, size, tmpdir=None):
-    rank, world_size = get_dist_info()
-    # create a tmp dir if it is not specified
-    if tmpdir is None:
-        MAX_LEN = 512
-        # 32 is whitespace
-        dir_tensor = torch.full((MAX_LEN, ),
-                                32,
-                                dtype=torch.uint8,
-                                device='cuda')
-        if rank == 0:
-            mkdir_or_exist('.dist_test')
-            tmpdir = tempfile.mkdtemp(dir='.dist_test')
-            tmpdir = torch.tensor(
-                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
-            dir_tensor[:len(tmpdir)] = tmpdir
-        dist.broadcast(dir_tensor, 0)
-        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
-    else:
-        mkdir_or_exist(tmpdir)
-    # dump the part result to the dir
-    with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
-        pickle.dump(result_part, f, protocol=2)
-    dist.barrier()
-    # collect all parts
-    if rank != 0:
-        return None
-    else:
-        # load results of all parts from tmp dir
-        part_list = []
-        for i in range(world_size):
-            with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f:
-                part_list.append(pickle.load(f))
-        # sort the results
-        ordered_results = []
-        for res in zip(*part_list):
-            ordered_results.extend(list(res))
-        # the dataloader may pad some samples
-        ordered_results = ordered_results[:size]
-        # remove tmp dir
-        shutil.rmtree(tmpdir)
-        return ordered_results
-
-
-# TODO: replace with mmengine.dist.collect_results
-def collect_results_gpu(result_part, size):
-    rank, world_size = get_dist_info()
-    # dump result part to tensor with pickle
-    part_tensor = torch.tensor(
-        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
-    # gather all result part tensor shape
-    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
-    shape_list = [shape_tensor.clone() for _ in range(world_size)]
-    dist.all_gather(shape_list, shape_tensor)
-    # padding result part tensor to max length
-    shape_max = torch.tensor(shape_list).max()
-    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
-    part_send[:shape_tensor[0]] = part_tensor
-    part_recv_list = [
-        part_tensor.new_zeros(shape_max) for _ in range(world_size)
-    ]
-    # gather all result part
-    dist.all_gather(part_recv_list, part_send)
-
-    if rank == 0:
-        part_list = []
-        for recv, shape in zip(part_recv_list, shape_list):
-            part_list.append(
-                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
-        # sort the results
-        ordered_results = []
-        for res in zip(*part_list):
-            ordered_results.extend(list(res))
-        # the dataloader may pad some samples
-        ordered_results = ordered_results[:size]
-        return ordered_results
diff --git a/mmengine/evaluator/utils.py b/mmengine/evaluator/utils.py
index 17c3229bd6e77d90964c6522f04600c878933056..6981c881b95e1b6c6c859958a3b6f73049e2f2af 100644
--- a/mmengine/evaluator/utils.py
+++ b/mmengine/evaluator/utils.py
@@ -1,3 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
 from typing import Any, Dict