Skip to content
Snippets Groups Projects
Unverified Commit 1c67f9eb authored by Yinlei Sun's avatar Yinlei Sun Committed by GitHub
Browse files

[Enhancement] Support BoolTensor and LongTensor on Ascend NPU (#1011)

parent 8bf1ecad
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from collections.abc import Sized from collections.abc import Sized
from typing import List, Union from typing import Any, List, Union
import numpy as np import numpy as np
import torch import torch
from mmengine.device import get_device
from .base_data_element import BaseDataElement from .base_data_element import BaseDataElement
IndexType = Union[str, slice, int, list, torch.LongTensor, BoolTypeTensor: Union[Any]
torch.cuda.LongTensor, torch.BoolTensor, LongTypeTensor: Union[Any]
torch.cuda.BoolTensor, np.ndarray]
if get_device() == 'npu':
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
else:
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor,
BoolTypeTensor, np.ndarray]
# Modified from # Modified from
...@@ -156,6 +166,7 @@ class InstanceData(BaseDataElement): ...@@ -156,6 +166,7 @@ class InstanceData(BaseDataElement):
Returns: Returns:
:obj:`InstanceData`: Corresponding values. :obj:`InstanceData`: Corresponding values.
""" """
assert isinstance(item, IndexType.__args__)
if isinstance(item, list): if isinstance(item, list):
item = np.array(item) item = np.array(item)
if isinstance(item, np.ndarray): if isinstance(item, np.ndarray):
...@@ -165,9 +176,6 @@ class InstanceData(BaseDataElement): ...@@ -165,9 +176,6 @@ class InstanceData(BaseDataElement):
# More details in https://github.com/numpy/numpy/issues/9464 # More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item) item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor))
if isinstance(item, str): if isinstance(item, str):
return getattr(self, item) return getattr(self, item)
...@@ -183,7 +191,7 @@ class InstanceData(BaseDataElement): ...@@ -183,7 +191,7 @@ class InstanceData(BaseDataElement):
if isinstance(item, torch.Tensor): if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \ assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.' ' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): if isinstance(item, BoolTypeTensor.__args__):
assert len(item) == len(self), 'The shape of the ' \ assert len(item) == len(self), 'The shape of the ' \
'input(BoolTensor) ' \ 'input(BoolTensor) ' \
f'{len(item)} ' \ f'{len(item)} ' \
...@@ -202,8 +210,7 @@ class InstanceData(BaseDataElement): ...@@ -202,8 +210,7 @@ class InstanceData(BaseDataElement):
v, (str, list, tuple)) or (hasattr(v, '__getitem__') v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')): and hasattr(v, 'cat')):
# convert to indexes from BoolTensor # convert to indexes from BoolTensor
if isinstance(item, if isinstance(item, BoolTypeTensor.__args__):
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view( indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist() -1).cpu().numpy().tolist()
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment