Skip to content
Snippets Groups Projects
Unverified Commit b7aa4dd8 authored by shenmishajing's avatar shenmishajing Committed by GitHub
Browse files

[Fix]: fix add graph function is not called bug in visualization hooks (#632)


* fix add graph func is not called bug

* move add graph call to NaiveVisualizationHook.before_train

* Update mmengine/hooks/naive_visualization_hook.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* adjust param sequence and add docstring

* minor refine

* Update mmengine/visualization/vis_backend.py

* update version info

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
parent 9d5b417f
No related branches found
No related tags found
No related merge requests found
......@@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook):
unpad_image = input[:unpad_height, :unpad_width]
return unpad_image
def before_train(self, runner) -> None:
"""Call add_graph method of visualizer.
Args:
runner (Runner): The runner of the training process.
"""
runner.visualizer.add_graph(runner.model, None)
def after_test_iter(self,
runner,
batch_idx: int,
......
......@@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend):
`wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_
for details. Defaults to None.
New in version 0.3.0.
watch_kwargs (optional, dict): Agurments for ``wandb.watch``.
New in version 0.4.0.
"""
def __init__(self,
......@@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend):
init_kwargs: Optional[dict] = None,
define_metric_cfg: Optional[dict] = None,
commit: Optional[bool] = True,
log_code_name: Optional[str] = None):
log_code_name: Optional[str] = None,
watch_kwargs: Optional[dict] = None):
super().__init__(save_dir)
self._init_kwargs = init_kwargs
self._define_metric_cfg = define_metric_cfg
self._commit = commit
self._log_code_name = log_code_name
self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {}
def _init_env(self):
"""Setup env for wandb."""
......@@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.config.update(dict(config))
self._wandb.run.log_code(name=self._log_code_name)
@force_init_env
def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
**kwargs) -> None:
"""Record the model graph.
Args:
model (torch.nn.Module): Model to draw.
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self._wandb.watch(model, **self._watch_kwargs)
@force_init_env
def add_image(self,
name: str,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment