From 94ab45d07edaa07a4aeaa79687fef563d5323baf Mon Sep 17 00:00:00 2001
From: Yifei Yang <2744335995@qq.com>
Date: Wed, 2 Mar 2022 14:04:41 +0800
Subject: [PATCH] [Feature] Add empty cache hook (#58)

* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* [Fix]: Add docstring and type hint for base hook

* [Fix]: Add test case to not the last iter, inner_iter, epoch

* [Fix]: Add missing type hint

* [Feature]: Add Args and Returns in docstring

* [Fix]: Add missing colon

* [Fix]: Add optional to docstring

* [Fix]: Fix docstring problem

* [Fix]: Fix lint

* fix lint

* update typing and docs

* fix lint

* Update mmengine/hooks/empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/hooks/empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/hooks/empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_hook/test_empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

* fix comments

* remove test condition

Co-authored-by: seuyou <3463423099@qq.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
---
 mmengine/hooks/__init__.py               |  3 +-
 mmengine/hooks/empty_cache_hook.py       | 65 ++++++++++++++++++++++++
 tests/test_hook/test_empty_cache_hook.py | 14 +++++
 3 files changed, 81 insertions(+), 1 deletion(-)
 create mode 100644 mmengine/hooks/empty_cache_hook.py
 create mode 100644 tests/test_hook/test_empty_cache_hook.py

diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index 0f27b237..4bb2b676 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from .empty_cache_hook import EmptyCacheHook
 from .hook import Hook
 from .iter_timer_hook import IterTimerHook
 from .optimizer_hook import OptimizerHook
@@ -7,5 +8,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
 
 __all__ = [
     'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
-    'OptimizerHook'
+    'OptimizerHook', 'EmptyCacheHook'
 ]
diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py
new file mode 100644
index 00000000..b457f2c0
--- /dev/null
+++ b/mmengine/hooks/empty_cache_hook.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
+import torch
+
+from mmengine.data import BaseDataSample
+from mmengine.registry import HOOKS
+from .hook import Hook
+
+
+@HOOKS.register_module()
+class EmptyCacheHook(Hook):
+    """Releases all unoccupied cached GPU memory during the process of
+    training.
+
+    Args:
+        before_epoch (bool): Whether to release cache before an epoch. Defaults
+            to False.
+        after_epoch (bool): Whether to release cache after an epoch. Defaults
+            to True.
+        after_iter (bool): Whether to release cache after an iteration.
+            Defaults to False.
+    """
+
+    def __init__(self,
+                 before_epoch: bool = False,
+                 after_epoch: bool = True,
+                 after_iter: bool = False) -> None:
+        self._before_epoch = before_epoch
+        self._after_epoch = after_epoch
+        self._after_iter = after_iter
+
+    def after_iter(self,
+                   runner: object,
+                   data_batch: Optional[Sequence[BaseDataSample]] = None,
+                   outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """Empty cache after an iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample]): Data from dataloader.
+                Defaults to None.
+            outputs (Sequence[BaseDataSample]): Outputs from model.
+                Defaults to None.
+        """
+        if self._after_iter:
+            torch.cuda.empty_cache()
+
+    def before_epoch(self, runner: object) -> None:
+        """Empty cache before an epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        if self._before_epoch:
+            torch.cuda.empty_cache()
+
+    def after_epoch(self, runner: object) -> None:
+        """Empty cache after an epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        if self._after_epoch:
+            torch.cuda.empty_cache()
diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py
new file mode 100644
index 00000000..5d651070
--- /dev/null
+++ b/tests/test_hook/test_empty_cache_hook.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mock import Mock
+
+from mmengine.hooks import EmptyCacheHook
+
+
+class TestEmptyCacheHook:
+
+    def test_emtpy_cache_hook(self):
+        Hook = EmptyCacheHook(True, True, True)
+        Runner = Mock()
+        Hook.after_iter(Runner)
+        Hook.before_epoch(Runner)
+        Hook.after_epoch(Runner)
-- 
GitLab