diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py index 63518c23c306bef4c69388ce206d2dfa2c0861e0..fa3e9993436760f8d576e34989d95bb40b20487f 100644 --- a/tests/test_model/test_wrappers/test_data_parallel.py +++ b/tests/test_model/test_wrappers/test_data_parallel.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase from unittest.mock import MagicMock, patch import pytest @@ -57,7 +58,7 @@ def test_is_model_wrapper(): assert is_model_wrapper(model_wrapper) -class TestMMDataParallel: +class TestMMDataParallel(TestCase): def setUp(self): """Setup the demo image in every test method. @@ -70,7 +71,7 @@ class TestMMDataParallel: def __init__(self): super().__init__() - self.conv = nn.Conv2d(2, 2, 1) + self.conv = nn.Conv2d(1, 2, 1) def forward(self, x): return self.conv(x) @@ -101,7 +102,7 @@ class TestMMDataParallel: with pytest.raises(AssertionError): mmdp.train_step(torch.zeros([1, 1, 3, 3])) - out = self.model.train_step([torch.zeros([1, 1, 3, 3])]) + out = self.model.train_step(torch.zeros([1, 1, 3, 3])) assert out.shape == (1, 2, 3, 3) def test_val_step(self): @@ -122,5 +123,5 @@ class TestMMDataParallel: with pytest.raises(AssertionError): mmdp.val_step(torch.zeros([1, 1, 3, 3])) - out = self.model.val_step([torch.zeros([1, 1, 3, 3])]) + out = self.model.val_step(torch.zeros([1, 1, 3, 3])) assert out.shape == (1, 2, 3, 3)