Skip to content
Snippets Groups Projects
test_averaged_model.py 10.6 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from unittest import TestCase

import torch

from mmengine.model import (ExponentialMovingAverage, MomentumAnnealingEMA,
                            StochasticWeightAverage)
from mmengine.testing import assert_allclose


class TestAveragedModel(TestCase):
    """Test the AveragedModel class.

    Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
    """  # noqa: E501

    def _test_swa_model(self, net_device, avg_device):
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.Linear(5, 10)).to(net_device)

        averaged_model = StochasticWeightAverage(model, device=avg_device)
        averaged_params = [
            torch.zeros_like(param) for param in model.parameters()
        ]
        n_updates = 2
        for i in range(n_updates):
            for p, p_avg in zip(model.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                p_avg += p.detach() / n_updates
            averaged_model.update_parameters(model)

        for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
            # Check that AveragedModel is on the correct device
            self.assertTrue(p_swa.device == avg_device)
            self.assertTrue(p.device == net_device)
            assert_allclose(p_avg, p_swa.to(p_avg.device))
        self.assertTrue(averaged_model.steps.device == avg_device)

    def test_averaged_model_all_devices(self):
        cpu = torch.device('cpu')
        self._test_swa_model(cpu, cpu)
        if torch.cuda.is_available():
            cuda = torch.device(0)
            self._test_swa_model(cuda, cpu)
            self._test_swa_model(cpu, cuda)
            self._test_swa_model(cuda, cuda)

    def test_swa_mixed_device(self):
        if not torch.cuda.is_available():
            return
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
        model[0].cuda()
        model[1].cpu()
        averaged_model = StochasticWeightAverage(model)
        averaged_params = [
            torch.zeros_like(param) for param in model.parameters()
        ]
        n_updates = 10
        for i in range(n_updates):
            for p, p_avg in zip(model.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                p_avg += p.detach() / n_updates
            averaged_model.update_parameters(model)

        for p_avg, p_swa in zip(averaged_params, averaged_model.parameters()):
            assert_allclose(p_avg, p_swa)
            # Check that AveragedModel is on the correct device
            self.assertTrue(p_avg.device == p_swa.device)

    def test_swa_state_dict(self):
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
        averaged_model = StochasticWeightAverage(model)
        averaged_model2 = StochasticWeightAverage(model)
        n_updates = 10
        for i in range(n_updates):
            for p in model.parameters():
                p.detach().add_(torch.randn_like(p))
            averaged_model.update_parameters(model)
        averaged_model2.load_state_dict(averaged_model.state_dict())
        for p_swa, p_swa2 in zip(averaged_model.parameters(),
                                 averaged_model2.parameters()):
            assert_allclose(p_swa, p_swa2)
        self.assertTrue(averaged_model.steps == averaged_model2.steps)

    def test_ema(self):
        # test invalid momentum
        with self.assertRaisesRegex(AssertionError,
                                    'momentum must be in range'):
            model = torch.nn.Sequential(
                torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
            ExponentialMovingAverage(model, momentum=3)

        with self.assertWarnsRegex(
                Warning,
                'The value of momentum in EMA is usually a small number'):
            model = torch.nn.Sequential(
                torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
            ExponentialMovingAverage(model, momentum=0.9)
        # test EMA
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
        momentum = 0.1

        ema_model = ExponentialMovingAverage(model, momentum=momentum)
        averaged_params = [
            torch.zeros_like(param) for param in model.parameters()
        ]
        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            for p, p_avg in zip(model.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                if i == 0:
                    updated_averaged_params.append(p.clone())
                else:
                    updated_averaged_params.append(
                        (p_avg * (1 - momentum) + p * momentum).clone())
            ema_model.update_parameters(model)
            averaged_params = updated_averaged_params

        for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
            assert_allclose(p_target, p_ema)

    def test_ema_update_buffers(self):
        # Test EMA and update_buffers as True.
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
        momentum = 0.1

        ema_model = ExponentialMovingAverage(
            model, momentum=momentum, update_buffers=True)
        averaged_params = [
            torch.zeros_like(param)
            for param in itertools.chain(model.parameters(), model.buffers())
            if param.size() != torch.Size([])
        ]
        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            params = [
                param for param in itertools.chain(model.parameters(),
                                                   model.buffers())
                if param.size() != torch.Size([])
            ]
            for p, p_avg in zip(params, averaged_params):
                p.detach().add_(torch.randn_like(p))
                if i == 0:
                    updated_averaged_params.append(p.clone())
                else:
                    updated_averaged_params.append(
                        (p_avg * (1 - momentum) + p * momentum).clone())
            ema_model.update_parameters(model)
            averaged_params = updated_averaged_params

        ema_params = [
            param for param in itertools.chain(ema_model.module.parameters(),
                                               ema_model.module.buffers())
            if param.size() != torch.Size([])
        ]
        for p_target, p_ema in zip(averaged_params, ema_params):
            assert_allclose(p_target, p_ema)

    def test_momentum_annealing_ema(self):
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
        # Test invalid gamma
        with self.assertRaisesRegex(AssertionError,
                                    'gamma must be greater than 0'):
            MomentumAnnealingEMA(model, gamma=-1)

        # Test EMA with momentum annealing.
        momentum = 0.1
        gamma = 4

        ema_model = MomentumAnnealingEMA(
            model, gamma=gamma, momentum=momentum, update_buffers=True)
        averaged_params = [
            torch.zeros_like(param)
            for param in itertools.chain(model.parameters(), model.buffers())
            if param.size() != torch.Size([])
        ]
        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            params = [
                param for param in itertools.chain(model.parameters(),
                                                   model.buffers())
                if param.size() != torch.Size([])
            ]
            for p, p_avg in zip(params, averaged_params):
                if i == 0:
                    updated_averaged_params.append(p.clone())
                else:
                    m = max(momentum, gamma / (gamma + i))
                    updated_averaged_params.append(
                        (p_avg * (1 - m) + p * m).clone())
            ema_model.update_parameters(model)
            averaged_params = updated_averaged_params

        ema_params = [
            param for param in itertools.chain(ema_model.module.parameters(),
                                               ema_model.module.buffers())
            if param.size() != torch.Size([])
        ]
        for p_target, p_ema in zip(averaged_params, ema_params):
            assert_allclose(p_target, p_ema)

    def test_momentum_annealing_ema_with_interval(self):
        # Test EMA with momentum annealing and interval
        model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10))
        momentum = 0.1
        gamma = 4
        interval = 3

        ema_model = MomentumAnnealingEMA(
            model,
            gamma=gamma,
            momentum=momentum,
            interval=interval,
            update_buffers=True)
        averaged_params = [
            torch.zeros_like(param)
            for param in itertools.chain(model.parameters(), model.buffers())
            if param.size() != torch.Size([])
        ]
        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            params = [
                param for param in itertools.chain(model.parameters(),
                                                   model.buffers())
                if param.size() != torch.Size([])
            ]
            for p, p_avg in zip(params, averaged_params):
                if i == 0:
                    updated_averaged_params.append(p.clone())
                elif i % interval == 0:
                    m = max(momentum, gamma / (gamma + i))
                    updated_averaged_params.append(
                        (p_avg * (1 - m) + p * m).clone())
                else:
                    updated_averaged_params.append(p_avg.clone())
            ema_model.update_parameters(model)
            averaged_params = updated_averaged_params

        ema_params = [
            param for param in itertools.chain(ema_model.module.parameters(),
                                               ema_model.module.buffers())
            if param.size() != torch.Size([])
        ]
        for p_target, p_ema in zip(averaged_params, ema_params):
            assert_allclose(p_target, p_ema)