Skip to content
Snippets Groups Projects
Unverified Commit 1c67f9eb authored by Yinlei Sun's avatar Yinlei Sun Committed by GitHub
Browse files

[Enhancement] Support BoolTensor and LongTensor on Ascend NPU (#1011)

parent 8bf1ecad
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from collections.abc import Sized
from typing import List, Union
from typing import Any, List, Union
import numpy as np
import torch
from mmengine.device import get_device
from .base_data_element import BaseDataElement
IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
BoolTypeTensor: Union[Any]
LongTypeTensor: Union[Any]
if get_device() == 'npu':
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
else:
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
BoolTypeTensor, np.ndarray]
# Modified from
......@@ -156,6 +166,7 @@ class InstanceData(BaseDataElement):
Returns:
:obj:`InstanceData`: Corresponding values.
"""
assert isinstance(item, IndexType.__args__)
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
......@@ -165,9 +176,6 @@ class InstanceData(BaseDataElement):
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor))
if isinstance(item, str):
return getattr(self, item)
......@@ -183,7 +191,7 @@ class InstanceData(BaseDataElement):
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \
f'{len(item)} ' \
......@@ -202,8 +210,7 @@ class InstanceData(BaseDataElement):
v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from BoolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment