diff --git a/mmengine/data/base_data_element.py b/mmengine/data/base_data_element.py
index a6de2600e5f391b5d9c8164be959200df1488eb0..e98fcb43a0a7340aacfed3dc31944579a812cdd7 100644
--- a/mmengine/data/base_data_element.py
+++ b/mmengine/data/base_data_element.py
@@ -7,10 +7,11 @@ import torch
 
 
 class BaseDataElement:
-    """A base data structure interface of OpenMMlab.
+    """A base data interface that supports Tensor-like and dict-like
+    operations.
 
-    Data elements refer to predicted results or ground truth labels on a
-    task, such as predicted bboxes, instance masks, semantic
+    A typical data elements refer to predicted results or ground truth labels
+    on a task, such as predicted bboxes, instance masks, semantic
     segmentation masks, etc. Because groundtruth labels and predicted results
     often have similar properties (for example, the predicted bboxes and the
     groundtruth bboxes), MMEngine uses the same abstract data interface to
@@ -23,7 +24,23 @@ class BaseDataElement:
     ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
     ``LabelData`` inheriting from ``BaseDataElement`` to represent different
     types of ground truth labels or predictions.
-    They are used as interfaces between different commopenets.
+
+    Another common data element is sample data. A sample data consists of input
+    data (such as an image) and its annotations and predictions. In general,
+    an image can have multiple types of annotations and/or predictions at the
+    same time (for example, both pixel-level semantic segmentation annotations
+    and instance-level detection bboxes annotations). All labels and
+    predictions of a training sample are often passed between Dataset, Model,
+    Visualizer, and Evaluator components. In order to simplify the interface
+    between components, we can treat them as a large data element and
+    encapsulate them. Such data elements are generally called XXDataSample in
+    the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
+    allows `BaseDataElement` as its attribute. Such a class generally
+    encapsulates all the data of a sample in the algorithm library, and its
+    attributes generally are various types of data elements. For example,
+    MMDetection is assigned by the BaseDataElement to encapsulate all the data
+    elements of the sample labeling and prediction of a sample in the
+    algorithm library.
 
     The attributes in ``BaseDataElement`` are divided into two parts,
     the ``metainfo`` and the ``data`` respectively.
@@ -70,8 +87,8 @@ class BaseDataElement:
         >>> # new
         >>> gt_instances1 = gt_instance.new(
         ...                     metainfo=dict(img_id=1, img_shape=(640, 640)),
-        ...                     data=dict(bboxes=torch.rand((5, 4)),
-        ...                               scores=torch.rand((5,))))
+        ...                     bboxes=torch.rand((5, 4)),
+...                             scores=torch.rand((5,)))
         >>> gt_instances2 = gt_instances1.new()
 
         >>> # add and process property
@@ -241,8 +258,9 @@ class BaseDataElement:
         self.set_data(dict(instance.items()))
 
     def new(self,
-            metainfo: dict = None,
-            data: dict = None) -> 'BaseDataElement':
+            *,
+            metainfo: Optional[dict] = None,
+            **kwargs) -> 'BaseDataElement':
         """Return a new data element with same type. If ``metainfo`` and
         ``data`` are None, the new data element will have same metainfo and
         data. If metainfo or data is not None, the new result will overwrite it
@@ -252,8 +270,9 @@ class BaseDataElement:
             metainfo (dict, optional): A dict contains the meta information
                 of image, such as ``img_shape``, ``scale_factor``, etc.
                 Defaults to None.
-            data (dict, optional): A dict contains annotations of image or
-                model predictions. Defaults to None.
+            kwargs (dict): A dict contains annotations of image or
+                model predictions.
+
         Returns:
             BaseDataElement: a new data element with same type.
         """
@@ -263,8 +282,8 @@ class BaseDataElement:
             new_data.set_metainfo(metainfo)
         else:
             new_data.set_metainfo(dict(self.metainfo_items()))
-        if data is not None:
-            new_data.set_data(data)
+        if kwargs:
+            new_data.set_data(kwargs)
         else:
             new_data.set_data(dict(self.items()))
         return new_data
@@ -388,7 +407,6 @@ class BaseDataElement:
             self._data_fields.remove(item)
 
     # dict-like methods
-    __setitem__ = __setattr__
     __delitem__ = __delattr__
 
     def get(self, key, default=None) -> Any:
@@ -519,6 +537,7 @@ class BaseDataElement:
         }
 
     def __repr__(self) -> str:
+        """represent the object."""
 
         def _addindent(s_: str, num_spaces: int) -> str:
             """This func is modified from `pytorch` https://github.com/pytorch/
diff --git a/mmengine/data/instance_data.py b/mmengine/data/instance_data.py
index fff34bf6e980a2a06362417d9f502cda16d93d38..27d8755982e108ac45880ac1bdcaa6c2990c4920 100644
--- a/mmengine/data/instance_data.py
+++ b/mmengine/data/instance_data.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import itertools
+from collections.abc import Sized
 from typing import List, Union
 
 import numpy as np
@@ -7,8 +8,9 @@ import torch
 
 from .base_data_element import BaseDataElement
 
-IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
-                  torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool]
+IndexType = Union[str, slice, int, list, torch.LongTensor,
+                  torch.cuda.LongTensor, torch.BoolTensor,
+                  torch.cuda.BoolTensor, np.ndarray]
 
 
 # Modified from
@@ -19,8 +21,37 @@ class InstanceData(BaseDataElement):
     Subclass of :class:`BaseDataElement`. All value in `data_fields`
     should have the same length. This design refer to
     https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
+    InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
+    in data field can be base data structure such as `torch.tensor`, `numpy.dnarray`, `list`, `str`, `tuple`,
+    and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
 
     Examples:
+        >>> # custom data structure
+        >>> class TmpObject:
+        ...     def __init__(self, tmp) -> None:
+        ...         assert isinstance(tmp, list)
+        ...         self.tmp = tmp
+        ...     def __len__(self):
+        ...         return len(self.tmp)
+        ...     def __getitem__(self, item):
+        ...         if type(item) == int:
+        ...             if item >= len(self) or item < -len(self):  # type:ignore
+        ...                 raise IndexError(f'Index {item} out of range!')
+        ...             else:
+        ...                 # keep the dimension
+        ...                 item = slice(item, None, len(self))
+        ...         return TmpObject(self.tmp[item])
+        ...     @staticmethod
+        ...     def cat(tmp_objs):
+        ...         assert all(isinstance(results, TmpObject) for results in tmp_objs)
+        ...         if len(tmp_objs) == 1:
+        ...             return tmp_objs[0]
+        ...         tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
+        ...         tmp_list = list(itertools.chain(*tmp_list))
+        ...         new_data = TmpObject(tmp_list)
+        ...         return new_data
+        ...     def __repr__(self):
+        ...         return str(self.tmp)
         >>> from mmengine.data import InstanceData
         >>> import numpy as np
         >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
@@ -30,44 +61,69 @@ class InstanceData(BaseDataElement):
         >>> instance_data.det_labels = torch.LongTensor([2, 3])
         >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
         >>> instance_data.bboxes = torch.rand((2, 4))
+        >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
         >>> len(instance_data)
-        4
+        2
         >>> print(instance_data)
         <InstanceData(
-
             META INFORMATION
             pad_shape: (800, 1196, 3)
             img_shape: (800, 1216, 3)
-
             DATA FIELDS
             det_labels: tensor([2, 3])
             det_scores: tensor([0.8, 0.7000])
             bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
                 [0.8101, 0.3105, 0.5123, 0.6263]])
+            polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
             ) at 0x7fb492de6280>
         >>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
         >>> sorted_results.det_scores
         tensor([0.7000, 0.8000])
         >>> print(instance_data[instance_data.det_scores > 0.75])
         <InstanceData(
-
             META INFORMATION
             pad_shape: (800, 1216, 3)
             img_shape: (800, 1196, 3)
-
             DATA FIELDS
-            det_labels: tensor([0])
-            bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
+            det_labels: tensor([2])
+            masks: [[11, 21, 31, 41]]
             det_scores: tensor([0.8000])
-        ) at 0x7fb5cf6e2790>
-        >>> instance_data[instance_data.det_scores > 0.75].det_labels
-        tensor([0])
-        >>> instance_data[instance_data.det_scores > 0.75].det_scores
-        tensor([0.8000])
+            bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
+            polygons: [[1, 2, 3, 4]]
+        ) at 0x7f64ecf0ec40>
+        >>> print(instance_data[instance_data.det_scores > 1])
+        <InstanceData(
+            META INFORMATION
+            pad_shape: (800, 1216, 3)
+            img_shape: (800, 1196, 3)
+            DATA FIELDS
+            det_labels: tensor([], dtype=torch.int64)
+            masks: []
+            det_scores: tensor([])
+            bboxes: tensor([], size=(0, 4))
+            polygons: [[]]
+        ) at 0x7f660a6a7f70>
+        >>> print(instance_data.cat([instance_data, instance_data]))
+        <InstanceData(
+            META INFORMATION
+            img_shape: (800, 1196, 3)
+            pad_shape: (800, 1216, 3)
+            DATA FIELDS
+            det_labels: tensor([2, 3, 2, 3])
+            bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
+                        [0.2837, 0.8112, 0.5416, 0.2810],
+                        [0.7404, 0.6332, 0.1684, 0.9961],
+                        [0.2837, 0.8112, 0.5416, 0.2810]])
+            data:
+            polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
+                       [1, 2, 3, 4], [5, 6, 7, 8]]
+            det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
+            masks: [[11, 21, 31, 41], [51, 61, 71, 81],
+                    [11, 21, 31, 41], [51, 61, 71, 81]]
+        ) at 0x7f203542feb0>
     """
 
-    def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray,
-                                                  list]):
+    def __setattr__(self, name: str, value: Sized):
         """setattr is only used to set data.
 
         the value must have the attribute of `__len__` and have the same length
@@ -82,9 +138,8 @@ class InstanceData(BaseDataElement):
                     f'private attribute, which is immutable. ')
 
         else:
-            assert isinstance(value, (torch.Tensor, np.ndarray, list)), \
-                f'Can set {type(value)}, only support' \
-                f' {(torch.Tensor, np.ndarray, list)}'
+            assert isinstance(value,
+                              Sized), 'value must contain `_len__` attribute'
 
             if len(self) > 0:
                 assert len(value) == len(self), f'the length of ' \
@@ -95,6 +150,8 @@ class InstanceData(BaseDataElement):
                                                 f'{len(self)} '
             super().__setattr__(name, value)
 
+    __setitem__ = __setattr__
+
     def __getitem__(self, item: IndexType) -> 'InstanceData':
         """
         Args:
@@ -105,11 +162,13 @@ class InstanceData(BaseDataElement):
         Returns:
             obj:`InstanceData`: Corresponding values.
         """
-        assert len(self) > 0, ' This is a empty instance'
-
+        if isinstance(item, list):
+            item = np.array(item)
+        if isinstance(item, np.ndarray):
+            item = torch.from_numpy(item)
         assert isinstance(
             item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
-                   torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long))
+                   torch.BoolTensor, torch.cuda.BoolTensor))
 
         if isinstance(item, str):
             return getattr(self, item)
@@ -121,7 +180,7 @@ class InstanceData(BaseDataElement):
                 # keep the dimension
                 item = slice(item, None, len(self))
 
-        new_data = self.new(data={})
+        new_data = self.__class__(metainfo=self.metainfo)
         if isinstance(item, torch.Tensor):
             assert item.dim() == 1, 'Only support to get the' \
                                     ' values along the first dimension.'
@@ -140,17 +199,36 @@ class InstanceData(BaseDataElement):
                     new_data[k] = v[item]
                 elif isinstance(v, np.ndarray):
                     new_data[k] = v[item.cpu().numpy()]
-                elif isinstance(v, list):
-                    r_list = []
+                elif isinstance(
+                        v, (str, list, tuple)) or (hasattr(v, '__getitem__')
+                                                   and hasattr(v, 'cat')):
                     # convert to indexes from boolTensor
                     if isinstance(item,
                                   (torch.BoolTensor, torch.cuda.BoolTensor)):
-                        indexes = torch.nonzero(item).view(-1)
+                        indexes = torch.nonzero(item).view(
+                            -1).cpu().numpy().tolist()
+                    else:
+                        indexes = item.cpu().numpy().tolist()
+                    slice_list = []
+                    if indexes:
+                        for index in indexes:
+                            slice_list.append(slice(index, None, len(v)))
                     else:
-                        indexes = item
-                    for index in indexes:
-                        r_list.append(v[index])
-                    new_data[k] = r_list
+                        slice_list.append(slice(None, 0, None))
+                    r_list = [v[s] for s in slice_list]
+                    if isinstance(v, (str, list, tuple)):
+                        new_value = r_list[0]
+                        for r in r_list[1:]:
+                            new_value = new_value + r
+                    else:
+                        new_value = v.cat(r_list)
+                    new_data[k] = new_value
+                else:
+                    raise ValueError(
+                        f'The type of `{k}` is `{type(v)}`, which has no '
+                        'attribute of `cat`, so it does not '
+                        f'support slice with `bool`')
+
         else:
             # item is a slice
             for k, v in self.items():
@@ -191,24 +269,30 @@ class InstanceData(BaseDataElement):
                                            'elements in `instances_list` ' \
                                            'have the exact same key '
 
-        new_data = instances_list[0].new(data={})
+        new_data = instances_list[0].__class__(
+            metainfo=instances_list[0].metainfo)
         for k in instances_list[0].keys():
             values = [results[k] for results in instances_list]
             v0 = values[0]
             if isinstance(v0, torch.Tensor):
-                values = torch.cat(values, dim=0)
+                new_values = torch.cat(values, dim=0)
             elif isinstance(v0, np.ndarray):
-                values = np.concatenate(values, axis=0)
-            elif isinstance(v0, list):
-                values = list(itertools.chain(*values))
+                new_values = np.concatenate(values, axis=0)
+            elif isinstance(v0, (str, list, tuple)):
+                new_values = v0[:]
+                for v in values[1:]:
+                    new_values += v
+            elif hasattr(v0, 'cat'):
+                new_values = v0.cat(values)
             else:
                 raise ValueError(
-                    f'Can not concat the {k} which is a {type(v0)}')
-            new_data[k] = values
+                    f'The type of `{k}` is `{type(v0)}` which has no '
+                    'attribute of `cat`')
+            new_data[k] = new_values
         return new_data  # type:ignore
 
     def __len__(self) -> int:
-        """The length of instance data."""
+        """int: the length of InstanceData"""
         if len(self._data_fields) > 0:
             return len(self.values()[0])
         else:
diff --git a/tests/test_data/test_data_element.py b/tests/test_data/test_data_element.py
index 1ff1eb0a0d42cc95719b8809cdc409fcebcc6280..bd0da7a20d90622b79626e09a0e78e192989a178 100644
--- a/tests/test_data/test_data_element.py
+++ b/tests/test_data/test_data_element.py
@@ -112,7 +112,7 @@ class TestBaseDataElement(TestCase):
 
         # test new() with arguments
         metainfo, data = self.setup_data()
-        new_instances = instances.new(metainfo=metainfo, data=data)
+        new_instances = instances.new(metainfo=metainfo, **data)
         assert type(new_instances) == type(instances)
         assert id(new_instances.gt_instances) != id(instances.gt_instances)
         _, new_data = self.setup_data()
diff --git a/tests/test_data/test_instance_data.py b/tests/test_data/test_instance_data.py
index 17fc2bcaa9b7b391a97d75b406b4cecf3715b784..50266d7ca48d861267b578c19b81bc2cad32eb16 100644
--- a/tests/test_data/test_instance_data.py
+++ b/tests/test_data/test_instance_data.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import itertools
 import random
 from unittest import TestCase
 
@@ -9,6 +10,66 @@ import torch
 from mmengine.data import BaseDataElement, InstanceData
 
 
+class TmpObject:
+
+    def __init__(self, tmp) -> None:
+        assert isinstance(tmp, list)
+        if len(tmp) > 0:
+            for t in tmp:
+                assert isinstance(t, list)
+        self.tmp = tmp
+
+    def __len__(self):
+        return len(self.tmp)
+
+    def __getitem__(self, item):
+        if type(item) == int:
+            if item >= len(self) or item < -len(self):  # type:ignore
+                raise IndexError(f'Index {item} out of range!')
+            else:
+                # keep the dimension
+                item = slice(item, None, len(self))
+        return TmpObject(self.tmp[item])
+
+    @staticmethod
+    def cat(tmp_objs):
+        assert all(isinstance(results, TmpObject) for results in tmp_objs)
+        if len(tmp_objs) == 1:
+            return tmp_objs[0]
+        tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
+        tmp_list = list(itertools.chain(*tmp_list))
+        new_data = TmpObject(tmp_list)
+        return new_data
+
+    def __repr__(self):
+        return str(self.tmp)
+
+
+class TmpObjectWithoutCat:
+
+    def __init__(self, tmp) -> None:
+        assert isinstance(tmp, list)
+        if len(tmp) > 0:
+            for t in tmp:
+                assert isinstance(t, list)
+        self.tmp = tmp
+
+    def __len__(self):
+        return len(self.tmp)
+
+    def __getitem__(self, item):
+        if type(item) == int:
+            if item >= len(self) or item < -len(self):  # type:ignore
+                raise IndexError(f'Index {item} out of range!')
+            else:
+                # keep the dimension
+                item = slice(item, None, len(self))
+        return TmpObject(self.tmp[item])
+
+    def __repr__(self):
+        return str(self.tmp)
+
+
 class TestInstanceData(TestCase):
 
     def setup_data(self):
@@ -18,10 +79,18 @@ class TestInstanceData(TestCase):
         instances_infos = [1] * 5
         bboxes = torch.rand((5, 4))
         labels = np.random.rand(5)
+        kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
+        ids = (1, 2, 3, 4, 5)
+        name_ids = '12345'
+        polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist())
         instance_data = InstanceData(
             metainfo=metainfo,
             bboxes=bboxes,
             labels=labels,
+            polygons=polygons,
+            kps=kps,
+            ids=ids,
+            name_ids=name_ids,
             instances_infos=instances_infos)
         return instance_data
 
@@ -34,10 +103,6 @@ class TestInstanceData(TestCase):
         with self.assertRaises(AttributeError):
             instance_data._data_fields = 1
 
-        # value only supports (torch.Tensor, np.ndarray, list)
-        with self.assertRaises(AssertionError):
-            instance_data.v = 'value'
-
         # The data length in InstanceData must be the same
         with self.assertRaises(AssertionError):
             instance_data.keypoints = torch.rand((17, 2))
@@ -48,14 +113,15 @@ class TestInstanceData(TestCase):
     def test_getitem(self):
         instance_data = InstanceData()
         # length must be greater than 0
-        with self.assertRaises(AssertionError):
+        with self.assertRaises(IndexError):
             instance_data[1]
 
         instance_data = self.setup_data()
         assert len(instance_data) == 5
         slice_instance_data = instance_data[:2]
         assert len(slice_instance_data) == 2
-
+        slice_instance_data = instance_data[1]
+        assert len(slice_instance_data) == 1
         # assert the index should in 0 ~ len(instance_data) -1
         with pytest.raises(IndexError):
             instance_data[5]
@@ -80,6 +146,40 @@ class TestInstanceData(TestCase):
         bool_tensor = torch.rand(5) > 0.5
         bool_index_instance_data = instance_data[bool_tensor]
         assert len(bool_index_instance_data) == bool_tensor.sum()
+        bool_tensor = torch.rand(5) > 1
+        empty_instance_data = instance_data[bool_tensor]
+        assert len(empty_instance_data) == bool_tensor.sum()
+
+        # test list index
+        list_index = [1, 2]
+        list_index_instance_data = instance_data[list_index]
+        assert len(list_index_instance_data) == len(list_index)
+
+        # text list bool
+        list_bool = [True, False, True, False, False]
+        list_bool_instance_data = instance_data[list_bool]
+        assert len(list_bool_instance_data) == 2
+
+        # test numpy
+        long_numpy = np.random.randint(5, size=2)
+        long_numpy_instance_data = instance_data[long_numpy]
+        assert len(long_numpy_instance_data) == len(long_numpy)
+
+        bool_numpy = np.random.rand(5) > 0.5
+        bool_numpy_instance_data = instance_data[bool_numpy]
+        assert len(bool_numpy_instance_data) == bool_numpy.sum()
+
+        # without cat
+        instance_data.polygons = TmpObjectWithoutCat(
+            np.arange(25).reshape((5, -1)).tolist())
+        bool_numpy = np.random.rand(5) > 0.5
+        with pytest.raises(
+                ValueError,
+                match=('The type of `polygons` is '
+                       f'`{type(instance_data.polygons)}`, '
+                       'which has no attribute of `cat`, so it does not '
+                       f'support slice with `bool`')):
+            bool_numpy_instance_data = instance_data[bool_numpy]
 
     def test_cat(self):
         instance_data_1 = self.setup_data()
@@ -97,6 +197,24 @@ class TestInstanceData(TestCase):
         # Input List length must be greater than 0
         with self.assertRaises(AssertionError):
             InstanceData.cat([])
+        instance_data_2 = instance_data_1.clone()
+        instance_data_2 = instance_data_2[torch.zeros(5) > 0.5]
+        cat_instance_data = InstanceData.cat(
+            [instance_data_1, instance_data_2])
+        cat_instance_data = InstanceData.cat([instance_data_1])
+        assert len(cat_instance_data) == 5
+
+        # test custom data cat
+        instance_data_1.polygons = TmpObjectWithoutCat(
+            np.arange(25).reshape((5, -1)).tolist())
+        instance_data_2 = instance_data_1.clone()
+        with pytest.raises(
+                ValueError,
+                match=('The type of `polygons` is '
+                       f'`{type(instance_data_1.polygons)}` '
+                       'which has no attribute of `cat`')):
+            cat_instance_data = InstanceData.cat(
+                [instance_data_1, instance_data_2])
 
     def test_len(self):
         instance_data = self.setup_data()