Skip to content
Snippets Groups Projects
Commit 93d22757 authored by ZwwWayne's avatar ZwwWayne
Browse files

Merge branch 'main' into adapt

parents 59cc08e3 66e52883
No related branches found
No related tags found
No related merge requests found
...@@ -229,7 +229,7 @@ class BaseDataset(Dataset): ...@@ -229,7 +229,7 @@ class BaseDataset(Dataset):
self.test_mode = test_mode self.test_mode = test_mode
self.max_refetch = max_refetch self.max_refetch = max_refetch
self.data_list: List[dict] = [] self.data_list: List[dict] = []
self.date_bytes: np.ndarray self.data_bytes: np.ndarray
# Set meta information. # Set meta information.
self._metainfo = self._get_meta_info(copy.deepcopy(metainfo)) self._metainfo = self._get_meta_info(copy.deepcopy(metainfo))
...@@ -259,7 +259,7 @@ class BaseDataset(Dataset): ...@@ -259,7 +259,7 @@ class BaseDataset(Dataset):
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item() end_addr = self.data_address[idx].item()
bytes = memoryview( bytes = memoryview(
self.date_bytes[start_addr:end_addr]) # type: ignore self.data_bytes[start_addr:end_addr]) # type: ignore
data_info = pickle.loads(bytes) # type: ignore data_info = pickle.loads(bytes) # type: ignore
else: else:
data_info = self.data_list[idx] data_info = self.data_list[idx]
...@@ -302,7 +302,7 @@ class BaseDataset(Dataset): ...@@ -302,7 +302,7 @@ class BaseDataset(Dataset):
# serialize data_list # serialize data_list
if self.serialize_data: if self.serialize_data:
self.date_bytes, self.data_address = self._serialize_data() self.data_bytes, self.data_address = self._serialize_data()
self._fully_initialized = True self._fully_initialized = True
...@@ -575,7 +575,7 @@ class BaseDataset(Dataset): ...@@ -575,7 +575,7 @@ class BaseDataset(Dataset):
# Get subset of data from serialized data or data information sequence # Get subset of data from serialized data or data information sequence
# according to `self.serialize_data`. # according to `self.serialize_data`.
if self.serialize_data: if self.serialize_data:
self.date_bytes, self.data_address = \ self.data_bytes, self.data_address = \
self._get_serialized_subset(indices) self._get_serialized_subset(indices)
else: else:
self.data_list = self._get_unserialized_subset(indices) self.data_list = self._get_unserialized_subset(indices)
...@@ -626,9 +626,9 @@ class BaseDataset(Dataset): ...@@ -626,9 +626,9 @@ class BaseDataset(Dataset):
sub_dataset = self._copy_without_annotation() sub_dataset = self._copy_without_annotation()
# Get subset of dataset with serialize and unserialized data. # Get subset of dataset with serialize and unserialized data.
if self.serialize_data: if self.serialize_data:
date_bytes, data_address = \ data_bytes, data_address = \
self._get_serialized_subset(indices) self._get_serialized_subset(indices)
sub_dataset.date_bytes = date_bytes.copy() sub_dataset.data_bytes = data_bytes.copy()
sub_dataset.data_address = data_address.copy() sub_dataset.data_address = data_address.copy()
else: else:
data_list = self._get_unserialized_subset(indices) data_list = self._get_unserialized_subset(indices)
...@@ -650,7 +650,7 @@ class BaseDataset(Dataset): ...@@ -650,7 +650,7 @@ class BaseDataset(Dataset):
Tuple[np.ndarray, np.ndarray]: subset of serialized data Tuple[np.ndarray, np.ndarray]: subset of serialized data
information. information.
""" """
sub_date_bytes: Union[List, np.ndarray] sub_data_bytes: Union[List, np.ndarray]
sub_data_address: Union[List, np.ndarray] sub_data_address: Union[List, np.ndarray]
if isinstance(indices, int): if isinstance(indices, int):
if indices >= 0: if indices >= 0:
...@@ -661,7 +661,7 @@ class BaseDataset(Dataset): ...@@ -661,7 +661,7 @@ class BaseDataset(Dataset):
if indices > 0 else 0 if indices > 0 else 0
# Slicing operation of `np.ndarray` does not trigger a memory # Slicing operation of `np.ndarray` does not trigger a memory
# copy. # copy.
sub_date_bytes = self.date_bytes[:end_addr] sub_data_bytes = self.data_bytes[:end_addr]
# Since the buffer size of first few data information is not # Since the buffer size of first few data information is not
# changed, # changed,
sub_data_address = self.data_address[:indices] sub_data_address = self.data_address[:indices]
...@@ -671,11 +671,11 @@ class BaseDataset(Dataset): ...@@ -671,11 +671,11 @@ class BaseDataset(Dataset):
# Return the last few data information. # Return the last few data information.
ignored_bytes_size = self.data_address[indices - 1] ignored_bytes_size = self.data_address[indices - 1]
start_addr = self.data_address[indices - 1].item() start_addr = self.data_address[indices - 1].item()
sub_date_bytes = self.date_bytes[start_addr:] sub_data_bytes = self.data_bytes[start_addr:]
sub_data_address = self.data_address[indices:] sub_data_address = self.data_address[indices:]
sub_data_address = sub_data_address - ignored_bytes_size sub_data_address = sub_data_address - ignored_bytes_size
elif isinstance(indices, Sequence): elif isinstance(indices, Sequence):
sub_date_bytes = [] sub_data_bytes = []
sub_data_address = [] sub_data_address = []
for idx in indices: for idx in indices:
assert len(self) > idx >= -len(self) assert len(self) > idx >= -len(self)
...@@ -683,20 +683,20 @@ class BaseDataset(Dataset): ...@@ -683,20 +683,20 @@ class BaseDataset(Dataset):
self.data_address[idx - 1].item() self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item() end_addr = self.data_address[idx].item()
# Get data information by address. # Get data information by address.
sub_date_bytes.append(self.date_bytes[start_addr:end_addr]) sub_data_bytes.append(self.data_bytes[start_addr:end_addr])
# Get data information size. # Get data information size.
sub_data_address.append(end_addr - start_addr) sub_data_address.append(end_addr - start_addr)
# Handle indices is an empty list. # Handle indices is an empty list.
if sub_date_bytes: if sub_data_bytes:
sub_date_bytes = np.concatenate(sub_date_bytes) sub_data_bytes = np.concatenate(sub_data_bytes)
sub_data_address = np.cumsum(sub_data_address) sub_data_address = np.cumsum(sub_data_address)
else: else:
sub_date_bytes = np.array([]) sub_data_bytes = np.array([])
sub_data_address = np.array([]) sub_data_address = np.array([])
else: else:
raise TypeError('indices should be a int or sequence of int, ' raise TypeError('indices should be a int or sequence of int, '
f'but got {type(indices)}') f'but got {type(indices)}')
return sub_date_bytes, sub_data_address # type: ignore return sub_data_bytes, sub_data_address # type: ignore
def _get_unserialized_subset(self, indices: Union[Sequence[int], def _get_unserialized_subset(self, indices: Union[Sequence[int],
int]) -> list: int]) -> list:
...@@ -795,7 +795,7 @@ class BaseDataset(Dataset): ...@@ -795,7 +795,7 @@ class BaseDataset(Dataset):
def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset':
"""Deepcopy for all attributes other than ``data_list``, """Deepcopy for all attributes other than ``data_list``,
``data_address`` and ``date_bytes``. ``data_address`` and ``data_bytes``.
Args: Args:
memo: Memory dict which used to reconstruct complex object memo: Memory dict which used to reconstruct complex object
...@@ -806,7 +806,7 @@ class BaseDataset(Dataset): ...@@ -806,7 +806,7 @@ class BaseDataset(Dataset):
memo[id(self)] = other memo[id(self)] = other
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
if key in ['data_list', 'data_address', 'date_bytes']: if key in ['data_list', 'data_address', 'data_bytes']:
continue continue
super(BaseDataset, other).__setattr__(key, super(BaseDataset, other).__setattr__(key,
copy.deepcopy(value, memo)) copy.deepcopy(value, memo))
......
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import unittest
from unittest import TestCase
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist import mmengine.dist as dist
from mmengine.testing._internal import MultiProcessTestCase
def _test_get_backend_non_dist(): class TestUtils(TestCase):
assert dist.get_backend() is None
def test_get_backend(self):
self.assertIsNone(dist.get_backend())
def _test_get_world_size_non_dist(): def test_get_world_size(self):
assert dist.get_world_size() == 1 self.assertEqual(dist.get_world_size(), 1)
def test_get_rank(self):
self.assertEqual(dist.get_rank(), 0)
def _test_get_rank_non_dist(): def test_local_size(self):
assert dist.get_rank() == 0 self.assertEqual(dist.get_local_size(), 1)
def test_local_rank(self):
self.assertEqual(dist.get_local_rank(), 0)
def _test_local_size_non_dist(): def test_get_dist_info(self):
assert dist.get_local_size() == 1 self.assertEqual(dist.get_dist_info(), (0, 1))
def test_is_main_process(self):
self.assertTrue(dist.is_main_process())
def _test_local_rank_non_dist(): def test_master_only(self):
assert dist.get_local_rank() == 0
@dist.master_only
def fun():
assert dist.get_rank() == 0
def _test_get_dist_info_non_dist(): fun()
assert dist.get_dist_info() == (0, 1)
def test_barrier(self):
dist.barrier() # nothing is done
def _test_is_main_process_non_dist():
assert dist.is_main_process()
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
def _test_master_only_non_dist(): def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
@dist.master_only torch_dist.init_process_group(
def fun(): backend='gloo', rank=rank, world_size=world_size)
assert dist.get_rank() == 0 dist.init_local_group(0, world_size)
fun() def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def _test_barrier_non_dist(): def test_get_world_size(self):
dist.barrier() # nothing is done self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def init_process(rank, world_size, functions, backend='gloo'): def test_local_size(self):
"""Initialize the distributed environment.""" self._init_dist_env(self.rank, self.world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1' self.assertEqual(dist.get_local_size(), 2)
os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank)
if backend == 'nccl': def test_local_rank(self):
num_gpus = torch.cuda.device_count() self._init_dist_env(self.rank, self.world_size)
torch.cuda.set_device(rank % num_gpus) self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
for func in functions:
func()
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail('error')
def _test_get_backend_dist():
assert dist.get_backend() == torch_dist.get_backend()
def _test_get_world_size_dist():
assert dist.get_world_size() == 2
def _test_get_rank_dist():
if torch_dist.get_rank() == 0:
assert dist.get_rank() == 0
else:
assert dist.get_rank() == 1
def _test_local_size_dist():
assert dist.get_local_size() == 2
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def _test_local_rank_dist(): def test_is_main_process(self):
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank() self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
def _test_get_dist_info_dist(): @dist.master_only
if dist.get_rank() == 0: def fun():
assert dist.get_dist_info() == (0, 2) assert dist.get_rank() == 0
else:
assert dist.get_dist_info() == (1, 2)
fun()
def _test_is_main_process_dist():
if dist.get_rank() == 0:
assert dist.is_main_process()
else:
assert not dist.is_main_process()
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
def _test_master_only_dist(): def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
@dist.master_only os.environ['MASTER_ADDR'] = '127.0.0.1'
def fun(): os.environ['MASTER_PORT'] = '29505'
assert dist.get_rank() == 0 os.environ['RANK'] = str(rank)
fun()
def test_non_distributed_env():
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
@pytest.mark.skipif( num_gpus = torch.cuda.device_count()
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') torch.cuda.set_device(rank % num_gpus)
def test_nccl_backend(): torch_dist.init_process_group(
main(functions_to_test, backend='nccl') backend='nccl', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment