diff --git a/docs/zh_cn/tutorials/runner.md b/docs/zh_cn/tutorials/runner.md index c7ccea098ca38340b565958f27eb01b13b0c8791..45c696ca51d221955a7a8f4301f87c0f00db7f11 100644 --- a/docs/zh_cn/tutorials/runner.md +++ b/docs/zh_cn/tutorials/runner.md @@ -91,17 +91,18 @@ runner.train() model = FasterRCNN() test_dataset = CocoDataset() test_dataloader = Dataloader(dataset=test_dataset, batch_size=2, num_workers=2) -evaluator = CocoEvaluator(metric='bbox') +metric = CocoMetric() +test_evaluator = Evaluator(metric) # åˆå§‹åŒ–执行器 -runner = Runner(model=model, test_dataloader=test_dataloader, evaluator=evaluator, - load_checkpoint='./faster_rcnn.pth') +runner = Runner(model=model, test_dataloader=test_dataloader, test_evaluator=test_evaluator, + load_from='./faster_rcnn.pth') # 执行测试 runner.test() ``` -这个例åä¸æˆ‘们手动构建了一个 Faster R-CNN 检测模型,以åŠæµ‹è¯•ç”¨çš„ COCO æ•°æ®é›†å’Œå¯¹åº”çš„ COCO 评测器,并使用这些模å—åˆå§‹åŒ–执行器,最åŽé€šè¿‡è°ƒç”¨æ‰§è¡Œå™¨çš„ `test` 函数进行模型测试。 +这个例åä¸æˆ‘们手动构建了一个 Faster R-CNN 检测模型,以åŠæµ‹è¯•ç”¨çš„ COCO æ•°æ®é›†å’Œä½¿ç”¨ COCO æŒ‡æ ‡çš„è¯„æµ‹å™¨ï¼Œå¹¶ä½¿ç”¨è¿™äº›æ¨¡å—åˆå§‹åŒ–执行器,最åŽé€šè¿‡è°ƒç”¨æ‰§è¡Œå™¨çš„ `test` 函数进行模型测试。 ### 通过é…置文件使用执行器 @@ -146,12 +147,13 @@ test_dataloader = ... optimizer = dict(type='SGD', lr=0.01) # å‚数调度器é…ç½® param_scheduler = dict(type='MultiStepLR', milestones=[80, 90]) -#评测器é…ç½® -evaluator = dict(type='Accuracy') +#验è¯å’Œæµ‹è¯•çš„评测器é…ç½® +val_evaluator = dict(type='Accuracy') +test_evaluator = dict(type='Accuracy') # è®ç»ƒã€éªŒè¯ã€æµ‹è¯•æµç¨‹é…ç½® train_cfg = dict(by_epoch=True, max_epochs=100) -validation_cfg = dict(interval=1) # æ¯éš”一个 epoch è¿›è¡Œä¸€æ¬¡éªŒè¯ +val_cfg = dict(interval=1) # æ¯éš”一个 epoch è¿›è¡Œä¸€æ¬¡éªŒè¯ test_cfg = dict() # 自定义钩å @@ -163,20 +165,40 @@ default_hooks = dict( checkpoint=dict(type='CheckpointHook', interval=1), # 模型ä¿å˜é’©å logger=dict(type='TextLoggerHook'), # è®ç»ƒæ—¥å¿—é’©å optimizer=dict(type='OptimzierHook', grad_clip=False), # 优化器钩å - param_scheduler=dict(type='ParamSchedulerHook')) # å‚数调度器执行钩å + param_scheduler=dict(type='ParamSchedulerHook'), # å‚数调度器执行钩å + sampler_seed=dict(type='DistSamplerSeedHook')) # 为æ¯è½®æ¬¡çš„æ•°æ®é‡‡æ ·è®¾ç½®éšæœºç§åçš„é’©å # 环境é…ç½® env_cfg = dict( - dist_params=dict(backend='nccl'), + cudnn_benchmark=False, + dist_cfg=dict(backend='nccl'), mp_cfg=dict(mp_start_method='fork') ) -# 系统日志é…ç½® -log_cfg = dict(log_level='INFO') +# 日志ç‰çº§é…ç½® +log_level = 'INFO' + +# åŠ è½½æƒé‡ +load_from = None +# æ¢å¤è®ç»ƒ +resume = False ``` 一个完整的é…置文件主è¦ç”±æ¨¡åž‹ã€æ•°æ®ã€ä¼˜åŒ–器ã€å‚数调度器ã€è¯„测器ç‰æ¨¡å—çš„é…置,è®ç»ƒã€éªŒè¯ã€æµ‹è¯•ç‰æµç¨‹çš„é…置,还有执行æµç¨‹è¿‡ç¨‹ä¸çš„å„ç§é’©å模å—çš„é…置,以åŠçŽ¯å¢ƒå’Œæ—¥å¿—ç‰å…¶ä»–é…置的å—段组æˆã€‚ 通过é…置文件构建的执行器采用了懒åˆå§‹åŒ– (lazy initialization),åªæœ‰å½“调用到è®ç»ƒæˆ–测试ç‰æ‰§è¡Œå‡½æ•°æ—¶ï¼Œæ‰ä¼šæ ¹æ®é…置文件去完整åˆå§‹åŒ–所需è¦çš„模å—。 +## åŠ è½½æƒé‡æˆ–æ¢å¤è®ç»ƒ + +执行器å¯ä»¥é€šè¿‡ `load_from` å‚æ•°åŠ è½½æ£€æŸ¥ç‚¹ï¼ˆcheckpoint)文件ä¸çš„模型æƒé‡ï¼Œåªéœ€è¦å°† `load_from` å‚数设置为检查点文件的路径å³å¯ã€‚ + +```python +runner = Runner(model=model, test_dataloader=test_dataloader, test_evaluator=test_evaluator, + load_from='./faster_rcnn.pth') +``` + +如果是通过é…置文件使用执行器,åªéœ€ä¿®æ”¹é…置文件ä¸çš„ `load_from` å—段å³å¯ã€‚ + +用户也å¯é€šè¿‡è®¾ç½® `resume=True` æ¥ï¼ŒåŠ 载检查点ä¸çš„è®ç»ƒçŠ¶æ€ä¿¡æ¯æ¥æ¢å¤è®ç»ƒã€‚当 `load_from` å’Œ `resume=True` åŒæ—¶è¢«è®¾ç½®æ—¶ï¼Œæ‰§è¡Œå™¨å°†åŠ è½½ `load_from` 路径对应的检查点文件ä¸çš„è®ç»ƒçŠ¶æ€ã€‚如果仅设置 `resume=True`,执行器将会å°è¯•ä»Ž `work_dir` 文件夹ä¸å¯»æ‰¾å¹¶è¯»å–最新的检查点文件。 + ## 进阶使用 MMEngine ä¸çš„默认执行器能够完æˆå¤§éƒ¨åˆ†çš„深度å¦ä¹ 任务,但ä¸å¯é¿å…会å˜åœ¨æ— æ³•æ»¡è¶³çš„æƒ…å†µã€‚æœ‰çš„ç”¨æˆ·å¸Œæœ›èƒ½å¤Ÿå¯¹æ‰§è¡Œå™¨è¿›è¡Œæ›´å¤šè‡ªå®šä¹‰ä¿®æ”¹ï¼Œå› æ¤ï¼ŒMMEngine 支æŒè‡ªå®šä¹‰æ¨¡åž‹çš„è®ç»ƒã€éªŒè¯ä»¥åŠæµ‹è¯•çš„æµç¨‹ã€‚ @@ -195,48 +217,68 @@ MMEngine 内æ供了四ç§é»˜è®¤çš„循环: 用户å¯ä»¥é€šè¿‡ç»§æ‰¿å¾ªçŽ¯åŸºç±»æ¥å®žçŽ°è‡ªå·±çš„è®ç»ƒæµç¨‹ã€‚循环基类需è¦æ供两个输入:`runner` 执行器的实例和 `loader` 循环所需è¦è¿ä»£çš„è¿ä»£å™¨ã€‚ 用户如果有自定义的需求,也å¯ä»¥å¢žåŠ 更多的输入å‚数。MMEngine ä¸åŒæ ·æ供了 LOOPS 注册器对循环类进行管ç†ï¼Œç”¨æˆ·å¯ä»¥å‘注册器内注册自定义的循环模å—, -然åŽåœ¨é…置文件的 `train_cfg`ã€`validation_cfg`ã€`test_cfg` ä¸å¢žåŠ `type` å—段æ¥æŒ‡å®šä½¿ç”¨ä½•ç§å¾ªçŽ¯ã€‚ +然åŽåœ¨é…置文件的 `train_cfg`ã€`val_cfg`ã€`test_cfg` ä¸å¢žåŠ `type` å—段æ¥æŒ‡å®šä½¿ç”¨ä½•ç§å¾ªçŽ¯ã€‚ 用户å¯ä»¥åœ¨è‡ªå®šä¹‰çš„循环ä¸å®žçŽ°ä»»æ„的执行逻辑,也å¯ä»¥å¢žåŠ æˆ–åˆ å‡é’©å(hook)点ä½ï¼Œä½†éœ€è¦æ³¨æ„的是一旦钩å点ä½è¢«ä¿®æ”¹ï¼Œé»˜è®¤çš„é’©å函数å¯èƒ½ä¸ä¼šè¢«æ‰§è¡Œï¼Œå¯¼è‡´ä¸€äº›è®ç»ƒè¿‡ç¨‹ä¸é»˜è®¤å‘生的行为å‘生å˜åŒ–。 å› æ¤ï¼Œæˆ‘们强烈建议用户按照本文档ä¸å®šä¹‰çš„循环执行æµç¨‹å›¾ä»¥åŠ[é’©å规范](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/hook.html) 去é‡è½½å¾ªçŽ¯åŸºç±»ã€‚ ```python -from mmengine.registry import LOOPS +from mmengine.registry import LOOPS, HOOKS from mmengine.runner.loop import BaseLoop +from mmengine.hooks import Hook + +# 自定义验è¯å¾ªçŽ¯ @LOOPS.register_module() class CustomValLoop(BaseLoop): - def __init__(self, runner, loader, evaluator, loader2): - super().__init__(runner, loader, evaluator) - self.loader2 = runner.build_dataloader(loader2) + def __init__(self, runner, dataloader, evaluator, dataloader2): + super().__init__(runner, dataloader, evaluator) + self.dataloader2 = runner.build_dataloader(dataloader2) def run(self): self.runner.call_hooks('before_val_epoch') - for idx, databatch in enumerate(self.loader): - self.runner.call_hooks('before_val_iter', - args=dict(databatch=databatch)) - outputs = self.run_iter(idx, databatch) - self.runner.call_hooks('after_val_iter', - args=dict(databatch=databatch, outputs=outputs)) + for idx, data_batch in enumerate(self.dataloader): + self.runner.call_hooks( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + outputs = self.run_iter(idx, data_batch) + self.runner.call_hooks( + 'after_val_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs) metric = self.evaluator.evaluate() - for idx, databatch in enumerate(self.loader2): - self.runner.call_hooks('before_val_iter2', - args=dict(databatch=databatch)) - self.run_iter(idx, databatch) - self.runner.call_hooks('after_val_iter2', - args=dict(databatch=databatch, outputs=outputs)) + + # å¢žåŠ é¢å¤–的验è¯å¾ªçŽ¯ + for idx, data_batch in enumerate(self.dataloader2): + # å¢žåŠ é¢å¤–çš„é’©åç‚¹ä½ + self.runner.call_hooks( + 'before_valloader2_iter', batch_idx=idx, data_batch=data_batch) + self.run_iter(idx, data_batch) + # å¢žåŠ é¢å¤–çš„é’©åç‚¹ä½ + self.runner.call_hooks( + 'after_valloader2_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs) metric2 = self.evaluator.evaluate() ... self.runner.call_hooks('after_val_epoch') + +# 定义é¢å¤–点ä½çš„é’©åç±» +@HOOKS.register_module() +class CustomValHook(Hook): + def before_valloader2_iter(self, batch_idx, data_batch): + ... + + def after_valloader2_iter(self, batch_idx, data_batch, outputs): + ... + ``` 上é¢çš„例åä¸å®žçŽ°äº†ä¸€ä¸ªä¸Žé»˜è®¤éªŒè¯å¾ªçŽ¯ä¸ä¸€æ ·çš„自定义验è¯å¾ªçŽ¯ï¼Œå®ƒåœ¨ä¸¤ä¸ªä¸åŒçš„验è¯é›†ä¸Šè¿›è¡ŒéªŒè¯ï¼ŒåŒæ—¶å¯¹ç¬¬äºŒæ¬¡éªŒè¯å¢žåŠ 了é¢å¤–çš„é’©å点ä½ï¼Œå¹¶åœ¨æœ€åŽå¯¹ä¸¤ä¸ªéªŒè¯ç»“果进行进一æ¥çš„处ç†ã€‚在实现了自定义的循环类之åŽï¼Œ -åªéœ€è¦åœ¨é…置文件的 `validation_cfg` 内设置 `type='CustomValLoop'`ï¼Œå¹¶æ·»åŠ é¢å¤–çš„é…ç½®å³å¯ã€‚ +åªéœ€è¦åœ¨é…置文件的 `val_cfg` 内设置 `type='CustomValLoop'`ï¼Œå¹¶æ·»åŠ é¢å¤–çš„é…ç½®å³å¯ã€‚ ```python -validation_cfg = dict(type='CustomValLoop', loader2=dict(dataset=dict(type='ValDataset2'), ...)) +# 自定义验è¯å¾ªçŽ¯ +val_cfg = dict(type='CustomValLoop', dataloader2=dict(dataset=dict(type='ValDataset2'), ...)) +# é¢å¤–点ä½çš„é’©å +custom_hooks = [dict(type='CustomValHook')] ``` ### 自定义执行器