Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from unittest import TestCase
from unittest.mock import MagicMock
import torch
import torch.distributed as torch_dist
import torch.nn as nn
from parameterized import parameterized
from torch.cuda.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer
from mmengine.dist import all_gather
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
is_apex_available = False
try:
import apex.amp as apex_amp
is_apex_available = True
except ImportError:
pass
amp_valid_dtypes = ['float64', 'float32', 'float16', 'bfloat16', None]
torch_dtypes = [
torch.float16 if dtype is None else getattr(torch, dtype)
for dtype in amp_valid_dtypes
]
def bf16_supported() -> bool:
return (hasattr(torch.cuda, 'is_bf16_supported')
and torch.cuda.is_bf16_supported())
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class ToyModel2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
class TestOptimWrapper(MultiProcessTestCase):
# Test `OptimWrapper.optim_context` will block the gradient
# synchronization when using gradient accumulation strategy in distributed
# data parallel training.
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def run_test(self, test_name: str, parent_pipe) -> None:
self.model = ToyModel()
self.optimizer = SGD(self.model.parameters(), lr=0.1)
self.logger = MMLogger.get_instance('test_optim_wrapper')
self.message_hub = MessageHub.get_instance('test_optim_wrapper_init')
super().run_test(test_name, parent_pipe)
def test_init(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertIs(optim_wrapper.optimizer, self.optimizer)
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
self.assertEqual(optim_wrapper._accumulative_counts, 1)
self.assertIs(optim_wrapper.message_hub, self.message_hub)
self.assertEqual(optim_wrapper._inner_count, 0)
self.assertEqual(optim_wrapper._max_counts, -1)
self.assertEqual(optim_wrapper._remainder_counts, -1)
with self.assertRaisesRegex(AssertionError,
'If `clip_grad` is not None'):
OptimWrapper(self.optimizer, clip_grad=[])
def test_update_params(self):
# Test update params every iteration.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1)
loss = torch.tensor(1.)
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
optim_wrapper.step.assert_called_with()
optim_wrapper.zero_grad.assert_called_with()
# Test gradient accumulation.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
self._mock_method(optim_wrapper)
# `iter=0`, accumulate gradient and do not update params.
loss = torch.tensor(1.)
optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
# gradient accumulate
optim_wrapper.update_params(loss)
self.assertEqual(optim_wrapper._inner_count, 2.)
# `iter=2`, update params.
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
# Test end of training without calling `initialize_iter_status`
optim_wrapper._inner_count = 99
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.)
# After calling `initialize_iter_status`, params will be updated at the
# last iteration, and the `loss_scaler` will be adjusted.
optim_wrapper.initialize_count_status(self.model, 99, 100)
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
self._mock_method(optim_wrapper)
# optim_wrapper.step should not be called at iteration 97 98, and the
# loss factor should be 3 at iteration 99.
optim_wrapper.initialize_count_status(self.model, 96, 100)
for _ in range(2):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3)
def test_initialize_iter_status(self):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper.initialize_count_status(self.model, 0, 100)
self.assertEqual(optim_wrapper._remainder_counts, 1)
# Indivisible cur_iter will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
with self.assertLogs(self.logger) as cm:
optim_wrapper.initialize_count_status(self.model, 2, 100)
self.assertRegex(cm.records[0].msg, 'Resumed iteration number')
# Model with batch norm will output warning.
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
model = nn.BatchNorm2d(1)
with self.assertLogs(self.logger) as cm:
optim_wrapper.initialize_count_status(model, 0, 99)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
self.assertEqual(len(cm.output), 1)
self.assertRegex(cm.records[0].msg, 'Gradient accumulative')
def test_ger_lr(self):
model = ToyModel()
optim = SGD(model.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))
def test_get_momentum(self):
# Get momentum from SGD
model = ToyModel()
optim = SGD(model.parameters(), lr=0., momentum=0.8)
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.8]))
# Get momentum from Adam
optim = Adam(model.parameters(), lr=0., betas=(0.9, 0.9))
optim_wrapper = OptimWrapper(optim)
self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.9]))
def test_backward(self):
loss = MagicMock()
optim_wrapper = OptimWrapper(self.optimizer)
optim_wrapper.backward(loss)
loss.backward.assert_called()
def test_zero_grad(self):
optimizer = MagicMock(spec=Optimizer)
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper.zero_grad()
optimizer.zero_grad.assert_called()
def test_step(self):
optimizer = MagicMock(spec=Optimizer)
optim_wrapper = OptimWrapper(optimizer)
optim_wrapper.step()
optimizer.step.assert_called()
# TODO: This unit test could cause CI to fail with some probability, which
# is caused by MultiProcessTestCase. This problem should be solved
# in the future).
@unittest.skipIf(True, reason='Solved in the future')
# Test `clip_grad` with `clip_norm_`
optim_wrapper = OptimWrapper(
self.optimizer, clip_grad=dict(max_norm=35))
loss = self.model(torch.Tensor(1, 1, 1, 1))
loss.backward()
optim_wrapper._clip_grad()
log_scalars = self.message_hub.log_scalars
self.assertIn('train/grad_norm', log_scalars)
self.message_hub._log_scalars.clear()
# Test `clip_grad` with `clip_value_`
optim_wrapper = OptimWrapper(
self.optimizer, clip_grad=dict(type='value', clip_value=0.5))
loss = self.model(torch.Tensor(1, 1, 1, 1))
loss.backward()
optim_wrapper._clip_grad()
self.assertNotIn('train/grad_norm', log_scalars)
def test_state_dict(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.state_dict(),
self.optimizer.state_dict())
def test_load_state_dict(self):
optim_wrapper = OptimWrapper(self.optimizer)
model = ToyModel()
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())
def test_param_groups(self):
optim_wrapper = OptimWrapper(self.optimizer)
self.assertEqual(optim_wrapper.param_groups,
self.optimizer.param_groups)
def test_optim_context(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel2()
ddp_model = DistributedDataParallel(model)
optimizer = SGD(ddp_model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
# Automatically sync grads if `accumulative_counts` = 1
optim_wrapper.initialize_count_status(model, 0, 100)
inputs = torch.randn(1, 1, 1, 1) * self.rank
ddp_model(inputs).sum().backward()
grad = model.conv.weight.grad
all_grads = all_gather(grad)
assert_allclose(all_grads[0], all_grads[1])
# Do not sync grads when `optim_wrapper.cur_iter` cannot be
# divided by `optim_wrapper._accumulative_counts`
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=3)
optim_wrapper.initialize_count_status(model, 0, 100)
with optim_wrapper.optim_context(ddp_model):
loss = ddp_model(inputs).sum()
loss.backward()
all_grads = all_gather(model.conv.weight.grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
# sync grads if `cur_iter == 2`
optim_wrapper.initialize_count_status(model, 2, 100)
with optim_wrapper.optim_context(ddp_model):
loss = ddp_model(inputs).sum()
loss.backward()
all_grads = all_gather(model.conv.weight.grad)
assert_allclose(all_grads[0], all_grads[1])
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29515'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
# TODO Test the real interface after add testing tool function which can
# test the function or method is read called.
def _mock_method(self, optim_wrapper):
def mock_methd(loss):
optim_wrapper._inner_count += 1
optim_wrapper.scaled_loss = loss
optim_wrapper.backward = mock_methd
optim_wrapper.step = MagicMock()
optim_wrapper.zero_grad = MagicMock()
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex')
class TestApexOptimWrapper(TestCase):
def setUp(self) -> None:
self.model = ToyModel().cuda()
self.optimizer = SGD(self.model.parameters(), lr=0.1)
@unittest.skipIf(
not is_apex_available,
reason='`apex` is not available, Please install apex from '
'https://www.github.com/nvidia/apex')
def test_init(self):
apex_optim_wrapper = ApexOptimWrapper(
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
pass
@unittest.skipIf(
not is_apex_available,
reason='`apex` is not available, Please install apex from '
'https://www.github.com/nvidia/apex')
def test_step(self):
optimizer = MagicMock(spec=Optimizer)
apex_optim_wrapper = ApexOptimWrapper(
optimizer=optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
apex_optim_wrapper.backward(loss)
apex_optim_wrapper.step()
@unittest.skipIf(
not is_apex_available,
reason='`apex` is not available, Please install apex from '
'https://www.github.com/nvidia/apex')
def test_backward(self):
apex_optim_wrapper = ApexOptimWrapper(
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
apex_optim_wrapper.backward(loss)
@unittest.skipIf(
not is_apex_available,
reason='`apex` is not available, Please install apex from '
'https://www.github.com/nvidia/apex')
def test_state_dict(self):
apex_optim_wrapper = ApexOptimWrapper(
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
apex_optim_wrapper.update_params(loss)
state_dict = apex_optim_wrapper.state_dict()
amp_state_dict = state_dict.pop('apex_amp')
optim_state_dict = state_dict
self.assertDictEqual(optim_state_dict,
apex_optim_wrapper.optimizer.state_dict())
self.assertDictEqual(amp_state_dict, apex_amp.state_dict())
@unittest.skipIf(
not is_apex_available,
reason='`apex` is not available, Please install apex from '
'https://www.github.com/nvidia/apex')
def test_load_state_dict(self):
apex_optim_wrapper = ApexOptimWrapper(
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
# Test load from optimizer
optimizer = SGD(self.model.parameters(), lr=0.1)
apex_optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertDictEqual(optimizer.state_dict(),
apex_optim_wrapper.optimizer.state_dict())
# Test load from optim_wrapper
apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer)
apex_optim_wrapper_ = ApexOptimWrapper(
optimizer=SGD(self.model.parameters(), lr=0.1))
apex_optim_wrapper_.load_state_dict(
apex_optim_wrapper.state_dict())
self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(),
apex_optim_wrapper_.optimizer.state_dict())
@unittest.skipIf(
not is_apex_available,
reason='`apex` is not available, Please install apex from '
'https://www.github.com/nvidia/apex')
def test_optim_context(self):
apex_optim_wrapper = ApexOptimWrapper(
optimizer=self.optimizer, opt_level='O1', loss_scale=1)
with apex_optim_wrapper.optim_context(self.model):
x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.float16)
class TestAmpOptimWrapper(TestCase):
def setUp(self) -> None:
self.model = ToyModel()
self.optimizer = SGD(self.model.parameters(), lr=0.1)
@unittest.skipIf(
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_init(self):
# Test with default arguments.
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
# Test with dynamic.
amp_optim_wrapper = AmpOptimWrapper(
'dynamic', optimizer=self.optimizer)
self.assertIsNone(amp_optim_wrapper._scale_update_param)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
# Test with dtype float16
amp_optim_wrapper = AmpOptimWrapper(
dtype='float16', optimizer=self.optimizer)
self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16)
# Test with dtype bfloat16
amp_optim_wrapper = AmpOptimWrapper(
dtype='bfloat16', optimizer=self.optimizer)
self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16)
# Test with dict loss_scale.
amp_optim_wrapper = AmpOptimWrapper(
dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
self.assertIsNone(amp_optim_wrapper._scale_update_param)
with self.assertRaisesRegex(TypeError,
'loss_scale must be of type float'):
AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown')
@parameterized.expand(list(zip(amp_valid_dtypes)))
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_step(self, dtype):
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
amp_optim_wrapper.loss_scaler = MagicMock()
amp_optim_wrapper.step()
amp_optim_wrapper.loss_scaler.step.assert_called_with(
amp_optim_wrapper.optimizer)
amp_optim_wrapper.loss_scaler.update.assert_called_with(
amp_optim_wrapper._scale_update_param)
@parameterized.expand(list(zip(amp_valid_dtypes)))
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_backward(self, dtype):
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype=dtype)
loss_scaler = MagicMock()
scale_return = MagicMock()
scale_fn = MagicMock(return_value=scale_return)
loss_scaler.scale = scale_fn
amp_optim_wrapper.loss_scaler = loss_scaler
amp_optim_wrapper.backward(1)
loss_scaler.scale.assert_called_with(1)
scale_return.backward.assert_called_with()
@unittest.skipIf(
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_state_dict(self):
self.model = self.model.cuda()
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
amp_optim_wrapper.update_params(loss)
state_dict = amp_optim_wrapper.state_dict()
scalar_state_dict = state_dict.pop('loss_scaler')
optim_state_dict = state_dict
self.assertDictEqual(optim_state_dict,
amp_optim_wrapper.optimizer.state_dict())
self.assertDictEqual(scalar_state_dict,
amp_optim_wrapper.loss_scaler.state_dict())
@unittest.skipIf(
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_load_state_dict(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
self.model = self.model.cuda()
# Test load from optimizer
optimizer = SGD(self.model.parameters(), lr=0.1)
amp_optim_wrapper.load_state_dict(optimizer.state_dict())
self.assertDictEqual(optimizer.state_dict(),
amp_optim_wrapper.optimizer.state_dict())
# Test load from optim_wrapper
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
amp_optim_wrapper_ = AmpOptimWrapper(
optimizer=SGD(self.model.parameters(), lr=0.1))
amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict())
self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(),
amp_optim_wrapper_.optimizer.state_dict())
self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(),
amp_optim_wrapper_.loss_scaler.state_dict())
@parameterized.expand(list(zip(amp_valid_dtypes, torch_dtypes)))
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_optim_context(self, dtype, target_dtype):
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype=dtype)
with amp_optim_wrapper.optim_context(self.model):
x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, target_dtype)