Skip to content
Snippets Groups Projects
Unverified Commit 792f481e authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

[Fix]fix ClassBalancedDataset (#354)

* fix cbds

* fix
parent b2ee9f8b
No related branches found
No related tags found
No related merge requests found
...@@ -271,6 +271,7 @@ class RepeatDataset: ...@@ -271,6 +271,7 @@ class RepeatDataset:
'dataset first and then use `RepeatDataset`.') 'dataset first and then use `RepeatDataset`.')
@DATASETS.register_module()
class ClassBalancedDataset: class ClassBalancedDataset:
"""A wrapper of class balanced dataset. """A wrapper of class balanced dataset.
...@@ -398,10 +399,11 @@ class ClassBalancedDataset: ...@@ -398,10 +399,11 @@ class ClassBalancedDataset:
repeat_factors = [] repeat_factors = []
for idx in range(num_images): for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx)) cat_ids = set(self.dataset.get_cat_ids(idx))
repeat_factor = max( if len(cat_ids) != 0:
{category_repeat[cat_id] repeat_factor = max(
for cat_id in cat_ids}) {category_repeat[cat_id]
repeat_factors.append(repeat_factor) for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
return repeat_factors return repeat_factors
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment