Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
I
ICV-mmengine_basecode
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Florian Schiffel
ICV-mmengine_basecode
Commits
93d22757
Commit
93d22757
authored
2 years ago
by
ZwwWayne
Browse files
Options
Downloads
Plain Diff
Merge branch 'main' into adapt
parents
59cc08e3
66e52883
No related branches found
No related tags found
No related merge requests found
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
mmengine/dataset/base_dataset.py
+17
-17
17 additions, 17 deletions
mmengine/dataset/base_dataset.py
tests/test_dist/test_dist.py
+389
-311
389 additions, 311 deletions
tests/test_dist/test_dist.py
tests/test_dist/test_utils.py
+138
-119
138 additions, 119 deletions
tests/test_dist/test_utils.py
with
544 additions
and
447 deletions
mmengine/dataset/base_dataset.py
+
17
−
17
View file @
93d22757
...
@@ -229,7 +229,7 @@ class BaseDataset(Dataset):
...
@@ -229,7 +229,7 @@ class BaseDataset(Dataset):
self
.
test_mode
=
test_mode
self
.
test_mode
=
test_mode
self
.
max_refetch
=
max_refetch
self
.
max_refetch
=
max_refetch
self
.
data_list
:
List
[
dict
]
=
[]
self
.
data_list
:
List
[
dict
]
=
[]
self
.
dat
e
_bytes
:
np
.
ndarray
self
.
dat
a
_bytes
:
np
.
ndarray
# Set meta information.
# Set meta information.
self
.
_metainfo
=
self
.
_get_meta_info
(
copy
.
deepcopy
(
metainfo
))
self
.
_metainfo
=
self
.
_get_meta_info
(
copy
.
deepcopy
(
metainfo
))
...
@@ -259,7 +259,7 @@ class BaseDataset(Dataset):
...
@@ -259,7 +259,7 @@ class BaseDataset(Dataset):
start_addr
=
0
if
idx
==
0
else
self
.
data_address
[
idx
-
1
].
item
()
start_addr
=
0
if
idx
==
0
else
self
.
data_address
[
idx
-
1
].
item
()
end_addr
=
self
.
data_address
[
idx
].
item
()
end_addr
=
self
.
data_address
[
idx
].
item
()
bytes
=
memoryview
(
bytes
=
memoryview
(
self
.
dat
e
_bytes
[
start_addr
:
end_addr
])
# type: ignore
self
.
dat
a
_bytes
[
start_addr
:
end_addr
])
# type: ignore
data_info
=
pickle
.
loads
(
bytes
)
# type: ignore
data_info
=
pickle
.
loads
(
bytes
)
# type: ignore
else
:
else
:
data_info
=
self
.
data_list
[
idx
]
data_info
=
self
.
data_list
[
idx
]
...
@@ -302,7 +302,7 @@ class BaseDataset(Dataset):
...
@@ -302,7 +302,7 @@ class BaseDataset(Dataset):
# serialize data_list
# serialize data_list
if
self
.
serialize_data
:
if
self
.
serialize_data
:
self
.
dat
e
_bytes
,
self
.
data_address
=
self
.
_serialize_data
()
self
.
dat
a
_bytes
,
self
.
data_address
=
self
.
_serialize_data
()
self
.
_fully_initialized
=
True
self
.
_fully_initialized
=
True
...
@@ -575,7 +575,7 @@ class BaseDataset(Dataset):
...
@@ -575,7 +575,7 @@ class BaseDataset(Dataset):
# Get subset of data from serialized data or data information sequence
# Get subset of data from serialized data or data information sequence
# according to `self.serialize_data`.
# according to `self.serialize_data`.
if
self
.
serialize_data
:
if
self
.
serialize_data
:
self
.
dat
e
_bytes
,
self
.
data_address
=
\
self
.
dat
a
_bytes
,
self
.
data_address
=
\
self
.
_get_serialized_subset
(
indices
)
self
.
_get_serialized_subset
(
indices
)
else
:
else
:
self
.
data_list
=
self
.
_get_unserialized_subset
(
indices
)
self
.
data_list
=
self
.
_get_unserialized_subset
(
indices
)
...
@@ -626,9 +626,9 @@ class BaseDataset(Dataset):
...
@@ -626,9 +626,9 @@ class BaseDataset(Dataset):
sub_dataset
=
self
.
_copy_without_annotation
()
sub_dataset
=
self
.
_copy_without_annotation
()
# Get subset of dataset with serialize and unserialized data.
# Get subset of dataset with serialize and unserialized data.
if
self
.
serialize_data
:
if
self
.
serialize_data
:
dat
e
_bytes
,
data_address
=
\
dat
a
_bytes
,
data_address
=
\
self
.
_get_serialized_subset
(
indices
)
self
.
_get_serialized_subset
(
indices
)
sub_dataset
.
dat
e
_bytes
=
dat
e
_bytes
.
copy
()
sub_dataset
.
dat
a
_bytes
=
dat
a
_bytes
.
copy
()
sub_dataset
.
data_address
=
data_address
.
copy
()
sub_dataset
.
data_address
=
data_address
.
copy
()
else
:
else
:
data_list
=
self
.
_get_unserialized_subset
(
indices
)
data_list
=
self
.
_get_unserialized_subset
(
indices
)
...
@@ -650,7 +650,7 @@ class BaseDataset(Dataset):
...
@@ -650,7 +650,7 @@ class BaseDataset(Dataset):
Tuple[np.ndarray, np.ndarray]: subset of serialized data
Tuple[np.ndarray, np.ndarray]: subset of serialized data
information.
information.
"""
"""
sub_dat
e
_bytes
:
Union
[
List
,
np
.
ndarray
]
sub_dat
a
_bytes
:
Union
[
List
,
np
.
ndarray
]
sub_data_address
:
Union
[
List
,
np
.
ndarray
]
sub_data_address
:
Union
[
List
,
np
.
ndarray
]
if
isinstance
(
indices
,
int
):
if
isinstance
(
indices
,
int
):
if
indices
>=
0
:
if
indices
>=
0
:
...
@@ -661,7 +661,7 @@ class BaseDataset(Dataset):
...
@@ -661,7 +661,7 @@ class BaseDataset(Dataset):
if
indices
>
0
else
0
if
indices
>
0
else
0
# Slicing operation of `np.ndarray` does not trigger a memory
# Slicing operation of `np.ndarray` does not trigger a memory
# copy.
# copy.
sub_dat
e
_bytes
=
self
.
dat
e
_bytes
[:
end_addr
]
sub_dat
a
_bytes
=
self
.
dat
a
_bytes
[:
end_addr
]
# Since the buffer size of first few data information is not
# Since the buffer size of first few data information is not
# changed,
# changed,
sub_data_address
=
self
.
data_address
[:
indices
]
sub_data_address
=
self
.
data_address
[:
indices
]
...
@@ -671,11 +671,11 @@ class BaseDataset(Dataset):
...
@@ -671,11 +671,11 @@ class BaseDataset(Dataset):
# Return the last few data information.
# Return the last few data information.
ignored_bytes_size
=
self
.
data_address
[
indices
-
1
]
ignored_bytes_size
=
self
.
data_address
[
indices
-
1
]
start_addr
=
self
.
data_address
[
indices
-
1
].
item
()
start_addr
=
self
.
data_address
[
indices
-
1
].
item
()
sub_dat
e
_bytes
=
self
.
dat
e
_bytes
[
start_addr
:]
sub_dat
a
_bytes
=
self
.
dat
a
_bytes
[
start_addr
:]
sub_data_address
=
self
.
data_address
[
indices
:]
sub_data_address
=
self
.
data_address
[
indices
:]
sub_data_address
=
sub_data_address
-
ignored_bytes_size
sub_data_address
=
sub_data_address
-
ignored_bytes_size
elif
isinstance
(
indices
,
Sequence
):
elif
isinstance
(
indices
,
Sequence
):
sub_dat
e
_bytes
=
[]
sub_dat
a
_bytes
=
[]
sub_data_address
=
[]
sub_data_address
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
assert
len
(
self
)
>
idx
>=
-
len
(
self
)
assert
len
(
self
)
>
idx
>=
-
len
(
self
)
...
@@ -683,20 +683,20 @@ class BaseDataset(Dataset):
...
@@ -683,20 +683,20 @@ class BaseDataset(Dataset):
self
.
data_address
[
idx
-
1
].
item
()
self
.
data_address
[
idx
-
1
].
item
()
end_addr
=
self
.
data_address
[
idx
].
item
()
end_addr
=
self
.
data_address
[
idx
].
item
()
# Get data information by address.
# Get data information by address.
sub_dat
e
_bytes
.
append
(
self
.
dat
e
_bytes
[
start_addr
:
end_addr
])
sub_dat
a
_bytes
.
append
(
self
.
dat
a
_bytes
[
start_addr
:
end_addr
])
# Get data information size.
# Get data information size.
sub_data_address
.
append
(
end_addr
-
start_addr
)
sub_data_address
.
append
(
end_addr
-
start_addr
)
# Handle indices is an empty list.
# Handle indices is an empty list.
if
sub_dat
e
_bytes
:
if
sub_dat
a
_bytes
:
sub_dat
e
_bytes
=
np
.
concatenate
(
sub_dat
e
_bytes
)
sub_dat
a
_bytes
=
np
.
concatenate
(
sub_dat
a
_bytes
)
sub_data_address
=
np
.
cumsum
(
sub_data_address
)
sub_data_address
=
np
.
cumsum
(
sub_data_address
)
else
:
else
:
sub_dat
e
_bytes
=
np
.
array
([])
sub_dat
a
_bytes
=
np
.
array
([])
sub_data_address
=
np
.
array
([])
sub_data_address
=
np
.
array
([])
else
:
else
:
raise
TypeError
(
'
indices should be a int or sequence of int,
'
raise
TypeError
(
'
indices should be a int or sequence of int,
'
f
'
but got
{
type
(
indices
)
}
'
)
f
'
but got
{
type
(
indices
)
}
'
)
return
sub_dat
e
_bytes
,
sub_data_address
# type: ignore
return
sub_dat
a
_bytes
,
sub_data_address
# type: ignore
def
_get_unserialized_subset
(
self
,
indices
:
Union
[
Sequence
[
int
],
def
_get_unserialized_subset
(
self
,
indices
:
Union
[
Sequence
[
int
],
int
])
->
list
:
int
])
->
list
:
...
@@ -795,7 +795,7 @@ class BaseDataset(Dataset):
...
@@ -795,7 +795,7 @@ class BaseDataset(Dataset):
def
_copy_without_annotation
(
self
,
memo
=
dict
())
->
'
BaseDataset
'
:
def
_copy_without_annotation
(
self
,
memo
=
dict
())
->
'
BaseDataset
'
:
"""
Deepcopy for all attributes other than ``data_list``,
"""
Deepcopy for all attributes other than ``data_list``,
``data_address`` and ``dat
e
_bytes``.
``data_address`` and ``dat
a
_bytes``.
Args:
Args:
memo: Memory dict which used to reconstruct complex object
memo: Memory dict which used to reconstruct complex object
...
@@ -806,7 +806,7 @@ class BaseDataset(Dataset):
...
@@ -806,7 +806,7 @@ class BaseDataset(Dataset):
memo
[
id
(
self
)]
=
other
memo
[
id
(
self
)]
=
other
for
key
,
value
in
self
.
__dict__
.
items
():
for
key
,
value
in
self
.
__dict__
.
items
():
if
key
in
[
'
data_list
'
,
'
data_address
'
,
'
dat
e
_bytes
'
]:
if
key
in
[
'
data_list
'
,
'
data_address
'
,
'
dat
a
_bytes
'
]:
continue
continue
super
(
BaseDataset
,
other
).
__setattr__
(
key
,
super
(
BaseDataset
,
other
).
__setattr__
(
key
,
copy
.
deepcopy
(
value
,
memo
))
copy
.
deepcopy
(
value
,
memo
))
...
...
This diff is collapsed.
Click to expand it.
tests/test_dist/test_dist.py
+
389
−
311
View file @
93d22757
This diff is collapsed.
Click to expand it.
tests/test_dist/test_utils.py
+
138
−
119
View file @
93d22757
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os
import
os
import
unittest
from
unittest
import
TestCase
import
pytest
import
torch
import
torch
import
torch.distributed
as
torch_dist
import
torch.distributed
as
torch_dist
import
torch.multiprocessing
as
mp
import
mmengine.dist
as
dist
import
mmengine.dist
as
dist
from
mmengine.testing._internal
import
MultiProcessTestCase
def
_test_get_backend_non_dist
():
class
TestUtils
(
TestCase
):
assert
dist
.
get_backend
()
is
None
def
test_get_backend
(
self
):
self
.
assertIsNone
(
dist
.
get_backend
())
def
_
test_get_world_size
_non_dist
(
):
def
test_get_world_size
(
self
):
assert
dist
.
get_world_size
()
==
1
self
.
assertEqual
(
dist
.
get_world_size
()
,
1
)
def
test_get_rank
(
self
):
self
.
assertEqual
(
dist
.
get_rank
(),
0
)
def
_
test_
get_rank_non_dist
(
):
def
test_
local_size
(
self
):
assert
dist
.
get_rank
()
==
0
self
.
assertEqual
(
dist
.
get_local_size
(),
1
)
def
test_local_rank
(
self
):
self
.
assertEqual
(
dist
.
get_local_rank
(),
0
)
def
_
test_
local_size_non_dist
(
):
def
test_
get_dist_info
(
self
):
assert
dist
.
get_local_size
()
==
1
self
.
assertEqual
(
dist
.
get_dist_info
(),
(
0
,
1
))
def
test_is_main_process
(
self
):
self
.
assertTrue
(
dist
.
is_main_process
())
def
_test_local_rank_non_dist
():
def
test_master_only
(
self
):
assert
dist
.
get_local_rank
()
==
0
@dist.master_only
def
fun
():
assert
dist
.
get_rank
()
==
0
def
_test_get_dist_info_non_dist
():
fun
()
assert
dist
.
get_dist_info
()
==
(
0
,
1
)
def
test_barrier
(
self
):
dist
.
barrier
()
# nothing is done
def
_test_is_main_process_non_dist
():
assert
dist
.
is_main_process
()
class
TestUtilsWithGLOOBackend
(
MultiProcessTestCase
):
def
_test_master_only_non_dist
():
def
_init_dist_env
(
self
,
rank
,
world_size
):
"""
Initialize the distributed environment.
"""
os
.
environ
[
'
MASTER_ADDR
'
]
=
'
127.0.0.1
'
os
.
environ
[
'
MASTER_PORT
'
]
=
'
29505
'
os
.
environ
[
'
RANK
'
]
=
str
(
rank
)
@dist.master_only
torch_dist
.
init_process_group
(
def
fun
():
backend
=
'
gloo
'
,
rank
=
rank
,
world_size
=
world_size
)
assert
dist
.
get_rank
()
==
0
dist
.
init_local_group
(
0
,
world_size
)
fun
()
def
setUp
(
self
):
super
().
setUp
()
self
.
_spawn_processes
()
def
test_get_backend
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
self
.
assertEqual
(
dist
.
get_backend
(),
torch_dist
.
get_backend
())
def
_test_barrier_non_dist
():
def
test_get_world_size
(
self
):
dist
.
barrier
()
# nothing is done
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
self
.
assertEqual
(
dist
.
get_world_size
(),
2
)
def
test_get_rank
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
if
torch_dist
.
get_rank
()
==
0
:
self
.
assertEqual
(
dist
.
get_rank
(),
0
)
else
:
self
.
assertEqual
(
dist
.
get_rank
(),
1
)
def
init_process
(
rank
,
world_size
,
functions
,
backend
=
'
gloo
'
):
def
test_local_size
(
self
):
"""
Initialize the distributed environment.
"""
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
os
.
environ
[
'
MASTER_ADDR
'
]
=
'
127.0.0.1
'
self
.
assertEqual
(
dist
.
get_local_size
(),
2
)
os
.
environ
[
'
MASTER_PORT
'
]
=
'
29501
'
os
.
environ
[
'
RANK
'
]
=
str
(
rank
)
if
backend
==
'
nccl
'
:
def
test_local_rank
(
self
):
num_gpus
=
torch
.
cuda
.
device_count
()
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
self
.
assertEqual
(
torch_dist
.
get_rank
(
dist
.
get_local_group
()),
dist
.
get_local_rank
())
torch_dist
.
init_process_group
(
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_local_group
(
0
,
world_size
)
for
func
in
functions
:
func
()
def
main
(
functions
,
world_size
=
2
,
backend
=
'
gloo
'
):
try
:
mp
.
spawn
(
init_process
,
args
=
(
world_size
,
functions
,
backend
),
nprocs
=
world_size
)
except
Exception
:
pytest
.
fail
(
'
error
'
)
def
_test_get_backend_dist
():
assert
dist
.
get_backend
()
==
torch_dist
.
get_backend
()
def
_test_get_world_size_dist
():
assert
dist
.
get_world_size
()
==
2
def
_test_get_rank_dist
():
if
torch_dist
.
get_rank
()
==
0
:
assert
dist
.
get_rank
()
==
0
else
:
assert
dist
.
get_rank
()
==
1
def
_test_local_size_dist
():
assert
dist
.
get_local_size
()
==
2
def
test_get_dist_info
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
if
dist
.
get_rank
()
==
0
:
self
.
assertEqual
(
dist
.
get_dist_info
(),
(
0
,
2
))
else
:
self
.
assertEqual
(
dist
.
get_dist_info
(),
(
1
,
2
))
def
_test_local_rank_dist
():
def
test_is_main_process
(
self
):
torch_dist
.
get_rank
(
dist
.
get_local_group
())
==
dist
.
get_local_rank
()
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
if
dist
.
get_rank
()
==
0
:
self
.
assertTrue
(
dist
.
is_main_process
())
else
:
self
.
assertFalse
(
dist
.
is_main_process
())
def
test_master_only
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
def
_test_get_dist_info_dist
():
@dist.master_only
if
dist
.
get_rank
()
==
0
:
def
fun
():
assert
dist
.
get_dist_info
()
==
(
0
,
2
)
assert
dist
.
get_rank
()
==
0
else
:
assert
dist
.
get_dist_info
()
==
(
1
,
2
)
fun
()
def
_test_is_main_process_dist
():
if
dist
.
get_rank
()
==
0
:
assert
dist
.
is_main_process
()
else
:
assert
not
dist
.
is_main_process
()
@unittest.skipIf
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
'
need 2 gpu to test nccl
'
)
class
TestUtilsWithNCCLBackend
(
MultiProcessTestCase
):
def
_test_master_only_dist
():
def
_init_dist_env
(
self
,
rank
,
world_size
):
"""
Initialize the distributed environment.
"""
@dist.master_only
os
.
environ
[
'
MASTER_ADDR
'
]
=
'
127.0.0.1
'
def
fun
():
os
.
environ
[
'
MASTER_PORT
'
]
=
'
29505
'
assert
dist
.
get_rank
()
==
0
os
.
environ
[
'
RANK
'
]
=
str
(
rank
)
fun
()
def
test_non_distributed_env
():
_test_get_backend_non_dist
()
_test_get_world_size_non_dist
()
_test_get_rank_non_dist
()
_test_local_size_non_dist
()
_test_local_rank_non_dist
()
_test_get_dist_info_non_dist
()
_test_is_main_process_non_dist
()
_test_master_only_non_dist
()
_test_barrier_non_dist
()
functions_to_test
=
[
_test_get_backend_dist
,
_test_get_world_size_dist
,
_test_get_rank_dist
,
_test_local_size_dist
,
_test_local_rank_dist
,
_test_get_dist_info_dist
,
_test_is_main_process_dist
,
_test_master_only_dist
,
]
def
test_gloo_backend
():
main
(
functions_to_test
)
@pytest.mark.skipif
(
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
device_count
()
<
2
,
reason
=
'
need 2 gpu to test nccl
'
)
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
def
test_nccl_backend
():
torch_dist
.
init_process_group
(
main
(
functions_to_test
,
backend
=
'
nccl
'
)
backend
=
'
nccl
'
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_local_group
(
0
,
world_size
)
def
setUp
(
self
):
super
().
setUp
()
self
.
_spawn_processes
()
def
test_get_backend
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
self
.
assertEqual
(
dist
.
get_backend
(),
torch_dist
.
get_backend
())
def
test_get_world_size
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
self
.
assertEqual
(
dist
.
get_world_size
(),
2
)
def
test_get_rank
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
if
torch_dist
.
get_rank
()
==
0
:
self
.
assertEqual
(
dist
.
get_rank
(),
0
)
else
:
self
.
assertEqual
(
dist
.
get_rank
(),
1
)
def
test_local_size
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
self
.
assertEqual
(
dist
.
get_local_size
(),
2
)
def
test_local_rank
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
self
.
assertEqual
(
torch_dist
.
get_rank
(
dist
.
get_local_group
()),
dist
.
get_local_rank
())
def
test_get_dist_info
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
if
dist
.
get_rank
()
==
0
:
self
.
assertEqual
(
dist
.
get_dist_info
(),
(
0
,
2
))
else
:
self
.
assertEqual
(
dist
.
get_dist_info
(),
(
1
,
2
))
def
test_is_main_process
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
if
dist
.
get_rank
()
==
0
:
self
.
assertTrue
(
dist
.
is_main_process
())
else
:
self
.
assertFalse
(
dist
.
is_main_process
())
def
test_master_only
(
self
):
self
.
_init_dist_env
(
self
.
rank
,
self
.
world_size
)
@dist.master_only
def
fun
():
assert
dist
.
get_rank
()
==
0
fun
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment