Skip to content
Snippets Groups Projects
Unverified Commit 6d73b6cd authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Refactor] Use mmengine distributed in evaluator (#123)

* [Refactor] Use mmengine distributed in evaluator

* remove 'TODO' comment
parent 4d49de7d
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
import torch.distributed as dist
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.utils import mkdir_or_exist from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process)
class BaseEvaluator(metaclass=ABCMeta): class BaseEvaluator(metaclass=ABCMeta):
...@@ -43,11 +37,6 @@ class BaseEvaluator(metaclass=ABCMeta): ...@@ -43,11 +37,6 @@ class BaseEvaluator(metaclass=ABCMeta):
self._dataset_meta: Union[None, dict] = None self._dataset_meta: Union[None, dict] = None
self.collect_device = collect_device self.collect_device = collect_device
self.results: List[Any] = [] self.results: List[Any] = []
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.prefix = prefix or self.default_prefix self.prefix = prefix or self.default_prefix
if self.prefix is None: if self.prefix is None:
warnings.warn('The prefix is not set in evaluator class ' warnings.warn('The prefix is not set in evaluator class '
...@@ -108,131 +97,22 @@ class BaseEvaluator(metaclass=ABCMeta): ...@@ -108,131 +97,22 @@ class BaseEvaluator(metaclass=ABCMeta):
'ensure that the processed results are properly added into ' 'ensure that the processed results are properly added into '
'`self._results` in `process` method.') '`self._results` in `process` method.')
if self.world_size == 1: results = collect_results(self.results, size, self.collect_device)
# non-distributed
results = self.results
else:
results = collect_results(self.results, size, self.collect_device)
if self.rank == 0: if is_main_process():
# TODO: replace with mmengine.dist.master_only _metrics = self.compute_metrics(results) # type: ignore
metrics = self.compute_metrics(results)
# Add prefix to metric names # Add prefix to metric names
if self.prefix: if self.prefix:
metrics = { _metrics = {
'/'.join((self.prefix, k)): v '/'.join((self.prefix, k)): v
for k, v in metrics.items() for k, v in _metrics.items()
} }
metrics = [metrics] # type: ignore metrics = [_metrics]
else: else:
metrics = [None] # type: ignore metrics = [None] # type: ignore
# TODO: replace with mmengine.dist.broadcast broadcast_object_list(metrics)
if self.world_size > 1:
metrics = dist.broadcast_object_list(metrics)
# reset the results list # reset the results list
self.results.clear() self.results.clear()
return metrics[0] return metrics[0]
# TODO: replace with mmengine.dist.get_dist_info
def get_dist_info():
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
# TODO: replace with mmengine.dist.collect_results
def collect_results(results, size, device='cpu'):
"""Collected results in distributed environments."""
# TODO: replace with mmengine.dist.collect_results
if device == 'gpu':
return collect_results_gpu(results, size)
elif device == 'cpu':
return collect_results_cpu(results, size)
else:
NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}")
# TODO: replace with mmengine.dist.collect_results
def collect_results_cpu(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
mkdir_or_exist('.dist_test')
tmpdir = tempfile.mkdtemp(dir='.dist_test')
tmpdir = torch.tensor(
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mkdir_or_exist(tmpdir)
# dump the part result to the dir
with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
pickle.dump(result_part, f, protocol=2)
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f:
part_list.append(pickle.load(f))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results
# TODO: replace with mmengine.dist.collect_results
def collect_results_gpu(result_part, size):
rank, world_size = get_dist_info()
# dump result part to tensor with pickle
part_tensor = torch.tensor(
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
part_send[:shape_tensor[0]] = part_tensor
part_recv_list = [
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
dist.all_gather(part_recv_list, part_send)
if rank == 0:
part_list = []
for recv, shape in zip(part_recv_list, shape_list):
part_list.append(
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict from typing import Any, Dict
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment