Skip to content
Snippets Groups Projects
Unverified Commit fd84c210 authored by BigDong's avatar BigDong Committed by GitHub
Browse files

[Fix] Update the ClassBalancedDataset logic to keep len(repeat_factors) = len(dataset) (#1048)

parent 093068e4
No related branches found
No related tags found
No related merge requests found
...@@ -429,12 +429,16 @@ class ClassBalancedDataset: ...@@ -429,12 +429,16 @@ class ClassBalancedDataset:
# r(I) = max_{c in L(I)} r(c) # r(I) = max_{c in L(I)} r(c)
repeat_factors = [] repeat_factors = []
for idx in range(num_images): for idx in range(num_images):
# the length of `repeat_factors` need equal to the length of
# dataset. Hence, if the `cat_ids` is empty,
# the repeat_factor should be 1.
repeat_factor: float = 1.
cat_ids = set(self.dataset.get_cat_ids(idx)) cat_ids = set(self.dataset.get_cat_ids(idx))
if len(cat_ids) != 0: if len(cat_ids) != 0:
repeat_factor = max( repeat_factor = max(
{category_repeat[cat_id] {category_repeat[cat_id]
for cat_id in cat_ids}) for cat_id in cat_ids})
repeat_factors.append(repeat_factor) repeat_factors.append(repeat_factor)
return repeat_factors return repeat_factors
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment