Skip to content
Snippets Groups Projects
Unverified Commit 89146a5b authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix detect_anomalous_params (#588)

parent 1f63d243
No related branches found
No related tags found
No related merge requests found
...@@ -119,10 +119,10 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -119,10 +119,10 @@ class MMDistributedDataParallel(DistributedDataParallel):
with optim_wrapper.optim_context(self): with optim_wrapper.optim_context(self):
data = self.module.data_preprocessor(data, training=True) data = self.module.data_preprocessor(data, training=True)
losses = self._run_forward(data, mode='loss') losses = self._run_forward(data, mode='loss')
if self.detect_anomalous_params:
detect_anomalous_params(losses, model=self)
parsed_loss, log_vars = self.module.parse_losses(losses) parsed_loss, log_vars = self.module.parse_losses(losses)
optim_wrapper.update_params(parsed_loss) optim_wrapper.update_params(parsed_loss)
if self.detect_anomalous_params:
detect_anomalous_params(parsed_loss, model=self)
return log_vars return log_vars
def val_step(self, data: Union[dict, tuple, list]) -> list: def val_step(self, data: Union[dict, tuple, list]) -> list:
......
...@@ -109,6 +109,16 @@ class TestDistributedDataParallel(MultiProcessTestCase): ...@@ -109,6 +109,16 @@ class TestDistributedDataParallel(MultiProcessTestCase):
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0])) assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0])) assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
# Test enable detect_anomalous_params.
ddp_model = MMDistributedDataParallel(
module=model, detect_anomalous_params=True)
optimizer = SGD(ddp_model.parameters(), lr=0)
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, accumulative_counts=3)
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss']
def test_val_step(self): def test_val_step(self):
self._init_dist_env(self.rank, self.world_size) self._init_dist_env(self.rank, self.world_size)
model = ToyModel() model = ToyModel()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment