From 1e79b97444b693df49f4ecaaa820a21f05016feb Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Fri, 25 Feb 2022 15:24:27 +0800 Subject: [PATCH] Mock unimplemented modules and fix unit tests (#54) * Mock unimplemented modules and fix unit tests * add a comment --- mmengine/data/sampler.py | 7 ++++++- pytest.ini | 4 +++- tests/test_optim/test_optimizer/test_optimizer.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mmengine/data/sampler.py b/mmengine/data/sampler.py index 3d891909..83936a27 100644 --- a/mmengine/data/sampler.py +++ b/mmengine/data/sampler.py @@ -2,13 +2,18 @@ import itertools import math from typing import Iterator, Optional, Sized +# from mmengine.dist import get_dist_info, sync_random_seed +from unittest.mock import MagicMock import torch from torch.utils.data import Sampler -from mmengine.dist import get_dist_info, sync_random_seed from mmengine.registry import DATA_SAMPLERS +# TODO, need to remove those lines after implementing dist module +get_dist_info = MagicMock(return_value=(0, 1)) +sync_random_seed = MagicMock(return_value=0) + @DATA_SAMPLERS.register_module() class DefaultSampler(Sampler[int]): diff --git a/pytest.ini b/pytest.ini index 24a58355..826e4d13 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] testpaths = tests -norecursedirs = tests/data +norecursedirs = + tests/data + tests/test_visualizer diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index e205c54f..24890f87 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -178,7 +178,9 @@ class TestBuilder(TestCase): assert sub_gn_bias['lr'] == self.base_lr assert sub_gn_bias['weight_decay'] == self.base_wd * norm_decay_mult - if torch.cuda.is_available(): + # test dcn which requires cuda is available and + # mmcv-full has been installed + if torch.cuda.is_available() and MMCV_FULL_AVAILABLE: dcn_conv_weight = param_groups[11] assert dcn_conv_weight['lr'] == self.base_lr assert dcn_conv_weight['weight_decay'] == self.base_wd -- GitLab