diff --git a/docs/zh_cn/tutorials/abstract_data_interface.md b/docs/zh_cn/tutorials/abstract_data_interface.md new file mode 100644 index 0000000000000000000000000000000000000000..ce85d7d7c5a696e9d523bb83507a359d80fa6119 --- /dev/null +++ b/docs/zh_cn/tutorials/abstract_data_interface.md @@ -0,0 +1,443 @@ +# 抽象数æ®æŽ¥å£ + +在模型的è®ç»ƒ/测试过程ä¸ï¼Œç»„件之间往往有大é‡çš„æ•°æ®éœ€è¦ä¼ 递,ä¸åŒçš„算法需è¦ä¼ 递的数æ®ç»å¸¸æ˜¯ä¸ä¸€æ ·çš„, +例如,è®ç»ƒå•é˜¶æ®µæ£€æµ‹å™¨éœ€è¦èŽ·å¾—æ•°æ®é›†çš„æ ‡æ³¨æ¡†ï¼ˆground truth bounding boxesï¼‰å’Œæ ‡ç¾ï¼ˆground truth box labels),è®ç»ƒ Mask R-CNN 时还需è¦å®žä¾‹æŽ©ç (instance masks)。 +è®ç»ƒè¿™äº›æ¨¡åž‹æ—¶çš„代ç 如下所示 + +```python +for img, img_metas, gt_bboxes, gt_labels in data_loader: + loss = retinanet(img, img_metas, gt_bboxes, gt_labels) +``` + +```python +for img, img_metas, gt_bboxes, gt_masks, gt_labels in data_loader: + loss = mask_rcnn(img, img_metas, gt_bboxes, gt_masks, gt_labels) +``` + +å¯ä»¥å‘现,在ä¸åŠ å°è£…的情况下,ä¸åŒç®—法所需数æ®çš„ä¸ä¸€è‡´å¯¼è‡´äº†ä¸åŒç®—法模å—之间接å£çš„ä¸ä¸€è‡´ï¼Œå½±å“了算法库的拓展性,åŒæ—¶ä¸€ä¸ªç®—法库内的模å—为了ä¿æŒå…¼å®¹æ€§å¾€å¾€åœ¨æŽ¥å£ä¸Šå˜åœ¨å†—余。 +ä¸Šè¿°å¼Šç«¯åœ¨ç®—æ³•åº“ä¹‹é—´ä¼šä½“çŽ°åœ°æ›´åŠ æ˜Žæ˜¾ï¼Œå¯¼è‡´åœ¨å®žçŽ°å¤šä»»åŠ¡ï¼ˆåŒæ—¶è¿›è¡Œå¦‚è¯ä¹‰åˆ†å‰²ã€æ£€æµ‹ã€å…³é”®ç‚¹æ£€æµ‹ç‰å¤šä¸ªä»»åŠ¡ï¼‰æ„ŸçŸ¥æ¨¡åž‹æ—¶æ¨¡å—难以å¤ç”¨ï¼ŒæŽ¥å£éš¾ä»¥æ‹“展。 + +为了解决上述问题,MMEngine 定义了一套抽象的数æ®æŽ¥å£æ¥å°è£…模型è¿è¡Œè¿‡ç¨‹ä¸çš„å„ç§æ•°æ®ã€‚å‡è®¾å°†ä¸Šè¿°ä¸åŒçš„æ•°æ®å°è£…è¿› `data_sample` ,ä¸åŒç®—法的è®ç»ƒéƒ½å¯ä»¥è¢«æŠ½è±¡å’Œç»Ÿä¸€æˆå¦‚下代ç + +```python +for img, data_sample in dataloader: + loss = model(img, data_sample) +``` + +通过对å„ç§æ•°æ®æ供统一的å°è£…,抽象数æ®æŽ¥å£ç»Ÿä¸€å¹¶ç®€åŒ–了算法库ä¸å„个模å—的接å£ï¼Œå¯ä»¥è¢«ç”¨äºŽç®—æ³•åº“ä¸ dataset,model,visualizer,和 evaluator 组件之间,或者 model 内å„个模å—之间的数æ®ä¼ 递。 +抽象数æ®æŽ¥å£å®žçŽ°äº†åŸºæœ¬çš„增/åˆ /改/查功能,åŒæ—¶æ”¯æŒä¸åŒè®¾å¤‡ä¹‹é—´çš„è¿ç§»ï¼Œæ”¯æŒç±»å—å…¸å’Œå¼ é‡çš„æ“作,å¯ä»¥å……分满足算法库对于这些数æ®çš„使用è¦æ±‚。 +基于 MMEngine 的算法库å¯ä»¥ç»§æ‰¿è¿™å¥—抽象数æ®æŽ¥å£å¹¶å®žçŽ°è‡ªå·±çš„抽象数æ®æŽ¥å£æ¥é€‚应ä¸åŒç®—法ä¸æ•°æ®çš„特点与实际需è¦ï¼Œåœ¨ä¿æŒç»Ÿä¸€æŽ¥å£çš„åŒæ—¶æ高了算法模å—的拓展性。 + +## 设计 + +一个算法库ä¸çš„æ•°æ®å¯ä»¥è¢«å½’ç±»æˆå…·æœ‰ä¸åŒæ€§è´¨çš„æ•°æ®å…ƒç´ 。一个è®ç»ƒæ ·æœ¬ï¼ˆå¦‚ä¸€å¼ å›¾ç‰‡ï¼‰çš„æ‰€æœ‰æ•°æ®å…ƒç´ æž„æˆäº†ä¸€ä¸ªè®ç»ƒæ ·æœ¬çš„完整数æ®ï¼Œç§°ä¸ºæ ·æœ¬æ•°æ®ã€‚相应地,MMEngine 为数æ®å…ƒç´ å’Œæ ·æœ¬æ•°æ®åˆ†åˆ«å®šä¹‰äº†ä¸€ç§å°è£…。 + +1. æ•°æ®å…ƒç´ çš„å°è£…: æ•°æ®å…ƒç´ 指的是æŸä¸€ç®—法任务上的预测数æ®æˆ–æ ‡æ³¨ï¼Œä¾‹å¦‚æ£€æµ‹æ¡†ï¼Œå®žä¾‹æŽ©ç ,è¯ä¹‰åˆ†å‰²æŽ©ç ç‰ã€‚å› ä¸ºæ ‡æ³¨æ•°æ®å’Œé¢„测数æ®å¾€å¾€å…·æœ‰ç›¸ä¼¼çš„æ€§è´¨ï¼ˆä¾‹å¦‚æ¨¡åž‹çš„é¢„æµ‹æ¡†å’Œæ ‡æ³¨æ¡†å…·æœ‰ç›¸åŒçš„性质),MMEngine 使用相åŒçš„抽象数æ®æŽ¥å£æ¥å°è£…预测数æ®å’Œæ ‡æ³¨æ•°æ®ï¼Œå¹¶æŽ¨è使用命åæ¥åŒºåˆ†ä»–们,如使用 `gt_instances` å’Œ `pred_instances` æ¥åŒºåˆ†æ ‡æ³¨å’Œé¢„测的实例数æ®ã€‚å¦å¤–,我们将数æ®å…ƒç´ 区分为实例级别,åƒç´ çº§åˆ«ï¼Œå’Œæ ‡ç¾çº§åˆ«ã€‚这些类型å„æœ‰è‡ªå·±çš„ç‰¹ç‚¹ï¼Œå› æ¤ï¼ŒMMEngine 定义了数æ®å…ƒç´ 的基类 `BaseDataElement`,并由æ¤æ´¾ç”Ÿå‡ºäº† 3 类数æ®ç»“æž„æ¥å°è£…ä¸åŒç±»åž‹çš„æ ‡æ³¨æ•°æ®æˆ–者模型的预测结果:`InstanceData`, `PixelData`, å’Œ `LabelData`。这些接å£å°†è¢«ç”¨äºŽæ¨¡åž‹å†…å„个模å—之间的数æ®ä¼ 递。 + +2. æ ·æœ¬æ•°æ®çš„å°è£…:一个è®ç»ƒæ ·æœ¬ï¼ˆä¾‹å¦‚ä¸€å¼ å›¾ç‰‡ï¼‰çš„æ‰€æœ‰æ ‡æ³¨å’Œé¢„æµ‹æž„æˆäº†ä¸€ä¸ªæ ·æœ¬æ•°æ®ã€‚ä¸€èˆ¬æƒ…å†µä¸‹ï¼Œä¸€å¼ å›¾ç‰‡å¯ä»¥åŒæ—¶æœ‰å¤šç§ç±»åž‹çš„æ ‡æ³¨å’Œ/或预测(例如,åŒæ—¶æ‹¥æœ‰åƒç´ 级别的è¯ä¹‰åˆ†å‰²æ ‡æ³¨å’Œå®žä¾‹çº§åˆ«çš„æ£€æµ‹æ¡†æ ‡æ³¨ï¼‰ã€‚å› æ¤ï¼ŒMMEngine 定义了 `BaseDataSample`ä½œä¸ºæ ·æœ¬æ•°æ®å°è£…的基类。也就是说,**`BaseDataSample` 的属性会是å„ç§ç±»åž‹çš„æ•°æ®å…ƒç´ **,OpenMMLab 算法库将基于 `BaseDataSample` 实现自己的抽象数æ®æŽ¥å£ï¼Œæ¥å°è£…一个算法库ä¸å•ä¸ªæ ·æœ¬çš„所有相关数æ®ï¼Œä½œä¸º dataset,model,visualizer,和 evaluator 组件之间的数æ®æŽ¥å£ã€‚ + +两ç§ç±»åž‹çš„å°è£…和他们的继承关系如下图所示 + + + +为了ä¿è¯æŠ½è±¡æ•°æ®æŽ¥å£å†…æ•°æ®çš„完整性,抽象数æ®æŽ¥å£å†…部有两ç§æ•°æ®ï¼Œé™¤äº†è¢«å°è£…çš„æ•°æ®ï¼ˆdata)本身,还有一ç§æ˜¯æ•°æ®çš„元信æ¯ï¼ˆmetainfo),例如图片大å°å’Œ ID ç‰ã€‚ +两ç§ç±»åž‹çš„抽象数æ®æŽ¥å£éƒ½å¯ä»¥ä½œä¸º Python 类去使用和æ“作他们的属性。åŒæ—¶ï¼Œå› 为他们å°è£…çš„æ•°æ®å¤§å¤šæ˜¯ Tensor,他们也æ供了类似 Tensor 的基础æ“作。 + +## 用法 + +### BaseDataElement + +MMEngine 为数æ®å…ƒç´ çš„å°è£…æ供了一个基类 `BaseDataElement`。 +基于 `BaseDataElement`,MMEngine 还实现了 `InstanceData`, `PixelData`, `LabelData` å’Œ `GeneralData` 四个典型的å类,å°è£…了实例级别,åƒç´ çº§åˆ«ï¼Œæ ‡ç¾çº§åˆ«å’Œå…¶ä»–普通的数æ®å…ƒç´ ,并针对他们的数æ®ç‰¹æ€§æ”¯æŒäº†ä¸€äº›é¢å¤–的功能。 + +1. `InstanceData`:å°è£…检测框ã€æ¡†å¯¹åº”çš„æ ‡ç¾å’Œå®žä¾‹æŽ©ç ã€ç”šè‡³å…³é”®ç‚¹ç‰å®žä¾‹çº§åˆ«æ•°æ®ï¼Œ`InstanceData` å‡å®šå®ƒå°è£…çš„æ•°æ®å…·æœ‰ç›¸åŒçš„长度 N,N 代表实例的个数,并基于æ¤å‡å®šå¯¹æ•°æ®è¿›è¡Œæ ¡éªŒã€æ”¯æŒå¯¹å®žä¾‹è¿›è¡Œç´¢å¼•å’Œæ‹¼æŽ¥ã€‚ +2. `PixelData`:å°è£…é€åƒç´ 级别的数æ®ï¼Œå¦‚è¯ä¹‰åˆ†å‰²å›¾å’Œæ·±åº¦å›¾ç‰ã€‚`PixelData` å‡å®šå®ƒå°è£…çš„æ•°æ®æœ‰ç›¸åŒçš„长度和宽度,第一和第二维为图片的长宽,第三维为通é“数。`PixelData` 基于æ¤å‡å®šå¯¹æ•°æ®è¿›è¡Œæ ¡éªŒã€æ”¯æŒå¯¹å®žä¾‹è¿›è¡Œç©ºé—´ç»´åº¦çš„索引和å„维度的拼接。 +3. `LabelData`:å°è£…æ ‡ç¾æ•°æ®ï¼Œå¦‚åœºæ™¯åˆ†ç±»æ ‡ç¾ç‰ã€‚ +4. `GeneralData`:`BaseDataElement` çš„ç‰ä»·ç±»ã€‚虽然 `BaseDataElement` å¯ä»¥ä½œä¸ºç‹¬ç«‹çš„模å—被使用,但是我们ä¸æŽ¨èç”¨æˆ·ç›´æŽ¥ä½¿ç”¨åŸºç±»ã€‚å› æ¤ï¼ŒMMEngine é¢å¤–实现了 `GeneralData` 。`GeneralData` ä¿æŒäº†å’Œ `InstanceData`, `PixelData`, ä»¥åŠ `LabelData` 一致的命åä¹ æƒ¯å’Œç»§æ‰¿å±‚æ¬¡ã€‚å®ƒæ‹¥æœ‰å’Œ `BaseDataElement` å®Œå…¨ä¸€æ ·çš„åŠŸèƒ½å’ŒæŽ¥å£ï¼Œå¯¹æ•°æ®å…ƒç´ 没有任何å‡å®šï¼Œä»…支æŒæœ€åŸºæœ¬çš„å¢žåˆ æ”¹æŸ¥åŠŸèƒ½ã€‚æˆ‘ä»¬æŽ¨è用户在实际应用过程ä¸ä½¿ç”¨ `GeneralData` è€Œéž `BaseDataElement` æ¥ä¿æŒä½¿ç”¨çš„一致性,在开å‘过程ä¸ç»§æ‰¿ `BaseDataElement` æ¥ä¿æŒç»§æ‰¿å±‚次的统一。在下文ä¸ï¼Œä¸ºäº†é˜æ˜Žæ•°æ®å…ƒç´ å°è£…的基本用法,我们还是使用 `BaseDataElement` æ¥è¿›è¡Œæ述和用例展示。 + +`BaseDataElement` ä¸å˜åœ¨ä¸¤ç§ç±»åž‹çš„æ•°æ®ï¼Œä¸€ç§æ˜¯ `data` ç±»åž‹ï¼Œå¦‚æ ‡æ³¨æ¡†ã€æ¡†çš„æ ‡ç¾ã€å’Œå®žä¾‹æŽ©ç ç‰ï¼›å¦ä¸€ç§æ˜¯ `metainfo` 类型,包å«æ•°æ®çš„元信æ¯ä»¥ç¡®ä¿æ•°æ®çš„完整性,如 `img_shape`, `img_id` ç‰æ•°æ®æ‰€åœ¨å›¾ç‰‡çš„一些基本信æ¯ï¼Œæ–¹ä¾¿å¯è§†åŒ–ç‰æƒ…况下对数æ®è¿›è¡Œæ¢å¤å’Œä½¿ç”¨ã€‚用户在创建 `BaseDataElement` 的过程ä¸éœ€è¦å¯¹è¿™ä¸¤ç±»å±žæ€§çš„æ•°æ®è¿›è¡Œæ˜¾å¼åœ°åŒºåˆ†å’Œå£°æ˜Žã€‚ + +#### 1. æ•°æ®å…ƒç´ 的创建 + +```python +# å¯ä»¥å£°æ˜Žä¸€ä¸ªç©ºçš„ object +gt_instances = BaseDataElement() + +bboxes = torch.rand((5, 4)) # å‡å®š bboxes 是一个 Nx4 ç»´çš„ tensor,N 代表框的个数 +scores = torch.rand((5,)) # å‡å®šæ¡†çš„分数是一个 N ç»´çš„ tensor,N 代表框的个数 +img_id = 0 # 图åƒçš„ ID +H = 800 # 图åƒçš„高度 +W = 1333 # 图åƒçš„宽度 + +# 显å¼å£°æ˜Ž BaseDataElement çš„å‚æ•° metainfo å’Œ data +gt_instances = BaseDataElement( + metainfo=dict(img_id=img_id, img_shape=(H, W)), + data=dict(bboxes=bboxes, scores=scores)) + +# ä¸æ˜¾å¼å£°æ˜Žçš„æ—¶å€™ï¼Œä¼ å…¥å—典将设置 BaseDataElement çš„å‚æ•° metainfo +gt_instances = BaseDataElement(dict(img_id=img_id, img_shape=(H, W))) +``` + +#### 2. `new` 函数 + +用户å¯ä»¥ä½¿ç”¨ `new()` 函数通过已有的数æ®æŽ¥å£åˆ›å»ºä¸€ä¸ªå…·æœ‰ç›¸åŒçŠ¶æ€å’Œæ•°æ®çš„抽象数æ®æŽ¥å£ã€‚用户å¯ä»¥åœ¨åˆ›å»ºæ–° `BaseDataElement` 时设置 metainfo å’Œ data,使得新的 BaseDataElement 有相åŒçš„状æ€ä½†æ˜¯ä¸åŒçš„æ•°æ®ã€‚ +也å¯ä»¥ç›´æŽ¥ä½¿ç”¨ `new()` æ¥èŽ·å¾—一份深拷è´ã€‚ + +```python +gt_instances = BaseDataElement() + +# å¯ä»¥åœ¨åˆ›å»ºæ–° `BaseDataElement` 时设置 metainfo å’Œ data,使得新的 BaseDataElement 有ä¸åŒçš„æ•°æ®ä½†æ˜¯æ•°æ®åœ¨ç›¸åŒçš„ device 上 +gt_instances1 = gt_instance.new( + metainfo=dict(img_id=1, img_shape=(640, 640)), + data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5,))) +) + +# 也å¯ä»¥å£°æ˜Žä¸€ä¸ªæ–°çš„ object,新的 object 会拥有和 gt_instance 相åŒçš„ data å’Œ metainfo 内容 +gt_instances2 = gt_instances1.new() +``` + +#### 3. å±žæ€§çš„å¢žåŠ ä¸ŽæŸ¥è¯¢ + +用户å¯ä»¥åƒå¢žåŠ ç±»å±žæ€§é‚£æ ·å¢žåŠ `BaseDataElement` 的属性,æ¤æ—¶æ•°æ®ä¼šè¢«**当作 data 类型**å¢žåŠ åˆ° `BaseDataElement` ä¸ã€‚ +如果需è¦å¢žåŠ metainfo 属性,用户应当使用 `set_metainfo`。 +用户å¯ä»¥é€šè¿‡ `metainfo_keys`,`metainfo_values`,和`metainfo_items` æ¥è®¿é—®åªå˜åœ¨äºŽ metainfo ä¸çš„键值, +也å¯ä»¥é€šè¿‡ `data_keys`,`data_values`,和 `data_items` æ¥è®¿é—®åªå˜åœ¨äºŽ data ä¸çš„键值。 +用户还能通过 `keys`,`values`, `items` æ¥è®¿é—® `BaseDataElement` 的所有的属性并且ä¸åŒºåˆ†ä»–们的类型。 + +**注æ„:** + +1. `BaseDataElement` ä¸æ”¯æŒ metainfo å’Œ data 属性ä¸æœ‰åŒåçš„å—段,所以用户应当é¿å… metainfo å’Œ data 属性ä¸è®¾ç½®ç›¸åŒçš„å—段,å¦åˆ™ `BaseDataElement` 会报错。 +2. 考虑到 `InstanceData` å’Œ `PixelData` 支æŒå¯¹æ•°æ®è¿›è¡Œåˆ‡ç‰‡æ“作,为了é¿å… `[]` 用法的ä¸ä¸€è‡´ï¼ŒåŒæ—¶å‡å°‘åŒç§éœ€æ±‚çš„ä¸åŒæ–¹æ³•ï¼Œ`BaseDataElement` ä¸æ”¯æŒåƒå—å…¸é‚£æ ·è®¿é—®å’Œè®¾ç½®å®ƒçš„å±žæ€§ï¼Œæ‰€ä»¥ç±»ä¼¼ `BaseDataElement[name]` çš„å–值赋值æ“作是ä¸è¢«æ”¯æŒçš„。 + +```python +gt_instances = BaseDataElement() +# 设置 gt_instances çš„ meta å—段,img_id å’Œ img_shape 会被作为 metainfo çš„å—段æˆä¸º gt_instances 的属性 +gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)) +assert 'img_shape' in gt_instaces.metainfo_keys() +# 'img_shape' 是 gt_instances 的属性 +assert 'img_shape' in gt_instaces +# img_shape ä¸æ˜¯ gt_instances çš„ data å—段 +assert 'img_shape' not in gt_instaces.data_keys() +# 通过 keys æ¥è®¿é—®æ‰€æœ‰å±žæ€§ +assert 'img_shape' in gt_instaces.keys() +# è®¿é—®ç±»å±žæ€§ä¸€æ ·è®¿é—® 'img_shape' +print(gt_instances.img_shape) + +# 直接设置 gt_instance çš„ scores 属性,默认该数æ®å±žäºŽ data +gt_instances.scores = torch.rand((5,)) +assert 'scores' in gt_instances.data_keys() +# 'scores' 是 gt_instances 的属性 +assert 'scores' in gt_instances +# 通过 keys æ¥è®¿é—®æ‰€æœ‰å±žæ€§ +assert 'scores' in gt_instances.keys() +# scores ä¸æ˜¯ gt_instances çš„ metainfo å—段 +assert 'scores' not in gt_instances.metainfo_keys() +# è®¿é—®ç±»å±žæ€§ä¸€æ ·è®¿é—® 'scores' +print(gt_instances.scores) + +# 设置 gt_instances çš„ data å—段 bboxes +gt_instances.bboxes = torch.rand((5, 4)) +assert 'bboxes' in gt_instances.data_keys() +# 'bboxes' 是 gt_instances 的属性 +assert 'bboxes' in gt_instances +# 通过 keys æ¥è®¿é—®æ‰€æœ‰å±žæ€§ +assert 'bboxes' in gt_instances.keys() +# bboxes ä¸æ˜¯ gt_instances çš„ metainfo å—段 +assert 'bboxes' not in gt_instances.metainfo_keys() +# è®¿é—®ç±»å±žæ€§ä¸€æ ·è®¿é—® 'bboxes' +print(gt_instances.bboxes) + +for k, v in gt_instances.items(): + print(f'{k}: {v}') # åŒ…å« img_shapes, img_id, bboxes,scores + +for k, v in gt_instances.metainfo_items(): + print(f'{k}: {v}') # åŒ…å« img_shapes, img_id + +for k, v in gt_instances.data_items(): + print(f'{k}: {v}') # åŒ…å« bboxes,scores +``` + +#### 4. å±žæ€§çš„åˆ æ”¹ + +`BaseDataElement` 支æŒç”¨æˆ·å¯ä»¥åƒä½¿ç”¨ä¸€ä¸ªç±»ä¸€æ ·å¯¹å®ƒçš„å±žæ€§è¿›è¡Œåˆ æ”¹ +åŒæ—¶ï¼Œ `BaseDataElement` æ”¯æŒ `get` æ¥å…许在访问ä¸åˆ°å˜é‡æ—¶è®¾ç½®é»˜è®¤å€¼ï¼Œä¹Ÿæ”¯æŒ `pop` 在在访问属性åŽåˆ 除属性。 + +```python +gt_instances = BaseDataElement( + metainfo=dict(img_id=0, img_shape=(640, 640)), + data=dict(bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))) + +# 对类的属性进行修改 +gt_instances.img_shape = (1280, 1280) +gt_instances.img_shape # (1280, 1280) +gt_instances.bboxes = gt_instances.bboxes * 2 + +# æ供了å¯è®¾ç½®é»˜è®¤å€¼çš„获å–æ–¹å¼ get +gt_instances.get('img_shape', None) # (640, 640) +gt_instances.get('bboxes', None) # 6x4 tensor + +# å±žæ€§çš„åˆ é™¤ +del gt_instances.img_shape +del gt_instances.bboxes +assert 'img_shape' in gt_instances +assert 'bboxes' not in gt_instances + +# æ供了便æ·çš„å±žæ€§åˆ é™¤å’Œè®¿é—®æ“作 pop +gt_instances.pop('img_shape', None) # None +gt_instances.pop('bboxes', None) # None +``` + +#### 5. ç±»å¼ é‡æ“作 + +用户å¯ä»¥åƒ torch.Tensor é‚£æ ·å¯¹ `BaseDataElement` çš„ data 进行状æ€è½¬æ¢ï¼Œç›®å‰æ”¯æŒ `cuda`, `cpu`, `to`, `numpy` ç‰æ“作。 +å…¶ä¸ï¼Œ`to` 函数拥有和 `torch.Tensor.to()` 相åŒçš„接å£ï¼Œä½¿å¾—用户å¯ä»¥çµæ´»åœ°å°†è¢«å°è£…çš„ tensor 进行状æ€è½¬æ¢ã€‚ + +```python +# 将所有 data 转移到 GPU 上 +cuda_instances = gt_instances.cuda() +cuda_instances = gt_instancess.to('cuda:0') + +# 将所有 data 转移到 cpu 上 +cpu_instances = cuda_instances.cpu() +cpu_instances = cuda_instances.to('cpu') + +# 将所有 data å˜æˆ FP16 +fp16_instances = cuda_instances.to( + device=None, dtype=torch.float16, non_blocking=False, copy=False, + memory_format=torch.preserve_format) + +# 阻æ–所有 data 的梯度 +cpu_instances = cuda_instances.detach() + +# 转移 data 到 numpy array +np_instances = cpu_instances.numpy() +``` + +#### 6. 属性的展示 + +`BaseDataElement` 还实现了 `__nice__` å’Œ `__repr__`ï¼Œå› æ¤ï¼Œç”¨æˆ·å¯ä»¥ç›´æŽ¥é€šè¿‡ `print` 函数看到其ä¸çš„所有数æ®ä¿¡æ¯ã€‚ +åŒæ—¶ï¼Œä¸ºäº†ä¾¿æ·å¼€å‘者 debug,`BaseDataElement` ä¸çš„å±žæ€§éƒ½ä¼šæ·»åŠ è¿› `__dict__` ä¸ï¼Œæ–¹ä¾¿ç”¨æˆ·åœ¨ IDE ç•Œé¢å¯ä»¥ç›´è§‚看到 `BaseDataElement` ä¸çš„内容。 +一个完整的属性展示如下 + +```python +>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) +>>> instance_data = BaseDataElement(metainfo=img_meta) +>>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3]) +>>> instance_data.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3]) +>>> print(results) +<BaseDataElement( + META INFORMATION +img_shape: (800, 1196, 3) +pad_shape: (800, 1216, 3) + DATA FIELDS +shape of det_labels: torch.Size([4]) +shape of det_scores: torch.Size([4]) +) at 0x7f84acd10f90> +``` + +### BaseDataSample + +MMEngine ä¸ºæ ·æœ¬æ•°æ®çš„å°è£…æ供了一个基类 `BaseDataSample`,OpenMMLab çš„æ¯ä¸ªç®—法库都应该继承 `BaseDataSample` å®žçŽ°è‡ªå·±çš„æ ·æœ¬æ•°æ®å°è£…ï¼Œå¹¶è§„çº¦å’Œæ ¡éªŒè¯¥ç®—æ³•åº“ä¸çš„常è§å—æ®µã€‚ç®—æ³•åº“è‡ªå·±å®žçŽ°çš„æ ·æœ¬æ•°æ®å°è£…会作为该算法库内 dataset,visualizer,evaluator,model 组件之间的数æ®æŽ¥å£è¿›è¡Œæµé€šã€‚ +`BaseDataSample` 虽然å¯ä»¥ä½œä¸ºä¸€ä¸ªæ¨¡å—被å•ç‹¬ä½¿ç”¨ï¼Œä½†æ˜¯æˆ‘们ä¸æŽ¨è `BaseDataSample` è¿™ç§ç”¨æ³•ã€‚ + +`BaseDataSample` 内部ä¾ç„¶åŒºåˆ† metainfo å’Œ data,并且支æŒåƒç±»ä¸€æ ·å¯¹å…¶å±žæ€§è¿›è¡Œè®¾ç½®å’Œè°ƒæ•´ï¼Œä¸ºäº†ä¿è¯ç”¨æˆ·ä½“验的一致性,`BaseDataSample` 的外部接å£ç”¨æ³•å’Œ `BaseDataElement` ä¿æŒä¸€è‡´ã€‚ + +åŒæ—¶ï¼Œç”±äºŽ `BaseDataSample` 作为基类一般ä¸ä¼šç›´æŽ¥ä½¿ç”¨ï¼Œä¸ºäº†æ–¹ä¾¿ä¸‹æ¸¸ç®—法库快速定义其å类,并对åç±»çš„å±žæ€§è¿›è¡Œè§„çº¦å’Œæ ¡éªŒã€‚ +`BaseDataSample` é¢å¤–æä¾›äº†ä¸€å¥—å†…éƒ¨æŽ¥å£ `_get_field`, `_del_field` å’Œ `_set_field` æ¥ä¾¿åˆ©å®ƒçš„å类快æ·åœ°å®šä¹‰å’Œè§„约 data å±žæ€§çš„å¢žåˆ æ”¹æŸ¥ã€‚ +`_set_field` ä¸ä¼šè¢«å½“作外部接å£ç›´æŽ¥ä½¿ç”¨ï¼Œè€Œæ˜¯è¢«ç”¨æ¥å®šä¹‰å±žæ€§ï¼ˆproperty) çš„ `setter` 并æä¾›åŸºæœ¬çš„ç±»åž‹æ ¡éªŒã€‚ + +一个简å•ç²—略的实现和用例如下。 + +```python +from abc import ABC +from functools import partial + + +class BaseDataSample(ABC): + + def __init__(self, metainfo=dict(), data=dict()): + self._data_fields = set() + self._metainfo_fields = set() + + # 其他功能实现 + ... + + def _get_field(self, name): + return getattr(self, name) + + def _set_field(self, val, name, dtype): + assert isinstance(val, dtype) + super().__setattr__(name, val) + self._data_fields.add(name) + + def _del_field(self, name): + super().__delattr__(name) + self._data_fields.remove(name) + +``` + +基于 `BaseDataSample`,下游算法库å¯ä»¥å®šä¹‰ `DetDataSample`,并且使用 `BaseDataSample` ä¸çš„接å£ï¼Œå¿«é€Ÿå®šä¹‰ 3 个 property:proposals,gt_instances,pred_instances,并约æŸä»–们的类型。 + +```python +class DetDataSample(BaseDataSample): + + proposals = property( + # 定义了 get 方法,通过 name '_proposals' æ¥è®¿é—®å®žé™…维护的å˜é‡ + fget=partial(BaseDataSample._get_field, name='_proposals'), + # 定义了 set 方法,将实际维护的å˜é‡è®¾ç½®ä¸º '_proposals',并在设置的时候检查类型是å¦æ˜¯ dtype 定义的类型 InstanceData + fset=partial(BaseDataSample._set_field, name='_proposals', dtype=InstanceData), + fdel=partial(BaseDataSample._del_field, name='_proposals'), + doc='Region proposals of an image' + ) + + gt_instances = property( + fget=partial(BaseDataSample._get_field, name='_gt_instances'), + fset=partial(BaseDataSample._set_field, name='_gt_instances', dtype=InstanceData), + fdel=partial(BaseDataSample._del_field, name='_gt_instances'), + doc='Ground truth instances of an image' + ) + + pred_instances = property( + fget=partial(BaseDataSample._get_field, name='_pred_instances'), + fset=partial(BaseDataSample._set_field, name='_pred_instances', dtype=InstanceData), + fdel=partial(BaseDataSample._del_field, name='_pred_instances'), + doc='Predicted instances of an image' + ) +``` + +`DetDataSample` 的用法如下所示,在数æ®ç±»åž‹ä¸ç¬¦åˆè¦æ±‚的时候(例如用 `torch.Tensor` è€Œéž `InstanceData` 定义 proposals 时) ,`DetDataSample` 就会报错。 + +```python +a = DetDataSample() + +a.proposals = InstanceData(data=dict(bboxes=torch.rand((5,4)))) + +assert 'proposals' in a +print(a.proposals) + +del a.proposals +assert 'proposals' not in a +``` + +### 对接å£çš„简化 + +下é¢ä»¥ MMDetection 为例更具体地说明 OpenMMLab 的算法库将如何è¿ç§»ä½¿ç”¨æŠ½è±¡æ•°æ®æŽ¥å£ï¼Œä»¥ç®€åŒ–模å—和组件接å£çš„。我们å‡å®š MMDetection å’Œ MMEngine ä¸å®žçŽ°äº† DetDataSample å’Œ InstanceData。 + +#### 1. 组件接å£çš„简化 + +检测器的外部接å£å¯ä»¥å¾—到显著的简化和统一。MMDet 2.X ä¸å•é˜¶æ®µæ£€æµ‹å™¨å’Œå•é˜¶æ®µåˆ†å‰²ç®—法的接å£å¦‚下。在è®ç»ƒè¿‡ç¨‹ä¸ï¼Œ`SingleStageDetector` 需è¦èŽ·å– +`img`, `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore` 作为输入,但是 `SingleStageInstanceSegmentor` è¿˜éœ€è¦ `gt_masks`,导致 detector çš„è®ç»ƒæŽ¥å£ä¸ä¸€è‡´ï¼Œå½±å“了代ç çš„çµæ´»æ€§ã€‚ + +```python + +class SingleStageDetector(BaseDetector): + ... + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None): + + +class SingleStageInstanceSegmentor(BaseDetector): + ... + + def forward_train(self, + img, + img_metas, + gt_masks, + gt_labels, + gt_bboxes=None, + gt_bboxes_ignore=None, + **kwargs): +``` + +在 MMDet 3.0 ä¸ï¼Œæ‰€æœ‰æ£€æµ‹å™¨çš„è®ç»ƒæŽ¥å£éƒ½å¯ä»¥ä½¿ç”¨ DetDataSample 统一简化为 `img` å’Œ `data_samples`,ä¸åŒæ¨¡å—å¯ä»¥æ ¹æ®éœ€è¦åŽ»è®¿é—® `data_samples` å°è£…çš„å„ç§æ‰€éœ€è¦çš„属性。 + +```python +class SingleStageDetector(BaseDetector): + ... + + def forward_train(self, + img, + data_samples): + +class SingleStageInstanceSegmentor(BaseDetector): + ... + + def forward_train(self, + img, + data_samples): + +``` + +#### 2. 模å—接å£çš„简化 + +MMDet 2.X ä¸ `HungarianAssigner` å’Œ `MaskHungarianAssigner` 分别用于在è®ç»ƒè¿‡ç¨‹ä¸å°†æ£€æµ‹æ¡†å’Œå®žä¾‹æŽ©ç å’Œæ ‡æ³¨çš„å®žä¾‹è¿›è¡ŒåŒ¹é…。他们内部的匹é…é€»è¾‘å®žçŽ°æ˜¯ä¸€æ ·çš„ï¼Œåªæ˜¯æŽ¥å£å’ŒæŸå¤±å‡½æ•°çš„计算ä¸åŒã€‚ +但是,接å£çš„ä¸åŒä½¿å¾— `HungarianAssigner` ä¸çš„代ç æ— æ³•è¢«å¤ç”¨ï¼Œ`MaskHungarianAssigner` ä¸é‡å†™äº†å¾ˆå¤šå†—余的逻辑。 + +```python +class HungarianAssigner(BaseAssigner): + + def assign(self, + bbox_pred, + cls_pred, + gt_bboxes, + gt_labels, + img_meta, + gt_bboxes_ignore=None, + eps=1e-7): + +class MaskHungarianAssigner(BaseAssigner): + + def assign(self, + cls_pred, + mask_pred, + gt_labels, + gt_mask, + img_meta, + gt_bboxes_ignore=None, + eps=1e-7): +``` + +`InstanceData` å¯ä»¥å°è£…实例的框ã€åˆ†æ•°ã€å’ŒæŽ©ç ,将 `HungarianAssigner` çš„æ ¸å¿ƒå‚æ•°ç®€åŒ–æˆ `pred_instances`,`gt_instancess`,和 `gt_instances_ignore` +使得 `HungarianAssigner` å’Œ `MaskHungarianAssigner` å¯ä»¥åˆå¹¶æˆä¸€ä¸ªé€šç”¨çš„ `HungarianAssigner`。 + +```python +class HungarianAssigner(BaseAssigner): + + def assign(self, + pred_instances, + gt_instancess, + gt_instances_ignore=None, + eps=1e-7): +``` + +## 命å规约 + +为了ä¿æŒä¸åŒä»»åŠ¡æ•°æ®ä¹‹é—´çš„兼容性和统一性,我们建议抽象数æ®æŽ¥å£ä¸å¯¹ç›¸åŒçš„æ•°æ®ä½¿ç”¨ç»Ÿä¸€çš„å—段命å。 +在本文档ä¸ï¼Œæˆ‘们暂时性地在下文列举一些算法方å‘çš„æ ·æœ¬æ•°æ®å°è£…åŠå…¶å±žæ€§çº¦å®šï¼ŒåŽç»ä¼šæœ‰æ›´å…¨é¢çš„文档æ¥æ述命å规约。 +用户在使用å„算法库抽象接å£çš„过程ä¸ï¼Œå¯ä»¥å‡å®šå¯¹åº”çš„æ•°æ®ï¼ˆå¦‚æœ‰ï¼‰åœ¨æ ·æœ¬æ•°æ®å°è£…ä¸æ˜¯æŒ‰ç…§å¦‚下约定进行命å的。 + +### ClsDataSample + +- gt_label (LabelData): æ•°æ®çš„åˆ†ç±»æ ‡ç¾ +- pred_label (LabelData): 模型对数æ®çš„分类预测结果 + +### DetDataSample + +- pred_instances (InstanceData): 模型预测的实例 +- gt_instances (InstanceData): æ ‡æ³¨çš„å®žä¾‹ +- gt_sem_seg (PixelData): è¯ä¹‰åˆ†å‰²çš„æ ‡æ³¨ +- pred_sem_seg (PixelData): è¯ä¹‰åˆ†å‰²ä»»åŠ¡çš„模型预测 +- gt_panoptic_seg (PixelData): å…¨æ™¯åˆ†å‰²çš„æ ‡æ³¨ +- pred_panoptic_seg (PixelData): 全景分割任务的模型预测 +- proposals (InstanceData): 用于åŒé˜¶æ®µæ£€æµ‹å™¨çš„候选框æå +- ignored_instances (InstanceData): 在è®ç»ƒä¸åº”当被忽视的实例 + +### SegDataSample + +- gt_sem_seg (PixelData): è¯ä¹‰åˆ†å‰²çš„æ ‡æ³¨ +- pred_sem_seg (PixelData): è¯ä¹‰åˆ†å‰²ä»»åŠ¡çš„模型预测