diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index 6a6c3f38b3be6f30d3dd0aa96d8c473909764f36..079a81ce05bbffeb658f8d670fe7f17b1c17617e 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -49,6 +49,14 @@ class NaiveVisualizationHook(Hook): unpad_image = input[:unpad_height, :unpad_width] return unpad_image + def before_train(self, runner) -> None: + """Call add_graph method of visualizer. + + Args: + runner (Runner): The runner of the training process. + """ + runner.visualizer.add_graph(runner.model, None) + def after_test_iter(self, runner, batch_idx: int, diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index 3f2f254ce256e7d4e6d1afde562913b693ab675a..50878f1e7044edaca6fe70d8ebea8c4220c37a46 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -360,6 +360,8 @@ class WandbVisBackend(BaseVisBackend): `wandb docs <https://docs.wandb.ai/ref/python/run#log_code>`_ for details. Defaults to None. New in version 0.3.0. + watch_kwargs (optional, dict): Agurments for ``wandb.watch``. + New in version 0.4.0. """ def __init__(self, @@ -367,12 +369,14 @@ class WandbVisBackend(BaseVisBackend): init_kwargs: Optional[dict] = None, define_metric_cfg: Optional[dict] = None, commit: Optional[bool] = True, - log_code_name: Optional[str] = None): + log_code_name: Optional[str] = None, + watch_kwargs: Optional[dict] = None): super().__init__(save_dir) self._init_kwargs = init_kwargs self._define_metric_cfg = define_metric_cfg self._commit = commit self._log_code_name = log_code_name + self._watch_kwargs = watch_kwargs if watch_kwargs is not None else {} def _init_env(self): """Setup env for wandb.""" @@ -415,6 +419,17 @@ class WandbVisBackend(BaseVisBackend): self._wandb.config.update(dict(config)) self._wandb.run.log_code(name=self._log_code_name) + @force_init_env + def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], + **kwargs) -> None: + """Record the model graph. + + Args: + model (torch.nn.Module): Model to draw. + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self._wandb.watch(model, **self._watch_kwargs) + @force_init_env def add_image(self, name: str,