Skip to content
Snippets Groups Projects
Unverified Commit 5170676a authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

fix mmdp unittest (#60)

parent bc759e55
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
...@@ -57,7 +58,7 @@ def test_is_model_wrapper(): ...@@ -57,7 +58,7 @@ def test_is_model_wrapper():
assert is_model_wrapper(model_wrapper) assert is_model_wrapper(model_wrapper)
class TestMMDataParallel: class TestMMDataParallel(TestCase):
def setUp(self): def setUp(self):
"""Setup the demo image in every test method. """Setup the demo image in every test method.
...@@ -70,7 +71,7 @@ class TestMMDataParallel: ...@@ -70,7 +71,7 @@ class TestMMDataParallel:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = nn.Conv2d(2, 2, 1) self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
...@@ -101,7 +102,7 @@ class TestMMDataParallel: ...@@ -101,7 +102,7 @@ class TestMMDataParallel:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
mmdp.train_step(torch.zeros([1, 1, 3, 3])) mmdp.train_step(torch.zeros([1, 1, 3, 3]))
out = self.model.train_step([torch.zeros([1, 1, 3, 3])]) out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3) assert out.shape == (1, 2, 3, 3)
def test_val_step(self): def test_val_step(self):
...@@ -122,5 +123,5 @@ class TestMMDataParallel: ...@@ -122,5 +123,5 @@ class TestMMDataParallel:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
mmdp.val_step(torch.zeros([1, 1, 3, 3])) mmdp.val_step(torch.zeros([1, 1, 3, 3]))
out = self.model.val_step([torch.zeros([1, 1, 3, 3])]) out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3) assert out.shape == (1, 2, 3, 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment