From 5170676a2fa1b502315b4b9fbb9dfffaa0e34fd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Sun, 27 Feb 2022 20:34:52 +0800 Subject: [PATCH] fix mmdp unittest (#60) --- tests/test_model/test_wrappers/test_data_parallel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py index 63518c23..fa3e9993 100644 --- a/tests/test_model/test_wrappers/test_data_parallel.py +++ b/tests/test_model/test_wrappers/test_data_parallel.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase from unittest.mock import MagicMock, patch import pytest @@ -57,7 +58,7 @@ def test_is_model_wrapper(): assert is_model_wrapper(model_wrapper) -class TestMMDataParallel: +class TestMMDataParallel(TestCase): def setUp(self): """Setup the demo image in every test method. @@ -70,7 +71,7 @@ class TestMMDataParallel: def __init__(self): super().__init__() - self.conv = nn.Conv2d(2, 2, 1) + self.conv = nn.Conv2d(1, 2, 1) def forward(self, x): return self.conv(x) @@ -101,7 +102,7 @@ class TestMMDataParallel: with pytest.raises(AssertionError): mmdp.train_step(torch.zeros([1, 1, 3, 3])) - out = self.model.train_step([torch.zeros([1, 1, 3, 3])]) + out = self.model.train_step(torch.zeros([1, 1, 3, 3])) assert out.shape == (1, 2, 3, 3) def test_val_step(self): @@ -122,5 +123,5 @@ class TestMMDataParallel: with pytest.raises(AssertionError): mmdp.val_step(torch.zeros([1, 1, 3, 3])) - out = self.model.val_step([torch.zeros([1, 1, 3, 3])]) + out = self.model.val_step(torch.zeros([1, 1, 3, 3])) assert out.shape == (1, 2, 3, 3) -- GitLab