diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 0f27b2378b60c49f9f5292a375e67c4bd901ef09..4bb2b6761e0073f074c67ae58e7a5eb466550e17 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .empty_cache_hook import EmptyCacheHook from .hook import Hook from .iter_timer_hook import IterTimerHook from .optimizer_hook import OptimizerHook @@ -7,5 +8,5 @@ from .sampler_seed_hook import DistSamplerSeedHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', - 'OptimizerHook' + 'OptimizerHook', 'EmptyCacheHook' ] diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b457f2c0409007b653def67a64f76ebc506aa5c6 --- /dev/null +++ b/mmengine/hooks/empty_cache_hook.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +import torch + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class EmptyCacheHook(Hook): + """Releases all unoccupied cached GPU memory during the process of + training. + + Args: + before_epoch (bool): Whether to release cache before an epoch. Defaults + to False. + after_epoch (bool): Whether to release cache after an epoch. Defaults + to True. + after_iter (bool): Whether to release cache after an iteration. + Defaults to False. + """ + + def __init__(self, + before_epoch: bool = False, + after_epoch: bool = True, + after_iter: bool = False) -> None: + self._before_epoch = before_epoch + self._after_epoch = after_epoch + self._after_iter = after_iter + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Empty cache after an iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. + Defaults to None. + """ + if self._after_iter: + torch.cuda.empty_cache() + + def before_epoch(self, runner: object) -> None: + """Empty cache before an epoch. + + Args: + runner (object): The runner of the training process. + """ + if self._before_epoch: + torch.cuda.empty_cache() + + def after_epoch(self, runner: object) -> None: + """Empty cache after an epoch. + + Args: + runner (object): The runner of the training process. + """ + if self._after_epoch: + torch.cuda.empty_cache() diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..5d65107022c349c7d5b6b6a6023313af629fbbd7 --- /dev/null +++ b/tests/test_hook/test_empty_cache_hook.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mock import Mock + +from mmengine.hooks import EmptyCacheHook + + +class TestEmptyCacheHook: + + def test_emtpy_cache_hook(self): + Hook = EmptyCacheHook(True, True, True) + Runner = Mock() + Hook.after_iter(Runner) + Hook.before_epoch(Runner) + Hook.after_epoch(Runner)