From fd84c210e551ebd2ac149cfcc82d443550a08f65 Mon Sep 17 00:00:00 2001
From: BigDong <yudongwang@tju.edu.cn>
Date: Tue, 4 Apr 2023 14:27:27 +0800
Subject: [PATCH] [Fix] Update the ClassBalancedDataset logic to keep
 len(repeat_factors) = len(dataset) (#1048)

---
 mmengine/dataset/dataset_wrapper.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py
index 78cdd9a7..49b630d8 100644
--- a/mmengine/dataset/dataset_wrapper.py
+++ b/mmengine/dataset/dataset_wrapper.py
@@ -429,12 +429,16 @@ class ClassBalancedDataset:
         #    r(I) = max_{c in L(I)} r(c)
         repeat_factors = []
         for idx in range(num_images):
+            # the length of `repeat_factors` need equal to the length of
+            # dataset. Hence, if the `cat_ids` is empty,
+            # the repeat_factor should be 1.
+            repeat_factor: float = 1.
             cat_ids = set(self.dataset.get_cat_ids(idx))
             if len(cat_ids) != 0:
                 repeat_factor = max(
                     {category_repeat[cat_id]
                      for cat_id in cat_ids})
-                repeat_factors.append(repeat_factor)
+            repeat_factors.append(repeat_factor)
 
         return repeat_factors
 
-- 
GitLab