From 94ab45d07edaa07a4aeaa79687fef563d5323baf Mon Sep 17 00:00:00 2001 From: Yifei Yang <2744335995@qq.com> Date: Wed, 2 Mar 2022 14:04:41 +0800 Subject: [PATCH] [Feature] Add empty cache hook (#58) * [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * [Fix]: Add docstring and type hint for base hook * [Fix]: Add test case to not the last iter, inner_iter, epoch * [Fix]: Add missing type hint * [Feature]: Add Args and Returns in docstring * [Fix]: Add missing colon * [Fix]: Add optional to docstring * [Fix]: Fix docstring problem * [Fix]: Fix lint * fix lint * update typing and docs * fix lint * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_hook/test_empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix comments * remove test condition Co-authored-by: seuyou <3463423099@qq.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/__init__.py | 3 +- mmengine/hooks/empty_cache_hook.py | 65 ++++++++++++++++++++++++ tests/test_hook/test_empty_cache_hook.py | 14 +++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 mmengine/hooks/empty_cache_hook.py create mode 100644 tests/test_hook/test_empty_cache_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 0f27b237..4bb2b676 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .empty_cache_hook import EmptyCacheHook from .hook import Hook from .iter_timer_hook import IterTimerHook from .optimizer_hook import OptimizerHook @@ -7,5 +8,5 @@ from .sampler_seed_hook import DistSamplerSeedHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', - 'OptimizerHook' + 'OptimizerHook', 'EmptyCacheHook' ] diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py new file mode 100644 index 00000000..b457f2c0 --- /dev/null +++ b/mmengine/hooks/empty_cache_hook.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +import torch + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class EmptyCacheHook(Hook): + """Releases all unoccupied cached GPU memory during the process of + training. + + Args: + before_epoch (bool): Whether to release cache before an epoch. Defaults + to False. + after_epoch (bool): Whether to release cache after an epoch. Defaults + to True. + after_iter (bool): Whether to release cache after an iteration. + Defaults to False. + """ + + def __init__(self, + before_epoch: bool = False, + after_epoch: bool = True, + after_iter: bool = False) -> None: + self._before_epoch = before_epoch + self._after_epoch = after_epoch + self._after_iter = after_iter + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Empty cache after an iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. + Defaults to None. + """ + if self._after_iter: + torch.cuda.empty_cache() + + def before_epoch(self, runner: object) -> None: + """Empty cache before an epoch. + + Args: + runner (object): The runner of the training process. + """ + if self._before_epoch: + torch.cuda.empty_cache() + + def after_epoch(self, runner: object) -> None: + """Empty cache after an epoch. + + Args: + runner (object): The runner of the training process. + """ + if self._after_epoch: + torch.cuda.empty_cache() diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py new file mode 100644 index 00000000..5d651070 --- /dev/null +++ b/tests/test_hook/test_empty_cache_hook.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mock import Mock + +from mmengine.hooks import EmptyCacheHook + + +class TestEmptyCacheHook: + + def test_emtpy_cache_hook(self): + Hook = EmptyCacheHook(True, True, True) + Runner = Mock() + Hook.after_iter(Runner) + Hook.before_epoch(Runner) + Hook.after_epoch(Runner) -- GitLab