Skip to content
Snippets Groups Projects
Unverified Commit 6ebb6f83 authored by Luo Yihang's avatar Luo Yihang Committed by GitHub
Browse files

[Fix] Call SyncBufferHook before validation in IterBasedTrainLoop (#982)


* [Fix] Call SyncBufferHook before validation in IterBasedTrainLoop

* Add before_val_epoch in SyncBuffersHook

* Fix white space format

* Add comments for SyncBuffersHook

* Add comments for SyncBuffersHook

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Add comments for SyncBuffersHook

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix white space format

* Add before_test_epoch

* Remove before_test_epoch

---------

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 0e5f9da6
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,24 @@ class SyncBuffersHook(Hook): ...@@ -13,6 +13,24 @@ class SyncBuffersHook(Hook):
def __init__(self) -> None: def __init__(self) -> None:
self.distributed = is_distributed() self.distributed = is_distributed()
# A flag to mark whether synchronization has been done in
# after_train_epoch
self.called_in_train = False
def before_val_epoch(self, runner) -> None:
"""All-reduce model buffers before each validation epoch.
Synchronize the buffers before each validation if they have not been
synchronized at the end of the previous training epoch. This method
will be called when using IterBasedTrainLoop.
Args:
runner (Runner): The runner of the training process.
"""
if self.distributed:
if not self.called_in_train:
all_reduce_params(runner.model.buffers(), op='mean')
self.called_in_train = False
def after_train_epoch(self, runner) -> None: def after_train_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch. """All-reduce model buffers at the end of each epoch.
...@@ -22,3 +40,4 @@ class SyncBuffersHook(Hook): ...@@ -22,3 +40,4 @@ class SyncBuffersHook(Hook):
""" """
if self.distributed: if self.distributed:
all_reduce_params(runner.model.buffers(), op='mean') all_reduce_params(runner.model.buffers(), op='mean')
self.called_in_train = True
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment