Skip to content
Snippets Groups Projects
test_optimizer_wrapper.py 21.4 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from unittest import TestCase
from unittest.mock import MagicMock

import torch
import torch.distributed as torch_dist
import torch.nn as nn
from parameterized import parameterized
from torch.cuda.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer

from mmengine.dist import all_gather
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase

is_apex_available = False
try:
    import apex.amp as apex_amp
    is_apex_available = True
except ImportError:
    pass

amp_valid_dtypes = ['float64', 'float32', 'float16', 'bfloat16', None]
torch_dtypes = [
    torch.float16 if dtype is None else getattr(torch, dtype)
    for dtype in amp_valid_dtypes
]


def bf16_supported() -> bool:
    return (hasattr(torch.cuda, 'is_bf16_supported')
            and torch.cuda.is_bf16_supported())


class ToyModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.conv3 = nn.Conv2d(1, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class ToyModel2(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class TestOptimWrapper(MultiProcessTestCase):
    # Test `OptimWrapper.optim_context` will block the gradient
    # synchronization when using gradient accumulation strategy in distributed
    # data parallel training.
    def setUp(self) -> None:
        super().setUp()
        self._spawn_processes()

    def run_test(self, test_name: str, parent_pipe) -> None:
        self.model = ToyModel()
        self.optimizer = SGD(self.model.parameters(), lr=0.1)
        self.logger = MMLogger.get_instance('test_optim_wrapper')
        self.message_hub = MessageHub.get_instance('test_optim_wrapper_init')
        super().run_test(test_name, parent_pipe)

    def test_init(self):
        optim_wrapper = OptimWrapper(self.optimizer)
        self.assertIs(optim_wrapper.optimizer, self.optimizer)
        self.assertIsNone(optim_wrapper.clip_grad_kwargs)
        self.assertEqual(optim_wrapper._accumulative_counts, 1)
        self.assertIs(optim_wrapper.message_hub, self.message_hub)
        self.assertEqual(optim_wrapper._inner_count, 0)
        self.assertEqual(optim_wrapper._max_counts, -1)
        self.assertEqual(optim_wrapper._remainder_counts, -1)

        with self.assertRaisesRegex(AssertionError,
            OptimWrapper(self.optimizer, clip_grad=[])

    def test_update_params(self):
        # Test update params every iteration.
        optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=1)
        self._mock_method(optim_wrapper)
        optim_wrapper.update_params(loss)
        self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
        optim_wrapper.step.assert_called_with()
        optim_wrapper.zero_grad.assert_called_with()

        # Test gradient accumulation.
        optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
        self._mock_method(optim_wrapper)
        # `iter=0`, accumulate gradient and do not update params.
        self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.)
        optim_wrapper.step.assert_not_called()
        optim_wrapper.zero_grad.assert_not_called()
        # gradient accumulate
        optim_wrapper.update_params(loss)
        self.assertEqual(optim_wrapper._inner_count, 2.)

        # `iter=2`, update params.
        optim_wrapper.update_params(loss)
        optim_wrapper.step.assert_called()
        optim_wrapper.zero_grad.assert_called()
        self._mock_method(optim_wrapper)

        # Test end of training without calling `initialize_iter_status`
        optim_wrapper._inner_count = 99
        optim_wrapper.update_params(loss)
        optim_wrapper.step.assert_not_called()
        optim_wrapper.zero_grad.assert_not_called()
        self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3.)
        self._mock_method(optim_wrapper)

        # After calling `initialize_iter_status`, params will be updated at the
        # last iteration, and the `loss_scaler` will be adjusted.
        optim_wrapper.initialize_count_status(self.model, 99, 100)
        optim_wrapper.update_params(loss)
        optim_wrapper.step.assert_called()
        optim_wrapper.zero_grad.assert_called()
        self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
        self._mock_method(optim_wrapper)

        # optim_wrapper.step should not be called at iteration 97 98, and the
        # loss factor should be 3 at iteration 99.
        optim_wrapper.initialize_count_status(self.model, 96, 100)
        for _ in range(2):
            optim_wrapper.update_params(loss)
            optim_wrapper.step.assert_not_called()
            optim_wrapper.zero_grad.assert_not_called()
        self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3)

    def test_initialize_iter_status(self):
        optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
        optim_wrapper.initialize_count_status(self.model, 0, 100)
        self.assertEqual(optim_wrapper._remainder_counts, 1)

        # Indivisible cur_iter will output warning.
        optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
        with self.assertLogs(self.logger) as cm:
            optim_wrapper.initialize_count_status(self.model, 2, 100)
            self.assertEqual(len(cm.output), 1)
            self.assertRegex(cm.records[0].msg, 'Resumed iteration number')

        # Model with batch norm will output warning.
        optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
        model = nn.BatchNorm2d(1)
        with self.assertLogs(self.logger) as cm:
            optim_wrapper.initialize_count_status(model, 0, 99)
            self.assertEqual(len(cm.output), 1)
            self.assertRegex(cm.records[0].msg, 'Gradient accumulative')

    def test_ger_lr(self):
        model = ToyModel()
        optim = SGD(model.parameters(), lr=0.1)
        optim_wrapper = OptimWrapper(optim)
        self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1]))

    def test_get_momentum(self):
        # Get momentum from SGD
        model = ToyModel()
        optim = SGD(model.parameters(), lr=0., momentum=0.8)
        optim_wrapper = OptimWrapper(optim)
        self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.8]))
        # Get momentum from Adam
        optim = Adam(model.parameters(), lr=0., betas=(0.9, 0.9))
        optim_wrapper = OptimWrapper(optim)
        self.assertEqual(optim_wrapper.get_momentum(), dict(momentum=[0.9]))

    def test_backward(self):
        loss = MagicMock()
        optim_wrapper = OptimWrapper(self.optimizer)
        optim_wrapper.backward(loss)
        loss.backward.assert_called()

    def test_zero_grad(self):
        optimizer = MagicMock(spec=Optimizer)
        optim_wrapper = OptimWrapper(optimizer)
        optim_wrapper.zero_grad()
        optimizer.zero_grad.assert_called()

    def test_step(self):
        optimizer = MagicMock(spec=Optimizer)
        optim_wrapper = OptimWrapper(optimizer)
        optim_wrapper.step()
        optimizer.step.assert_called()

    # TODO: This unit test could cause CI to fail with some probability, which
    #       is caused by MultiProcessTestCase. This problem should be solved
    #       in the future).
    @unittest.skipIf(True, reason='Solved in the future')
    def test_clip_grads(self):
        optim_wrapper = OptimWrapper(
            self.optimizer, clip_grad=dict(max_norm=35))
        loss = self.model(torch.Tensor(1, 1, 1, 1))
        loss.backward()
        optim_wrapper._clip_grad()
        log_scalars = self.message_hub.log_scalars
        self.assertIn('train/grad_norm', log_scalars)
        self.message_hub._log_scalars.clear()

        # Test `clip_grad` with `clip_value_`
        optim_wrapper = OptimWrapper(
            self.optimizer, clip_grad=dict(type='value', clip_value=0.5))
        loss = self.model(torch.Tensor(1, 1, 1, 1))
        loss.backward()
        optim_wrapper._clip_grad()
        self.assertNotIn('train/grad_norm', log_scalars)

    def test_state_dict(self):
        optim_wrapper = OptimWrapper(self.optimizer)
        self.assertEqual(optim_wrapper.state_dict(),
                         self.optimizer.state_dict())

    def test_load_state_dict(self):
        optim_wrapper = OptimWrapper(self.optimizer)
        model = ToyModel()
        optimizer = SGD(model.parameters(), lr=0.1)
        optim_wrapper.load_state_dict(optimizer.state_dict())

        self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())

    def test_param_groups(self):
        optim_wrapper = OptimWrapper(self.optimizer)
        self.assertEqual(optim_wrapper.param_groups,
                         self.optimizer.param_groups)

        self._init_dist_env(self.rank, self.world_size)
        model = ToyModel2()
        ddp_model = DistributedDataParallel(model)
        optimizer = SGD(ddp_model.parameters(), lr=0.01)
        optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
        optim_wrapper.zero_grad()

        # Automatically sync grads if `accumulative_counts` = 1
        optim_wrapper.initialize_count_status(model, 0, 100)
        inputs = torch.randn(1, 1, 1, 1) * self.rank
        ddp_model(inputs).sum().backward()
        grad = model.conv.weight.grad
        all_grads = all_gather(grad)
        assert_allclose(all_grads[0], all_grads[1])
        # Do not sync grads when `optim_wrapper.cur_iter` cannot be
        # divided by `optim_wrapper._accumulative_counts`
        optim_wrapper = OptimWrapper(optimizer, accumulative_counts=3)
        optim_wrapper.initialize_count_status(model, 0, 100)
        with optim_wrapper.optim_context(ddp_model):
            loss = ddp_model(inputs).sum()
        loss.backward()
        all_grads = all_gather(model.conv.weight.grad)
        with self.assertRaises(AssertionError):
            assert_allclose(all_grads[0], all_grads[1])

        # sync grads if `cur_iter == 2`
        optim_wrapper.initialize_count_status(model, 2, 100)
        with optim_wrapper.optim_context(ddp_model):
            loss = ddp_model(inputs).sum()
        loss.backward()
        all_grads = all_gather(model.conv.weight.grad)
        assert_allclose(all_grads[0], all_grads[1])

    def _init_dist_env(self, rank, world_size):
        """Initialize the distributed environment."""
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '29515'
        os.environ['RANK'] = str(rank)
        torch_dist.init_process_group(
            backend='gloo', rank=rank, world_size=world_size)

    # TODO Test the real interface after add testing tool function which can
    #  test the function or method is read called.
    def _mock_method(self, optim_wrapper):

        def mock_methd(loss):
            optim_wrapper._inner_count += 1
            optim_wrapper.scaled_loss = loss

        optim_wrapper.backward = mock_methd
        optim_wrapper.step = MagicMock()
        optim_wrapper.zero_grad = MagicMock()


@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex')
class TestApexOptimWrapper(TestCase):

    def setUp(self) -> None:
        self.model = ToyModel().cuda()
        self.optimizer = SGD(self.model.parameters(), lr=0.1)

    @unittest.skipIf(
        not is_apex_available,
        reason='`apex` is not available, Please install apex from '
        'https://www.github.com/nvidia/apex')
    def test_init(self):
        apex_optim_wrapper = ApexOptimWrapper(
            optimizer=self.optimizer, opt_level='O1', loss_scale=1)
        with apex_optim_wrapper.optim_context(self.model):
            pass

    @unittest.skipIf(
        not is_apex_available,
        reason='`apex` is not available, Please install apex from '
        'https://www.github.com/nvidia/apex')
    def test_step(self):
        optimizer = MagicMock(spec=Optimizer)
        apex_optim_wrapper = ApexOptimWrapper(
            optimizer=optimizer, opt_level='O1', loss_scale=1)
        with apex_optim_wrapper.optim_context(self.model):
            loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
            apex_optim_wrapper.backward(loss)
            apex_optim_wrapper.step()

    @unittest.skipIf(
        not is_apex_available,
        reason='`apex` is not available, Please install apex from '
        'https://www.github.com/nvidia/apex')
    def test_backward(self):
        apex_optim_wrapper = ApexOptimWrapper(
            optimizer=self.optimizer, opt_level='O1', loss_scale=1)
        with apex_optim_wrapper.optim_context(self.model):
            loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
            apex_optim_wrapper.backward(loss)

    @unittest.skipIf(
        not is_apex_available,
        reason='`apex` is not available, Please install apex from '
        'https://www.github.com/nvidia/apex')
    def test_state_dict(self):
        apex_optim_wrapper = ApexOptimWrapper(
            optimizer=self.optimizer, opt_level='O1', loss_scale=1)
        with apex_optim_wrapper.optim_context(self.model):
            loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
            apex_optim_wrapper.update_params(loss)
            state_dict = apex_optim_wrapper.state_dict()
            amp_state_dict = state_dict.pop('apex_amp')
            optim_state_dict = state_dict

            self.assertDictEqual(optim_state_dict,
                                 apex_optim_wrapper.optimizer.state_dict())
            self.assertDictEqual(amp_state_dict, apex_amp.state_dict())

    @unittest.skipIf(
        not is_apex_available,
        reason='`apex` is not available, Please install apex from '
        'https://www.github.com/nvidia/apex')
    def test_load_state_dict(self):
        apex_optim_wrapper = ApexOptimWrapper(
            optimizer=self.optimizer, opt_level='O1', loss_scale=1)
        with apex_optim_wrapper.optim_context(self.model):
            # Test load from optimizer
            optimizer = SGD(self.model.parameters(), lr=0.1)
            apex_optim_wrapper.load_state_dict(optimizer.state_dict())

            self.assertDictEqual(optimizer.state_dict(),
                                 apex_optim_wrapper.optimizer.state_dict())
            # Test load from optim_wrapper
            apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer)
            apex_optim_wrapper_ = ApexOptimWrapper(
                optimizer=SGD(self.model.parameters(), lr=0.1))
            apex_optim_wrapper_.load_state_dict(
                apex_optim_wrapper.state_dict())
            self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(),
                                 apex_optim_wrapper_.optimizer.state_dict())

    @unittest.skipIf(
        not is_apex_available,
        reason='`apex` is not available, Please install apex from '
        'https://www.github.com/nvidia/apex')
    def test_optim_context(self):
        apex_optim_wrapper = ApexOptimWrapper(
            optimizer=self.optimizer, opt_level='O1', loss_scale=1)
        with apex_optim_wrapper.optim_context(self.model):
            x = torch.randn(1, 1, 1, 1).cuda()
            y = nn.Conv2d(1, 1, 1).cuda()(x)
            self.assertEqual(y.dtype, torch.float16)


class TestAmpOptimWrapper(TestCase):

    def setUp(self) -> None:
        self.model = ToyModel()
        self.optimizer = SGD(self.model.parameters(), lr=0.1)

    @unittest.skipIf(
        not torch.cuda.is_available(),
        reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
    def test_init(self):
        # Test with default arguments.
        amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
        self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)

        # Test with dynamic.
        amp_optim_wrapper = AmpOptimWrapper(
            'dynamic', optimizer=self.optimizer)
        self.assertIsNone(amp_optim_wrapper._scale_update_param)
        self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)

        # Test with dtype float16
        amp_optim_wrapper = AmpOptimWrapper(
            dtype='float16', optimizer=self.optimizer)
        self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16)

        # Test with dtype bfloat16
        amp_optim_wrapper = AmpOptimWrapper(
            dtype='bfloat16', optimizer=self.optimizer)
        self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16)

        # Test with dict loss_scale.
        amp_optim_wrapper = AmpOptimWrapper(
            dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
        self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
        self.assertIsNone(amp_optim_wrapper._scale_update_param)
        with self.assertRaisesRegex(TypeError,
                                    'loss_scale must be of type float'):
            AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown')

    @parameterized.expand(list(zip(amp_valid_dtypes)))
    @unittest.skipIf(
        not torch.cuda.is_available(),
        reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
    def test_step(self, dtype):
        if dtype == 'bfloat16' and not bf16_supported():
            raise unittest.SkipTest('bfloat16 not supported by device')
        optimizer = MagicMock(spec=Optimizer)
        amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
        amp_optim_wrapper.loss_scaler = MagicMock()
        amp_optim_wrapper.step()
        amp_optim_wrapper.loss_scaler.step.assert_called_with(
            amp_optim_wrapper.optimizer)
        amp_optim_wrapper.loss_scaler.update.assert_called_with(
            amp_optim_wrapper._scale_update_param)

    @parameterized.expand(list(zip(amp_valid_dtypes)))
    @unittest.skipIf(
        not torch.cuda.is_available(),
        reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
    def test_backward(self, dtype):
        if dtype == 'bfloat16' and not bf16_supported():
            raise unittest.SkipTest('bfloat16 not supported by device')
        amp_optim_wrapper = AmpOptimWrapper(
            optimizer=self.optimizer, dtype=dtype)
        loss_scaler = MagicMock()
        scale_return = MagicMock()
        scale_fn = MagicMock(return_value=scale_return)
        loss_scaler.scale = scale_fn
        amp_optim_wrapper.loss_scaler = loss_scaler

        amp_optim_wrapper.backward(1)
        loss_scaler.scale.assert_called_with(1)
        scale_return.backward.assert_called_with()

    @unittest.skipIf(
        not torch.cuda.is_available(),
        reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
    def test_state_dict(self):
        self.model = self.model.cuda()
        amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
        loss = self.model(torch.Tensor(1, 1, 1, 1).cuda())
        amp_optim_wrapper.update_params(loss)
        state_dict = amp_optim_wrapper.state_dict()
        scalar_state_dict = state_dict.pop('loss_scaler')
        optim_state_dict = state_dict

        self.assertDictEqual(optim_state_dict,
                             amp_optim_wrapper.optimizer.state_dict())
        self.assertDictEqual(scalar_state_dict,
                             amp_optim_wrapper.loss_scaler.state_dict())

    @unittest.skipIf(
        not torch.cuda.is_available(),
        reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
    def test_load_state_dict(self):
        amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
        self.model = self.model.cuda()
        # Test load from optimizer
        optimizer = SGD(self.model.parameters(), lr=0.1)
        amp_optim_wrapper.load_state_dict(optimizer.state_dict())

        self.assertDictEqual(optimizer.state_dict(),
                             amp_optim_wrapper.optimizer.state_dict())
        # Test load from optim_wrapper
        amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
        amp_optim_wrapper_ = AmpOptimWrapper(
            optimizer=SGD(self.model.parameters(), lr=0.1))
        amp_optim_wrapper_.load_state_dict(amp_optim_wrapper.state_dict())
        self.assertDictEqual(amp_optim_wrapper.optimizer.state_dict(),
                             amp_optim_wrapper_.optimizer.state_dict())
        self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(),
                             amp_optim_wrapper_.loss_scaler.state_dict())

    @parameterized.expand(list(zip(amp_valid_dtypes, torch_dtypes)))
    @unittest.skipIf(
        not torch.cuda.is_available(),
        reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
    def test_optim_context(self, dtype, target_dtype):
        if dtype == 'bfloat16' and not bf16_supported():
            raise unittest.SkipTest('bfloat16 not supported by device')
        amp_optim_wrapper = AmpOptimWrapper(
            optimizer=self.optimizer, dtype=dtype)
        with amp_optim_wrapper.optim_context(self.model):
            x = torch.randn(1, 1, 1, 1).cuda()
            y = nn.Conv2d(1, 1, 1).cuda()(x)
            self.assertEqual(y.dtype, target_dtype)