From 792f481efe2f06e491bcb68625b3fb7a302650f5 Mon Sep 17 00:00:00 2001 From: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Date: Fri, 8 Jul 2022 14:51:51 +0800 Subject: [PATCH] [Fix]fix ClassBalancedDataset (#354) * fix cbds * fix --- mmengine/dataset/dataset_wrapper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index 2cf995e7..b91756f6 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -271,6 +271,7 @@ class RepeatDataset: 'dataset first and then use `RepeatDataset`.') +@DATASETS.register_module() class ClassBalancedDataset: """A wrapper of class balanced dataset. @@ -398,10 +399,11 @@ class ClassBalancedDataset: repeat_factors = [] for idx in range(num_images): cat_ids = set(self.dataset.get_cat_ids(idx)) - repeat_factor = max( - {category_repeat[cat_id] - for cat_id in cat_ids}) - repeat_factors.append(repeat_factor) + if len(cat_ids) != 0: + repeat_factor = max( + {category_repeat[cat_id] + for cat_id in cat_ids}) + repeat_factors.append(repeat_factor) return repeat_factors -- GitLab