Skip to content
Snippets Groups Projects
Unverified Commit 2d3e9124 authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Add sampler seed hook (#64)

* [Feature]: Add sampler seed hook

* [Fix]: Add call with to UT
parent 1244e486
No related branches found
No related tags found
No related merge requests found
......@@ -97,7 +97,7 @@ import numpy as np
@EVALUATORS.register_module()
class Accuracy(BaseEvaluator):
def process(self, data_samples: Dict, predictions: Dict):
"""Process one batch of data and predictions. The processed
Results should be stored in `self.results`, which will be used
......
......@@ -276,7 +276,7 @@ class ModuleA:
class ModuleB:
def __init__(self):
self.instance = GlobalAccessible.get_instance(current=True)
def run(self):
print(f'moduleB: {self.instance.instance_name} is called')
......
# Copyright (c) OpenMMLab. All rights reserved.
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .sampler_seed_hook import DistSamplerSeedHook
__all__ = ['Hook', 'IterTimerHook']
__all__ = ['Hook', 'IterTimerHook', 'DistSamplerSeedHook']
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import HOOKS
from .hook import Hook
@HOOKS.register_module()
class DistSamplerSeedHook(Hook):
"""Data-loading sampler for distributed training.
When distributed training, it is only useful in conjunction with
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
purpose with :obj:`IterLoader`.
"""
def before_epoch(self, runner: object) -> None:
"""Set the seed for sampler and batch_sampler.
Args:
runner (object): The runner of the training process.
"""
if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore
# in case the data loader uses `SequentialSampler` in Pytorch
runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore
elif hasattr(
runner.data_loader.batch_sampler.sampler, # type: ignore
'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes.
runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore
runner.epoch) # type: ignore
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from mmengine.hooks import DistSamplerSeedHook
class TestDistSamplerSeedHook:
def test_before_epoch(self):
hook = DistSamplerSeedHook()
# Test dataset sampler
runner = Mock()
runner.epoch = 1
runner.data_loader = Mock()
runner.data_loader.sampler = Mock()
runner.data_loader.sampler.set_epoch = Mock()
hook.before_epoch(runner)
runner.data_loader.sampler.set_epoch.assert_called()
# Test batch sampler
runner = Mock()
runner.data_loader = Mock()
runner.data_loader.sampler = Mock(spec_set=True)
runner.data_loader.batch_sampler = Mock()
runner.data_loader.batch_sampler.sampler = Mock()
runner.data_loader.batch_sampler.sampler.set_epoch = Mock()
hook.before_epoch(runner)
runner.data_loader.batch_sampler.sampler.set_epoch.assert_called()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment