From 5b065b10fde2f625c945a52c1bbcbe6869baef2b Mon Sep 17 00:00:00 2001
From: Tong Gao <gaotongxiao@gmail.com>
Date: Wed, 20 Jul 2022 16:04:24 +0800
Subject: [PATCH] [Enhance] Support Compose(None) (#373)

* [Enhance] Allow Compose(None)

* add typehint

* fix
---
 mmengine/dataset/base_dataset.py     | 9 ++++++---
 tests/test_data/test_base_dataset.py | 7 +++++++
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py
index 6f3b7517..3aa56475 100644
--- a/mmengine/dataset/base_dataset.py
+++ b/mmengine/dataset/base_dataset.py
@@ -19,13 +19,16 @@ class Compose:
     """Compose multiple transforms sequentially.
 
     Args:
-        transforms (Sequence[dict, callable]): Sequence of transform object or
-            config dict to be composed.
+        transforms (Sequence[dict, callable], optional): Sequence of transform
+            object or config dict to be composed.
     """
 
-    def __init__(self, transforms: Sequence[Union[dict, Callable]]):
+    def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]):
         self.transforms: List[Callable] = []
 
+        if transforms is None:
+            transforms = []
+
         for transform in transforms:
             # `Compose` can be built with config dict with type and
             # corresponding arguments.
diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py
index 12e637d5..e540f527 100644
--- a/tests/test_data/test_base_dataset.py
+++ b/tests/test_data/test_base_dataset.py
@@ -318,6 +318,13 @@ class TestBaseDataset:
         with pytest.raises(TypeError):
             Compose([1])
 
+        # when the input transform is None, do nothing
+        compose = Compose(None)
+        assert (compose(dict(img=self.imgs))['img'] == self.imgs).all()
+
+        compose = Compose([])
+        assert (compose(dict(img=self.imgs))['img'] == self.imgs).all()
+
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_getitem(self, lazy_init):
         dataset = BaseDataset(
-- 
GitLab