From bbe00274c8ce9e8c9fbdc28a148f9de09db89173 Mon Sep 17 00:00:00 2001 From: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Date: Mon, 27 Jun 2022 15:00:11 +0800 Subject: [PATCH] [Enhance] LR and Momentum Visualizer (#267) * impl lr and momentum visualizer * provide fakerun --- mmengine/utils/visualize.py | 63 +++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 mmengine/utils/visualize.py diff --git a/mmengine/utils/visualize.py b/mmengine/utils/visualize.py new file mode 100644 index 00000000..f3361e1d --- /dev/null +++ b/mmengine/utils/visualize.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import torch +import torch.nn as nn + +from mmengine.model import BaseModel +from mmengine.registry import MODELS + + +@MODELS.register_module() +class ToyModel(BaseModel): + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, *args, **kwargs): + return {'loss': torch.tensor(0.0)} + + +def update_params_step(self, loss): + pass + + +def runtimeinfo_step(self, runner, batch_idx, data_batch=None): + runner.message_hub.update_info('iter', runner.iter) + lr_dict = runner.optim_wrapper.get_lr() + for name, lr in lr_dict.items(): + runner.message_hub.update_scalar(f'train/{name}', lr[0]) + + momentum_dict = runner.optim_wrapper.get_momentum() + for name, momentum in momentum_dict.items(): + runner.message_hub.update_scalar(f'train/{name}', momentum[0]) + + +@patch('mmengine.optim.optimizer.OptimWrapper.update_params', + update_params_step) +@patch('mmengine.hooks.RuntimeInfoHook.before_train_iter', runtimeinfo_step) +def fake_run(cfg): + from mmengine.runner import Runner + cfg.pop('model') + cfg.pop('visualizer') + cfg.pop('val_dataloader') + cfg.pop('val_evaluator') + cfg.pop('val_cfg') + cfg.pop('test_dataloader') + cfg.pop('test_evaluator') + cfg.pop('test_cfg') + extra_cfg = dict( + model=dict(type='ToyModel'), + visualizer=dict( + type='Visualizer', + vis_backends=[ + dict(type='TensorboardVisBackend', save_dir='temp_dir') + ]), + ) + cfg.merge_from_dict(extra_cfg) + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train() -- GitLab