Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
I
ICV-mmengine_basecode
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Florian Schiffel
ICV-mmengine_basecode
Commits
438e8e74
Unverified
Commit
438e8e74
authored
2 years ago
by
Mashiro
Committed by
GitHub
2 years ago
Browse files
Options
Downloads
Patches
Plain Diff
BaseModel support recursively set the device of data_preprocessor (#387)
parent
f98ba606
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
mmengine/model/base_model/base_model.py
+33
-8
33 additions, 8 deletions
mmengine/model/base_model/base_model.py
tests/test_model/test_base_model/test_base_model.py
+33
-0
33 additions, 0 deletions
tests/test_model/test_base_model/test_base_model.py
with
66 additions
and
8 deletions
mmengine/model/base_model/base_model.py
+
33
−
8
View file @
438e8e74
...
@@ -11,6 +11,7 @@ from mmengine.optim import OptimWrapper
...
@@ -11,6 +11,7 @@ from mmengine.optim import OptimWrapper
from
mmengine.registry
import
MODELS
from
mmengine.registry
import
MODELS
from
mmengine.utils
import
is_list_of
from
mmengine.utils
import
is_list_of
from
..base_module
import
BaseModule
from
..base_module
import
BaseModule
from
.data_preprocessor
import
BaseDataPreprocessor
ForwardResults
=
Union
[
Dict
[
str
,
torch
.
Tensor
],
List
[
BaseDataElement
],
ForwardResults
=
Union
[
Dict
[
str
,
torch
.
Tensor
],
List
[
BaseDataElement
],
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
]
Tuple
[
torch
.
Tensor
],
torch
.
Tensor
]
...
@@ -177,30 +178,38 @@ class BaseModel(BaseModule):
...
@@ -177,30 +178,38 @@ class BaseModel(BaseModule):
return
loss
,
log_vars
return
loss
,
log_vars
def
to
(
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
*
args
,
def
to
(
self
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
*
args
,
**
kwargs
)
->
nn
.
Module
:
**
kwargs
)
->
nn
.
Module
:
"""
Overrides this method to call :meth:`BaseDataPreprocessor.to`
"""
Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
additionally.
Args:
Args:
device (int or torch.device, optional): the desired device
of the
device (int
, str
or torch.device, optional): the desired device
parameters and buffers in this module.
of the
parameters and buffers in this module.
Returns:
Returns:
nn.Module: The model itself.
nn.Module: The model itself.
"""
"""
self
.
data_preprocessor
.
to
(
device
)
if
device
is
not
None
:
self
.
_set_device
(
torch
.
device
(
device
))
return
super
().
to
(
device
)
return
super
().
to
(
device
)
def
cuda
(
self
,
*
args
,
**
kwargs
)
->
nn
.
Module
:
def
cuda
(
self
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
)
->
nn
.
Module
:
"""
Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
"""
Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally.
additionally.
Returns:
Returns:
nn.Module: The model itself.
nn.Module: The model itself.
"""
"""
self
.
data_preprocessor
.
cuda
()
if
device
is
None
or
isinstance
(
device
,
int
):
return
super
().
cuda
()
device
=
torch
.
device
(
'
cuda
'
,
index
=
device
)
self
.
_set_device
(
torch
.
device
(
device
))
return
super
().
cuda
(
device
)
def
cpu
(
self
,
*
args
,
**
kwargs
)
->
nn
.
Module
:
def
cpu
(
self
,
*
args
,
**
kwargs
)
->
nn
.
Module
:
"""
Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
"""
Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
...
@@ -209,9 +218,25 @@ class BaseModel(BaseModule):
...
@@ -209,9 +218,25 @@ class BaseModel(BaseModule):
Returns:
Returns:
nn.Module: The model itself.
nn.Module: The model itself.
"""
"""
self
.
data_preprocessor
.
cpu
(
)
self
.
_set_device
(
torch
.
device
(
'
cpu
'
)
)
return
super
().
cpu
()
return
super
().
cpu
()
def
_set_device
(
self
,
device
:
torch
.
device
)
->
None
:
"""
Recursively set device for `BaseDataPreprocessor` instance.
Args:
device (torch.device): the desired device of the parameters and
buffers in this module.
"""
def
apply_fn
(
module
):
if
not
isinstance
(
module
,
BaseDataPreprocessor
):
return
if
device
is
not
None
:
module
.
_device
=
device
self
.
apply
(
apply_fn
)
@abstractmethod
@abstractmethod
def
forward
(
self
,
def
forward
(
self
,
batch_inputs
:
torch
.
Tensor
,
batch_inputs
:
torch
.
Tensor
,
...
...
This diff is collapsed.
Click to expand it.
tests/test_model/test_base_model/test_base_model.py
+
33
−
0
View file @
438e8e74
...
@@ -40,6 +40,16 @@ class ToyModel(BaseModel):
...
@@ -40,6 +40,16 @@ class ToyModel(BaseModel):
return
out
return
out
class
NestedModel
(
BaseModel
):
def
__init__
(
self
):
super
().
__init__
()
self
.
toy_model
=
ToyModel
()
def
forward
(
self
):
pass
class
TestBaseModel
(
TestCase
):
class
TestBaseModel
(
TestCase
):
def
test_init
(
self
):
def
test_init
(
self
):
...
@@ -118,6 +128,15 @@ class TestBaseModel(TestCase):
...
@@ -118,6 +128,15 @@ class TestBaseModel(TestCase):
out
=
model
.
val_step
([
data
])
out
=
model
.
val_step
([
data
])
self
.
assertEqual
(
out
.
device
.
type
,
'
cuda
'
)
self
.
assertEqual
(
out
.
device
.
type
,
'
cuda
'
)
model
=
NestedModel
()
self
.
assertEqual
(
model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cpu
'
))
self
.
assertEqual
(
model
.
toy_model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cpu
'
))
model
.
cuda
()
self
.
assertEqual
(
model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cuda
'
))
self
.
assertEqual
(
model
.
toy_model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cuda
'
))
@unittest.skipIf
(
not
torch
.
cuda
.
is_available
(),
'
cuda should be available
'
)
@unittest.skipIf
(
not
torch
.
cuda
.
is_available
(),
'
cuda should be available
'
)
def
test_to
(
self
):
def
test_to
(
self
):
inputs
=
torch
.
randn
(
3
,
1
,
1
).
to
(
'
cuda:0
'
)
inputs
=
torch
.
randn
(
3
,
1
,
1
).
to
(
'
cuda:0
'
)
...
@@ -125,3 +144,17 @@ class TestBaseModel(TestCase):
...
@@ -125,3 +144,17 @@ class TestBaseModel(TestCase):
model
=
ToyModel
().
to
(
torch
.
cuda
.
current_device
())
model
=
ToyModel
().
to
(
torch
.
cuda
.
current_device
())
out
=
model
.
val_step
([
data
])
out
=
model
.
val_step
([
data
])
self
.
assertEqual
(
out
.
device
.
type
,
'
cuda
'
)
self
.
assertEqual
(
out
.
device
.
type
,
'
cuda
'
)
model
=
NestedModel
()
self
.
assertEqual
(
model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cpu
'
))
self
.
assertEqual
(
model
.
toy_model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cpu
'
))
model
.
to
(
'
cuda
'
)
self
.
assertEqual
(
model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cuda
'
))
self
.
assertEqual
(
model
.
toy_model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cuda
'
))
model
.
to
()
self
.
assertEqual
(
model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cuda
'
))
self
.
assertEqual
(
model
.
toy_model
.
data_preprocessor
.
_device
,
torch
.
device
(
'
cuda
'
))
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment