Skip to content
Snippets Groups Projects
Commit 93d22757 authored by ZwwWayne's avatar ZwwWayne
Browse files

Merge branch 'main' into adapt

parents 59cc08e3 66e52883
No related branches found
No related tags found
No related merge requests found
......@@ -229,7 +229,7 @@ class BaseDataset(Dataset):
self.test_mode = test_mode
self.max_refetch = max_refetch
self.data_list: List[dict] = []
self.date_bytes: np.ndarray
self.data_bytes: np.ndarray
# Set meta information.
self._metainfo = self._get_meta_info(copy.deepcopy(metainfo))
......@@ -259,7 +259,7 @@ class BaseDataset(Dataset):
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item()
bytes = memoryview(
self.date_bytes[start_addr:end_addr]) # type: ignore
self.data_bytes[start_addr:end_addr]) # type: ignore
data_info = pickle.loads(bytes) # type: ignore
else:
data_info = self.data_list[idx]
......@@ -302,7 +302,7 @@ class BaseDataset(Dataset):
# serialize data_list
if self.serialize_data:
self.date_bytes, self.data_address = self._serialize_data()
self.data_bytes, self.data_address = self._serialize_data()
self._fully_initialized = True
......@@ -575,7 +575,7 @@ class BaseDataset(Dataset):
# Get subset of data from serialized data or data information sequence
# according to `self.serialize_data`.
if self.serialize_data:
self.date_bytes, self.data_address = \
self.data_bytes, self.data_address = \
self._get_serialized_subset(indices)
else:
self.data_list = self._get_unserialized_subset(indices)
......@@ -626,9 +626,9 @@ class BaseDataset(Dataset):
sub_dataset = self._copy_without_annotation()
# Get subset of dataset with serialize and unserialized data.
if self.serialize_data:
date_bytes, data_address = \
data_bytes, data_address = \
self._get_serialized_subset(indices)
sub_dataset.date_bytes = date_bytes.copy()
sub_dataset.data_bytes = data_bytes.copy()
sub_dataset.data_address = data_address.copy()
else:
data_list = self._get_unserialized_subset(indices)
......@@ -650,7 +650,7 @@ class BaseDataset(Dataset):
Tuple[np.ndarray, np.ndarray]: subset of serialized data
information.
"""
sub_date_bytes: Union[List, np.ndarray]
sub_data_bytes: Union[List, np.ndarray]
sub_data_address: Union[List, np.ndarray]
if isinstance(indices, int):
if indices >= 0:
......@@ -661,7 +661,7 @@ class BaseDataset(Dataset):
if indices > 0 else 0
# Slicing operation of `np.ndarray` does not trigger a memory
# copy.
sub_date_bytes = self.date_bytes[:end_addr]
sub_data_bytes = self.data_bytes[:end_addr]
# Since the buffer size of first few data information is not
# changed,
sub_data_address = self.data_address[:indices]
......@@ -671,11 +671,11 @@ class BaseDataset(Dataset):
# Return the last few data information.
ignored_bytes_size = self.data_address[indices - 1]
start_addr = self.data_address[indices - 1].item()
sub_date_bytes = self.date_bytes[start_addr:]
sub_data_bytes = self.data_bytes[start_addr:]
sub_data_address = self.data_address[indices:]
sub_data_address = sub_data_address - ignored_bytes_size
elif isinstance(indices, Sequence):
sub_date_bytes = []
sub_data_bytes = []
sub_data_address = []
for idx in indices:
assert len(self) > idx >= -len(self)
......@@ -683,20 +683,20 @@ class BaseDataset(Dataset):
self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item()
# Get data information by address.
sub_date_bytes.append(self.date_bytes[start_addr:end_addr])
sub_data_bytes.append(self.data_bytes[start_addr:end_addr])
# Get data information size.
sub_data_address.append(end_addr - start_addr)
# Handle indices is an empty list.
if sub_date_bytes:
sub_date_bytes = np.concatenate(sub_date_bytes)
if sub_data_bytes:
sub_data_bytes = np.concatenate(sub_data_bytes)
sub_data_address = np.cumsum(sub_data_address)
else:
sub_date_bytes = np.array([])
sub_data_bytes = np.array([])
sub_data_address = np.array([])
else:
raise TypeError('indices should be a int or sequence of int, '
f'but got {type(indices)}')
return sub_date_bytes, sub_data_address # type: ignore
return sub_data_bytes, sub_data_address # type: ignore
def _get_unserialized_subset(self, indices: Union[Sequence[int],
int]) -> list:
......@@ -795,7 +795,7 @@ class BaseDataset(Dataset):
def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset':
"""Deepcopy for all attributes other than ``data_list``,
``data_address`` and ``date_bytes``.
``data_address`` and ``data_bytes``.
Args:
memo: Memory dict which used to reconstruct complex object
......@@ -806,7 +806,7 @@ class BaseDataset(Dataset):
memo[id(self)] = other
for key, value in self.__dict__.items():
if key in ['data_list', 'data_address', 'date_bytes']:
if key in ['data_list', 'data_address', 'data_bytes']:
continue
super(BaseDataset, other).__setattr__(key,
copy.deepcopy(value, memo))
......
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from unittest import TestCase
import pytest
import torch
import torch.distributed as torch_dist
import torch.multiprocessing as mp
import mmengine.dist as dist
from mmengine.testing._internal import MultiProcessTestCase
def _test_get_backend_non_dist():
assert dist.get_backend() is None
class TestUtils(TestCase):
def test_get_backend(self):
self.assertIsNone(dist.get_backend())
def _test_get_world_size_non_dist():
assert dist.get_world_size() == 1
def test_get_world_size(self):
self.assertEqual(dist.get_world_size(), 1)
def test_get_rank(self):
self.assertEqual(dist.get_rank(), 0)
def _test_get_rank_non_dist():
assert dist.get_rank() == 0
def test_local_size(self):
self.assertEqual(dist.get_local_size(), 1)
def test_local_rank(self):
self.assertEqual(dist.get_local_rank(), 0)
def _test_local_size_non_dist():
assert dist.get_local_size() == 1
def test_get_dist_info(self):
self.assertEqual(dist.get_dist_info(), (0, 1))
def test_is_main_process(self):
self.assertTrue(dist.is_main_process())
def _test_local_rank_non_dist():
assert dist.get_local_rank() == 0
def test_master_only(self):
@dist.master_only
def fun():
assert dist.get_rank() == 0
def _test_get_dist_info_non_dist():
assert dist.get_dist_info() == (0, 1)
fun()
def test_barrier(self):
dist.barrier() # nothing is done
def _test_is_main_process_non_dist():
assert dist.is_main_process()
class TestUtilsWithGLOOBackend(MultiProcessTestCase):
def _test_master_only_non_dist():
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
@dist.master_only
def fun():
assert dist.get_rank() == 0
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
fun()
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def _test_barrier_non_dist():
dist.barrier() # nothing is done
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def init_process(rank, world_size, functions, backend='gloo'):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501'
os.environ['RANK'] = str(rank)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
if backend == 'nccl':
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend=backend, rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
for func in functions:
func()
def main(functions, world_size=2, backend='gloo'):
try:
mp.spawn(
init_process,
args=(world_size, functions, backend),
nprocs=world_size)
except Exception:
pytest.fail('error')
def _test_get_backend_dist():
assert dist.get_backend() == torch_dist.get_backend()
def _test_get_world_size_dist():
assert dist.get_world_size() == 2
def _test_get_rank_dist():
if torch_dist.get_rank() == 0:
assert dist.get_rank() == 0
else:
assert dist.get_rank() == 1
def _test_local_size_dist():
assert dist.get_local_size() == 2
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def _test_local_rank_dist():
torch_dist.get_rank(dist.get_local_group()) == dist.get_local_rank()
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
def _test_get_dist_info_dist():
if dist.get_rank() == 0:
assert dist.get_dist_info() == (0, 2)
else:
assert dist.get_dist_info() == (1, 2)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def _test_is_main_process_dist():
if dist.get_rank() == 0:
assert dist.is_main_process()
else:
assert not dist.is_main_process()
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
class TestUtilsWithNCCLBackend(MultiProcessTestCase):
def _test_master_only_dist():
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
def test_non_distributed_env():
_test_get_backend_non_dist()
_test_get_world_size_non_dist()
_test_get_rank_non_dist()
_test_local_size_non_dist()
_test_local_rank_non_dist()
_test_get_dist_info_non_dist()
_test_is_main_process_non_dist()
_test_master_only_non_dist()
_test_barrier_non_dist()
functions_to_test = [
_test_get_backend_dist,
_test_get_world_size_dist,
_test_get_rank_dist,
_test_local_size_dist,
_test_local_rank_dist,
_test_get_dist_info_dist,
_test_is_main_process_dist,
_test_master_only_dist,
]
def test_gloo_backend():
main(functions_to_test)
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29505'
os.environ['RANK'] = str(rank)
@pytest.mark.skipif(
torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl')
def test_nccl_backend():
main(functions_to_test, backend='nccl')
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend='nccl', rank=rank, world_size=world_size)
dist.init_local_group(0, world_size)
def setUp(self):
super().setUp()
self._spawn_processes()
def test_get_backend(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_backend(), torch_dist.get_backend())
def test_get_world_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_world_size(), 2)
def test_get_rank(self):
self._init_dist_env(self.rank, self.world_size)
if torch_dist.get_rank() == 0:
self.assertEqual(dist.get_rank(), 0)
else:
self.assertEqual(dist.get_rank(), 1)
def test_local_size(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(dist.get_local_size(), 2)
def test_local_rank(self):
self._init_dist_env(self.rank, self.world_size)
self.assertEqual(
torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank())
def test_get_dist_info(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertEqual(dist.get_dist_info(), (0, 2))
else:
self.assertEqual(dist.get_dist_info(), (1, 2))
def test_is_main_process(self):
self._init_dist_env(self.rank, self.world_size)
if dist.get_rank() == 0:
self.assertTrue(dist.is_main_process())
else:
self.assertFalse(dist.is_main_process())
def test_master_only(self):
self._init_dist_env(self.rank, self.world_size)
@dist.master_only
def fun():
assert dist.get_rank() == 0
fun()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment