diff --git a/mmengine/utils/visualize.py b/mmengine/utils/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..f3361e1d50a4dafb8518d6bbd66f9131b441bd80 --- /dev/null +++ b/mmengine/utils/visualize.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import torch +import torch.nn as nn + +from mmengine.model import BaseModel +from mmengine.registry import MODELS + + +@MODELS.register_module() +class ToyModel(BaseModel): + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, *args, **kwargs): + return {'loss': torch.tensor(0.0)} + + +def update_params_step(self, loss): + pass + + +def runtimeinfo_step(self, runner, batch_idx, data_batch=None): + runner.message_hub.update_info('iter', runner.iter) + lr_dict = runner.optim_wrapper.get_lr() + for name, lr in lr_dict.items(): + runner.message_hub.update_scalar(f'train/{name}', lr[0]) + + momentum_dict = runner.optim_wrapper.get_momentum() + for name, momentum in momentum_dict.items(): + runner.message_hub.update_scalar(f'train/{name}', momentum[0]) + + +@patch('mmengine.optim.optimizer.OptimWrapper.update_params', + update_params_step) +@patch('mmengine.hooks.RuntimeInfoHook.before_train_iter', runtimeinfo_step) +def fake_run(cfg): + from mmengine.runner import Runner + cfg.pop('model') + cfg.pop('visualizer') + cfg.pop('val_dataloader') + cfg.pop('val_evaluator') + cfg.pop('val_cfg') + cfg.pop('test_dataloader') + cfg.pop('test_evaluator') + cfg.pop('test_cfg') + extra_cfg = dict( + model=dict(type='ToyModel'), + visualizer=dict( + type='Visualizer', + vis_backends=[ + dict(type='TensorboardVisBackend', save_dir='temp_dir') + ]), + ) + cfg.merge_from_dict(extra_cfg) + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train()