diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py
index 3b05f06a71e295c3f3f5a8dabb9074f6774a92d5..b989c45a7ca7aa5c07c92d54a3e1be71d6e10efc 100644
--- a/mmengine/dist/dist.py
+++ b/mmengine/dist/dist.py
@@ -414,9 +414,13 @@ def _broadcast_object_list(object_list: List[Any],
     is_nccl_backend = group_backend == torch_dist.Backend.NCCL
     current_device = torch.device('cpu')
     is_hccl_backend = group_backend == 'hccl'
+    is_cncl_backend = group_backend == 'cncl'
     if is_hccl_backend:
         current_device = torch.npu.current_device()
         object_sizes_tensor = object_sizes_tensor.to(current_device)
+    elif is_cncl_backend:
+        current_device = torch.device('mlu', torch.mlu.current_device())
+        object_sizes_tensor = object_sizes_tensor.to(current_device)
     elif is_nccl_backend:
         # See note about using torch.cuda.current_device() here in
         # docstring. We cannot simply use my_rank since rank == device is
@@ -436,7 +440,7 @@ def _broadcast_object_list(object_list: List[Any],
             dtype=torch.uint8,
         )
 
-    if is_nccl_backend or is_hccl_backend:
+    if is_nccl_backend or is_hccl_backend or is_cncl_backend:
         object_tensor = object_tensor.to(current_device)
     torch_dist.broadcast(object_tensor, src=src, group=group)
     # Deserialize objects using their stored sizes.
diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py
index 3e2d72da0dd58bcdbb0e50a07a299fda584e5424..c682ff48ea7b5510e0258a4f711c32b5b951847c 100644
--- a/mmengine/model/base_model/base_model.py
+++ b/mmengine/model/base_model/base_model.py
@@ -216,6 +216,20 @@ class BaseModel(BaseModule):
         self._set_device(torch.device(device))
         return super().cuda(device)
 
+    def mlu(
+        self,
+        device: Union[int, str, torch.device, None] = None,
+    ) -> nn.Module:
+        """Overrides this method to call :meth:`BaseDataPreprocessor.mlu`
+        additionally.
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        device = torch.device('mlu', torch.mlu.current_device())
+        self._set_device(device)
+        return super().mlu()
+
     def npu(
         self,
         device: Union[int, str, torch.device, None] = None,
diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 8f02a6c5a4c597cb1d41e0f6550f85a95139c757..7f8a54c1780b963a3f479feac3712ef3124f653b 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -122,6 +122,15 @@ class BaseDataPreprocessor(nn.Module):
         self._device = torch.device(torch.npu.current_device())
         return super().npu()
 
+    def mlu(self, *args, **kwargs) -> nn.Module:
+        """Overrides this method to set the :attr:`device`
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self._device = torch.device(torch.mlu.current_device())
+        return super().mlu()
+
     def cpu(self, *args, **kwargs) -> nn.Module:
         """Overrides this method to set the :attr:`device`
 
diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py
index fcc3b4c50fb82e92be87ca293cc4f95cfde6de57..861361369fbfa057e7fa9d06f18eee75a96ab79a 100644
--- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py
+++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py
@@ -5,7 +5,8 @@ from typing import Union
 import torch
 import torch.nn as nn
 
-from mmengine.device import is_cuda_available, is_npu_available
+from mmengine.device import (is_cuda_available, is_mlu_available,
+                             is_npu_available)
 from mmengine.registry import OPTIM_WRAPPERS
 from mmengine.utils import digit_version
 from mmengine.utils.dl_utils import TORCH_VERSION
@@ -13,6 +14,8 @@ from .optimizer_wrapper import OptimWrapper
 
 if is_npu_available():
     from torch.npu.amp import GradScaler
+elif is_mlu_available():
+    from torch.mlu.amp import GradScaler
 else:
     from torch.cuda.amp import GradScaler
 
@@ -65,8 +68,9 @@ class AmpOptimWrapper(OptimWrapper):
                  **kwargs):
         assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
             '`torch.cuda.amp` is only available when pytorch version >= 1.6')
-        assert is_cuda_available() or is_npu_available(), (
-            '``AmpOptimizerWrapper`` is only available training on gpu or npu')
+        assert is_cuda_available() or is_npu_available() or is_mlu_available(
+        ), ('``AmpOptimizerWrapper`` is only available training '
+            'on gpu, npu or mlu')
         super().__init__(**kwargs)
         self._scale_update_param = None
         if loss_scale == 'dynamic':
diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py
index 33ab6bd25d3801e2133e455b57b214af25330256..964518fc90ca9f949f74e1f38421cf9fefb5cfcd 100644
--- a/mmengine/runner/amp.py
+++ b/mmengine/runner/amp.py
@@ -5,7 +5,8 @@ from typing import Optional
 
 import torch
 
-from mmengine.device import get_device, is_cuda_available, is_npu_available
+from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
+                             is_npu_available)
 from mmengine.logging import print_log
 from mmengine.utils import digit_version
 from mmengine.utils.dl_utils import TORCH_VERSION
@@ -75,9 +76,11 @@ def autocast(device_type: Optional[str] = None,
             digit_version('1.10.0')):
         # If pytorch version is between 1.5.0 and 1.10.0, the default value of
         # dtype for `torch.cuda.amp.autocast` is torch.float16.
-        assert device_type == 'cuda' or device_type is None, (
-            'Pytorch version under 1.10.0 only supports running automatic '
-            'mixed training with cuda')
+        assert (
+            device_type == 'cuda' or device_type == 'mlu'
+            or device_type is None), (
+                'Pytorch version under 1.10.0 only supports running automatic '
+                'mixed training with cuda or mlu')
         if dtype is not None or cache_enabled is not None:
             print_log(
                 f'{dtype} and {device_type} will not work for '
@@ -89,6 +92,9 @@ def autocast(device_type: Optional[str] = None,
         if is_npu_available():
             with torch.npu.amp.autocast(enabled=enabled):
                 yield
+        elif is_mlu_available():
+            with torch.mlu.amp.autocast(enabled=enabled):
+                yield
         elif is_cuda_available():
             with torch.cuda.amp.autocast(enabled=enabled):
                 yield
diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py
index 46c4c886e6633c69104327ea71278f8cc7ff650f..454a2243718b429bb2b5192c6f130c006dec7c7a 100644
--- a/mmengine/structures/base_data_element.py
+++ b/mmengine/structures/base_data_element.py
@@ -521,6 +521,16 @@ class BaseDataElement:
                 new_data.set_data(data)
         return new_data
 
+    def mlu(self) -> 'BaseDataElement':
+        """Convert all tensors to MLU in data."""
+        new_data = self.new()
+        for k, v in self.items():
+            if isinstance(v, (torch.Tensor, BaseDataElement)):
+                v = v.mlu()
+                data = {k: v}
+                new_data.set_data(data)
+        return new_data
+
     # Tensor-like methods
     def detach(self) -> 'BaseDataElement':
         """Detach all tensors in data."""
diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py
index 1ceac9ad24d79300aa4f7d40e52d6791f2af43bf..8df9727a00027a2c949e8bac37e3beaa49af68ba 100644
--- a/mmengine/structures/instance_data.py
+++ b/mmengine/structures/instance_data.py
@@ -15,6 +15,9 @@ LongTypeTensor: Union[Any]
 if get_device() == 'npu':
     BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
     LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
+elif get_device() == 'mlu':
+    BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor]
+    LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor]
 else:
     BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
     LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py
index 7ef605637832c361563ebf3da313abb75cef1e5e..89794f34144152c4ea8441a9359357b9743fc13d 100644
--- a/tests/test_runner/test_amp.py
+++ b/tests/test_runner/test_amp.py
@@ -5,7 +5,7 @@ import torch
 import torch.nn as nn
 
 import mmengine
-from mmengine.device import get_device
+from mmengine.device import get_device, is_mlu_available
 from mmengine.runner import autocast
 from mmengine.utils import digit_version
 from mmengine.utils.dl_utils import TORCH_VERSION
@@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION
 class TestAmp(unittest.TestCase):
 
     def test_autocast(self):
-        if not torch.cuda.is_available():
+        if is_mlu_available():
+            device = 'mlu'
+            with autocast(device_type=device):
+                # torch.autocast support mlu mode.
+                layer = nn.Conv2d(1, 1, 1).to(device)
+                res = layer(torch.randn(1, 1, 1, 1).to(device))
+                self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
+                with autocast(enabled=False, device_type=device):
+                    res = layer(torch.randn(1, 1, 1, 1).to(device))
+                    self.assertEqual(res.dtype, torch.float32)
+            # Test with fp32_enabled
+            with autocast(enabled=False, device_type=device):
+                layer = nn.Conv2d(1, 1, 1).to(device)
+                res = layer(torch.randn(1, 1, 1, 1).to(device))
+                self.assertEqual(res.dtype, torch.float32)
+        elif not torch.cuda.is_available():
             if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
                 # `torch.cuda.amp.autocast` is only support in gpu mode, if
                 # cuda is not available, it will return an empty context and