diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py
index 63518c23c306bef4c69388ce206d2dfa2c0861e0..fa3e9993436760f8d576e34989d95bb40b20487f 100644
--- a/tests/test_model/test_wrappers/test_data_parallel.py
+++ b/tests/test_model/test_wrappers/test_data_parallel.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
 from unittest.mock import MagicMock, patch
 
 import pytest
@@ -57,7 +58,7 @@ def test_is_model_wrapper():
     assert is_model_wrapper(model_wrapper)
 
 
-class TestMMDataParallel:
+class TestMMDataParallel(TestCase):
 
     def setUp(self):
         """Setup the demo image in every test method.
@@ -70,7 +71,7 @@ class TestMMDataParallel:
 
             def __init__(self):
                 super().__init__()
-                self.conv = nn.Conv2d(2, 2, 1)
+                self.conv = nn.Conv2d(1, 2, 1)
 
             def forward(self, x):
                 return self.conv(x)
@@ -101,7 +102,7 @@ class TestMMDataParallel:
         with pytest.raises(AssertionError):
             mmdp.train_step(torch.zeros([1, 1, 3, 3]))
 
-        out = self.model.train_step([torch.zeros([1, 1, 3, 3])])
+        out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
         assert out.shape == (1, 2, 3, 3)
 
     def test_val_step(self):
@@ -122,5 +123,5 @@ class TestMMDataParallel:
         with pytest.raises(AssertionError):
             mmdp.val_step(torch.zeros([1, 1, 3, 3]))
 
-        out = self.model.val_step([torch.zeros([1, 1, 3, 3])])
+        out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
         assert out.shape == (1, 2, 3, 3)