From fd84c210e551ebd2ac149cfcc82d443550a08f65 Mon Sep 17 00:00:00 2001 From: BigDong <yudongwang@tju.edu.cn> Date: Tue, 4 Apr 2023 14:27:27 +0800 Subject: [PATCH] [Fix] Update the ClassBalancedDataset logic to keep len(repeat_factors) = len(dataset) (#1048) --- mmengine/dataset/dataset_wrapper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index 78cdd9a7..49b630d8 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -429,12 +429,16 @@ class ClassBalancedDataset: # r(I) = max_{c in L(I)} r(c) repeat_factors = [] for idx in range(num_images): + # the length of `repeat_factors` need equal to the length of + # dataset. Hence, if the `cat_ids` is empty, + # the repeat_factor should be 1. + repeat_factor: float = 1. cat_ids = set(self.dataset.get_cat_ids(idx)) if len(cat_ids) != 0: repeat_factor = max( {category_repeat[cat_id] for cat_id in cat_ids}) - repeat_factors.append(repeat_factor) + repeat_factors.append(repeat_factor) return repeat_factors -- GitLab