From 6d73b6cdf20992e6ccf5d1ad0d528c1c3a347094 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 13 Mar 2022 16:58:23 +0800 Subject: [PATCH] [Refactor] Use mmengine distributed in evaluator (#123) * [Refactor] Use mmengine distributed in evaluator * remove 'TODO' comment --- mmengine/evaluator/base.py | 138 +++--------------------------------- mmengine/evaluator/utils.py | 1 + 2 files changed, 10 insertions(+), 129 deletions(-) diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py index dcf447be..bc90634e 100644 --- a/mmengine/evaluator/base.py +++ b/mmengine/evaluator/base.py @@ -1,17 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -import pickle -import shutil -import tempfile import warnings from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence, Tuple, Union -import torch -import torch.distributed as dist - from mmengine.data import BaseDataSample -from mmengine.utils import mkdir_or_exist +from mmengine.dist import (broadcast_object_list, collect_results, + is_main_process) class BaseEvaluator(metaclass=ABCMeta): @@ -43,11 +37,6 @@ class BaseEvaluator(metaclass=ABCMeta): self._dataset_meta: Union[None, dict] = None self.collect_device = collect_device self.results: List[Any] = [] - - rank, world_size = get_dist_info() - self.rank = rank - self.world_size = world_size - self.prefix = prefix or self.default_prefix if self.prefix is None: warnings.warn('The prefix is not set in evaluator class ' @@ -108,131 +97,22 @@ class BaseEvaluator(metaclass=ABCMeta): 'ensure that the processed results are properly added into ' '`self._results` in `process` method.') - if self.world_size == 1: - # non-distributed - results = self.results - else: - results = collect_results(self.results, size, self.collect_device) + results = collect_results(self.results, size, self.collect_device) - if self.rank == 0: - # TODO: replace with mmengine.dist.master_only - metrics = self.compute_metrics(results) + if is_main_process(): + _metrics = self.compute_metrics(results) # type: ignore # Add prefix to metric names if self.prefix: - metrics = { + _metrics = { '/'.join((self.prefix, k)): v - for k, v in metrics.items() + for k, v in _metrics.items() } - metrics = [metrics] # type: ignore + metrics = [_metrics] else: metrics = [None] # type: ignore - # TODO: replace with mmengine.dist.broadcast - if self.world_size > 1: - metrics = dist.broadcast_object_list(metrics) + broadcast_object_list(metrics) # reset the results list self.results.clear() return metrics[0] - - -# TODO: replace with mmengine.dist.get_dist_info -def get_dist_info(): - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - return rank, world_size - - -# TODO: replace with mmengine.dist.collect_results -def collect_results(results, size, device='cpu'): - """Collected results in distributed environments.""" - # TODO: replace with mmengine.dist.collect_results - if device == 'gpu': - return collect_results_gpu(results, size) - elif device == 'cpu': - return collect_results_cpu(results, size) - else: - NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}") - - -# TODO: replace with mmengine.dist.collect_results -def collect_results_cpu(result_part, size, tmpdir=None): - rank, world_size = get_dist_info() - # create a tmp dir if it is not specified - if tmpdir is None: - MAX_LEN = 512 - # 32 is whitespace - dir_tensor = torch.full((MAX_LEN, ), - 32, - dtype=torch.uint8, - device='cuda') - if rank == 0: - mkdir_or_exist('.dist_test') - tmpdir = tempfile.mkdtemp(dir='.dist_test') - tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') - dir_tensor[:len(tmpdir)] = tmpdir - dist.broadcast(dir_tensor, 0) - tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() - else: - mkdir_or_exist(tmpdir) - # dump the part result to the dir - with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: - pickle.dump(result_part, f, protocol=2) - dist.barrier() - # collect all parts - if rank != 0: - return None - else: - # load results of all parts from tmp dir - part_list = [] - for i in range(world_size): - with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f: - part_list.append(pickle.load(f)) - # sort the results - ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - # remove tmp dir - shutil.rmtree(tmpdir) - return ordered_results - - -# TODO: replace with mmengine.dist.collect_results -def collect_results_gpu(result_part, size): - rank, world_size = get_dist_info() - # dump result part to tensor with pickle - part_tensor = torch.tensor( - bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') - # gather all result part tensor shape - shape_tensor = torch.tensor(part_tensor.shape, device='cuda') - shape_list = [shape_tensor.clone() for _ in range(world_size)] - dist.all_gather(shape_list, shape_tensor) - # padding result part tensor to max length - shape_max = torch.tensor(shape_list).max() - part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') - part_send[:shape_tensor[0]] = part_tensor - part_recv_list = [ - part_tensor.new_zeros(shape_max) for _ in range(world_size) - ] - # gather all result part - dist.all_gather(part_recv_list, part_send) - - if rank == 0: - part_list = [] - for recv, shape in zip(part_recv_list, shape_list): - part_list.append( - pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) - # sort the results - ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - return ordered_results diff --git a/mmengine/evaluator/utils.py b/mmengine/evaluator/utils.py index 17c3229b..6981c881 100644 --- a/mmengine/evaluator/utils.py +++ b/mmengine/evaluator/utils.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Dict -- GitLab