From 1c67f9eb22c5528266c2748d146429c2c7336e19 Mon Sep 17 00:00:00 2001
From: Yinlei Sun <ginray0215@gmail.com>
Date: Mon, 10 Apr 2023 16:31:31 +0800
Subject: [PATCH] [Enhancement] Support BoolTensor and LongTensor on Ascend NPU
 (#1011)

---
 mmengine/structures/instance_data.py | 27 +++++++++++++++++----------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py
index 36c60fd1..1ceac9ad 100644
--- a/mmengine/structures/instance_data.py
+++ b/mmengine/structures/instance_data.py
@@ -1,16 +1,26 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import itertools
 from collections.abc import Sized
-from typing import List, Union
+from typing import Any, List, Union
 
 import numpy as np
 import torch
 
+from mmengine.device import get_device
 from .base_data_element import BaseDataElement
 
-IndexType = Union[str, slice, int, list, torch.LongTensor,
-                  torch.cuda.LongTensor, torch.BoolTensor,
-                  torch.cuda.BoolTensor, np.ndarray]
+BoolTypeTensor: Union[Any]
+LongTypeTensor: Union[Any]
+
+if get_device() == 'npu':
+    BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
+    LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
+else:
+    BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
+    LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
+
+IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
+                              BoolTypeTensor, np.ndarray]
 
 
 # Modified from
@@ -156,6 +166,7 @@ class InstanceData(BaseDataElement):
         Returns:
             :obj:`InstanceData`: Corresponding values.
         """
+        assert isinstance(item, IndexType.__args__)
         if isinstance(item, list):
             item = np.array(item)
         if isinstance(item, np.ndarray):
@@ -165,9 +176,6 @@ class InstanceData(BaseDataElement):
             # More details in https://github.com/numpy/numpy/issues/9464
             item = item.astype(np.int64) if item.dtype == np.int32 else item
             item = torch.from_numpy(item)
-        assert isinstance(
-            item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
-                   torch.BoolTensor, torch.cuda.BoolTensor))
 
         if isinstance(item, str):
             return getattr(self, item)
@@ -183,7 +191,7 @@ class InstanceData(BaseDataElement):
         if isinstance(item, torch.Tensor):
             assert item.dim() == 1, 'Only support to get the' \
                                     ' values along the first dimension.'
-            if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
+            if isinstance(item, BoolTypeTensor.__args__):
                 assert len(item) == len(self), 'The shape of the ' \
                                                'input(BoolTensor) ' \
                                                f'{len(item)} ' \
@@ -202,8 +210,7 @@ class InstanceData(BaseDataElement):
                         v, (str, list, tuple)) or (hasattr(v, '__getitem__')
                                                    and hasattr(v, 'cat')):
                     # convert to indexes from BoolTensor
-                    if isinstance(item,
-                                  (torch.BoolTensor, torch.cuda.BoolTensor)):
+                    if isinstance(item, BoolTypeTensor.__args__):
                         indexes = torch.nonzero(item).view(
                             -1).cpu().numpy().tolist()
                     else:
-- 
GitLab