From b7aa4dd8856c74c349e895be08faacef8e1762c7 Mon Sep 17 00:00:00 2001
From: shenmishajing <shenmishajing@Gmail.com>
Date: Mon, 21 Nov 2022 11:52:48 +0800
Subject: [PATCH] [Fix]: fix add graph function is not called bug in
 visualization hooks (#632)

* fix add graph func is not called bug

* move add graph call to NaiveVisualizationHook.before_train

* Update mmengine/hooks/naive_visualization_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* adjust param sequence and add docstring

* minor refine

* Update mmengine/visualization/vis_backend.py

* update version info

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
---
 mmengine/hooks/naive_visualization_hook.py |  8 ++++++++
 mmengine/visualization/vis_backend.py      | 17 ++++++++++++++++-
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py
index 6a6c3f38..079a81ce 100644
--- a/mmengine/hooks/naive_visualization_hook.py
+++ b/mmengine/hooks/naive_visualization_hook.py
@@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook):
         unpad_image = input[:unpad_height, :unpad_width]
         return unpad_image
 
+    def before_train(self, runner) -> None:
+        """Call add_graph method of visualizer.
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        runner.visualizer.add_graph(runner.model, None)
+
     def after_test_iter(self,
                         runner,
                         batch_idx: int,
diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py
index 3f2f254c..50878f1e 100644
--- a/mmengine/visualization/vis_backend.py
+++ b/mmengine/visualization/vis_backend.py
@@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend):
             `wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_
             for details. Defaults to None.
             New in version 0.3.0.
+        watch_kwargs (optional, dict): Agurments for ``wandb.watch``.
+            New in version 0.4.0.
     """
 
     def __init__(self,
@@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend):
                  init_kwargs: Optional[dict] = None,
                  define_metric_cfg: Optional[dict] = None,
                  commit: Optional[bool] = True,
-                 log_code_name: Optional[str] = None):
+                 log_code_name: Optional[str] = None,
+                 watch_kwargs: Optional[dict] = None):
         super().__init__(save_dir)
         self._init_kwargs = init_kwargs
         self._define_metric_cfg = define_metric_cfg
         self._commit = commit
         self._log_code_name = log_code_name
+        self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {}
 
     def _init_env(self):
         """Setup env for wandb."""
@@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend):
         self._wandb.config.update(dict(config))
         self._wandb.run.log_code(name=self._log_code_name)
 
+    @force_init_env
+    def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
+                  **kwargs) -> None:
+        """Record the model graph.
+
+        Args:
+            model (torch.nn.Module): Model to draw.
+            data_batch (Sequence[dict]): Batch of data from dataloader.
+        """
+        self._wandb.watch(model, **self._watch_kwargs)
+
     @force_init_env
     def add_image(self,
                   name: str,
-- 
GitLab