From 792f481efe2f06e491bcb68625b3fb7a302650f5 Mon Sep 17 00:00:00 2001
From: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Date: Fri, 8 Jul 2022 14:51:51 +0800
Subject: [PATCH] [Fix]fix ClassBalancedDataset (#354)

* fix cbds

* fix
---
 mmengine/dataset/dataset_wrapper.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py
index 2cf995e7..b91756f6 100644
--- a/mmengine/dataset/dataset_wrapper.py
+++ b/mmengine/dataset/dataset_wrapper.py
@@ -271,6 +271,7 @@ class RepeatDataset:
             'dataset first and then use `RepeatDataset`.')
 
 
+@DATASETS.register_module()
 class ClassBalancedDataset:
     """A wrapper of class balanced dataset.
 
@@ -398,10 +399,11 @@ class ClassBalancedDataset:
         repeat_factors = []
         for idx in range(num_images):
             cat_ids = set(self.dataset.get_cat_ids(idx))
-            repeat_factor = max(
-                {category_repeat[cat_id]
-                 for cat_id in cat_ids})
-            repeat_factors.append(repeat_factor)
+            if len(cat_ids) != 0:
+                repeat_factor = max(
+                    {category_repeat[cat_id]
+                     for cat_id in cat_ids})
+                repeat_factors.append(repeat_factor)
 
         return repeat_factors
 
-- 
GitLab