Skip to content
Snippets Groups Projects
Commit a9a57586 authored by RangiLyu's avatar RangiLyu Committed by Zaida Zhou
Browse files

[Fix] Fix offline_evaluate index error (#630)

* [Fix] Fix offline eval dataset index error.

* update

* update
parent f2b0540f
No related branches found
No related tags found
No related merge requests found
......@@ -104,12 +104,6 @@ class Evaluator:
"""
# support chunking iterable objects
if data is not None:
assert len(data_samples) == len(data), (
'outputs and data should have the same length, but got '
f'outputs length: {len(data_samples)} '
f'data length: {len(data)}')
def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
......@@ -123,10 +117,17 @@ class Evaluator:
if chunk:
yield chunk
if data is not None:
assert len(data_samples) == len(data), (
'data_samples and data should have the same length, but got '
f'data_samples length: {len(data_samples)} '
f'data length: {len(data)}')
data = get_chunks(iter(data), chunk_size)
size = 0
for output_chunk in get_chunks(iter(data_samples), chunk_size):
if data is not None:
data_chunk = pseudo_collate(data[size:size + chunk_size])
data_chunk = pseudo_collate(next(data)) # type: ignore
else:
data_chunk = None
size += len(output_chunk)
......
......@@ -247,7 +247,7 @@ class TestEvaluator(TestCase):
all_data = [dict() for _ in range(9)]
with self.assertRaisesRegex(
AssertionError,
'outputs and data should have the same length'):
'data_samples and data should have the same length'):
evaluator.offline_evaluate(all_predictions, all_data)
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment