Skip to content
Snippets Groups Projects
test_amp.py 3.47 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
import torch.nn as nn

import mmengine
from mmengine.device import get_device
from mmengine.runner import autocast
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION


class TestAmp(unittest.TestCase):

    def test_autocast(self):
        if not torch.cuda.is_available():
            if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
                # `torch.cuda.amp.autocast` is only support in gpu mode, if
                # cuda is not available, it will return an empty context and
                # should not accept any arguments.
                with self.assertRaisesRegex(RuntimeError,
                                            'If pytorch versions is '):
                    with autocast():
                        pass

                with autocast(enabled=False):
                    layer = nn.Conv2d(1, 1, 1)
                    res = layer(torch.randn(1, 1, 1, 1))
                    self.assertEqual(res.dtype, torch.float32)

            else:
                with autocast(device_type='cpu'):
                    # torch.autocast support cpu mode.
                    layer = nn.Conv2d(1, 1, 1)
                    res = layer(torch.randn(1, 1, 1, 1))
                    self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
                    with autocast(enabled=False):
                        res = layer(torch.randn(1, 1, 1, 1))
                        self.assertEqual(res.dtype, torch.float32)

        else:
            if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
                devices = ['cuda']
            else:
                devices = ['cpu', 'cuda']
            for device in devices:
                with autocast(device_type=device):
                    # torch.autocast support cpu and cuda mode.
                    layer = nn.Conv2d(1, 1, 1).to(device)
                    res = layer(torch.randn(1, 1, 1, 1).to(device))
                    self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
                    with autocast(enabled=False, device_type=device):
                        res = layer(torch.randn(1, 1, 1, 1).to(device))
                        self.assertEqual(res.dtype, torch.float32)
                # Test with fp32_enabled
                with autocast(enabled=False, device_type=device):
                    layer = nn.Conv2d(1, 1, 1).to(device)
                    res = layer(torch.randn(1, 1, 1, 1).to(device))
                    self.assertEqual(res.dtype, torch.float32)

        # Test mps
        if digit_version(TORCH_VERSION) >= digit_version('1.12.0'):
            mmengine.runner.amp.get_device = lambda: 'mps'
            with autocast(enabled=False):
                layer = nn.Conv2d(1, 1, 1)
                res = layer(torch.randn(1, 1, 1, 1))
                self.assertEqual(res.dtype, torch.float32)

            with self.assertRaisesRegex(ValueError,
                                        'User specified autocast device_type'):
                with autocast(enabled=True):
                    pass
        # Native pytorch does not support mlu, here we simply test autocast
        # will call `torch.autocast`, which will be overridden by mlu version
        # pytorch
            mmengine.runner.amp.get_device = lambda: 'mlu'
            with self.assertRaises(RuntimeError):
                with autocast(enabled=False):
                    pass
            mmengine.runner.amp.get_device = get_device