Skip to content
Snippets Groups Projects
test_naive_visualization_hook.py 2.93 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock

import torch

from mmengine.data import BaseDataElement
from mmengine.hooks import NaiveVisualizationHook


class TestNaiveVisualizationHook:

    def test_after_train_iter(self):
        naive_visualization_hook = NaiveVisualizationHook()
        runner = Mock(iter=1)
        runner.visualizer.add_image = Mock()
        inputs = torch.randn(1, 3, 15, 15)
        batch_idx = 10
        # test with normalize, resize, pad
        gt_datasamples = BaseDataElement(
            metainfo=dict(
                img_norm_cfg=dict(
                    mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
                scale=(10, 10),
                pad_shape=(15, 15, 3),
                ori_height=5,
                ori_width=5,
                img_path='tmp.jpg'))
        pred_datasamples = [BaseDataElement()]
        data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)
        # test with resize, pad
        gt_datasamples = BaseDataElement(
            metainfo=dict(
                scale=(10, 10),
                pad_shape=(15, 15, 3),
                ori_height=5,
                ori_width=5,
                img_path='tmp.jpg'))
        pred_datasamples = [BaseDataElement()]
        data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)
        # test with only resize
        gt_datasamples = BaseDataElement(
            metainfo=dict(
                scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg'))
        pred_datasamples = [BaseDataElement()]
        data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)

        # test with only pad
        gt_datasamples = BaseDataElement(
            metainfo=dict(
                pad_shape=(15, 15, 3),
                ori_height=5,
                ori_width=5,
                img_path='tmp.jpg'))
        pred_datasamples = [BaseDataElement()]
        data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)

        # test no transform
        gt_datasamples = BaseDataElement(
            metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg'))
        pred_datasamples = [BaseDataElement()]
        data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
        naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
                                                 pred_datasamples)