diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index 2cf995e74b447b0ec4e3d8117a2a2115ca6dd935..b91756f651115cabaed544de851c5df27b0ff65a 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -271,6 +271,7 @@ class RepeatDataset: 'dataset first and then use `RepeatDataset`.') +@DATASETS.register_module() class ClassBalancedDataset: """A wrapper of class balanced dataset. @@ -398,10 +399,11 @@ class ClassBalancedDataset: repeat_factors = [] for idx in range(num_images): cat_ids = set(self.dataset.get_cat_ids(idx)) - repeat_factor = max( - {category_repeat[cat_id] - for cat_id in cat_ids}) - repeat_factors.append(repeat_factor) + if len(cat_ids) != 0: + repeat_factor = max( + {category_repeat[cat_id] + for cat_id in cat_ids}) + repeat_factors.append(repeat_factor) return repeat_factors