From b7aa4dd8856c74c349e895be08faacef8e1762c7 Mon Sep 17 00:00:00 2001 From: shenmishajing <shenmishajing@Gmail.com> Date: Mon, 21 Nov 2022 11:52:48 +0800 Subject: [PATCH] [Fix]: fix add graph function is not called bug in visualization hooks (#632) * fix add graph func is not called bug * move add graph call to NaiveVisualizationHook.before_train * Update mmengine/hooks/naive_visualization_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * adjust param sequence and add docstring * minor refine * Update mmengine/visualization/vis_backend.py * update version info Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn> --- mmengine/hooks/naive_visualization_hook.py | 8 ++++++++ mmengine/visualization/vis_backend.py | 17 ++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index 6a6c3f38..079a81ce 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook): unpad_image = input[:unpad_height, :unpad_width] return unpad_image + def before_train(self, runner) -> None: + """Call add_graph method of visualizer. + + Args: + runner (Runner): The runner of the training process. + """ + runner.visualizer.add_graph(runner.model, None) + def after_test_iter(self, runner, batch_idx: int, diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index 3f2f254c..50878f1e 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend): `wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_ for details. Defaults to None. New in version 0.3.0. + watch_kwargs (optional, dict): Agurments for ``wandb.watch``. + New in version 0.4.0. """ def __init__(self, @@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend): init_kwargs: Optional[dict] = None, define_metric_cfg: Optional[dict] = None, commit: Optional[bool] = True, - log_code_name: Optional[str] = None): + log_code_name: Optional[str] = None, + watch_kwargs: Optional[dict] = None): super().__init__(save_dir) self._init_kwargs = init_kwargs self._define_metric_cfg = define_metric_cfg self._commit = commit self._log_code_name = log_code_name + self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {} def _init_env(self): """Setup env for wandb.""" @@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend): self._wandb.config.update(dict(config)) self._wandb.run.log_code(name=self._log_code_name) + @force_init_env + def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], + **kwargs) -> None: + """Record the model graph. + + Args: + model (torch.nn.Module): Model to draw. + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self._wandb.watch(model, **self._watch_kwargs) + @force_init_env def add_image(self, name: str, -- GitLab