From dcab0f5055783dada916b365ec95151c03ed5a46 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 29 Aug 2022 14:47:31 +0800 Subject: [PATCH] [Docs] Add resume training examples (#407) * [Docs] Add resume training examples * refine * rename filename * minor refinement * fix comments * resolve comments * resolve comments --- docs/zh_cn/examples/resume_training.md | 36 ++++++++++++++++++++++++++ docs/zh_cn/index.rst | 1 + 2 files changed, 37 insertions(+) create mode 100644 docs/zh_cn/examples/resume_training.md diff --git a/docs/zh_cn/examples/resume_training.md b/docs/zh_cn/examples/resume_training.md new file mode 100644 index 00000000..a7382597 --- /dev/null +++ b/docs/zh_cn/examples/resume_training.md @@ -0,0 +1,36 @@ +# æ¢å¤è®ç»ƒ + +æ¢å¤è®ç»ƒæ˜¯æŒ‡ä»Žä¹‹å‰æŸæ¬¡è®ç»ƒä¿å˜ä¸‹æ¥çš„状æ€å¼€å§‹ç»§ç»è®ç»ƒï¼Œè¿™é‡Œçš„状æ€åŒ…括模型的æƒé‡ã€ä¼˜åŒ–器和优化器å‚数调整ç–略的状æ€ã€‚ + +## 自动æ¢å¤è®ç»ƒ + +用户å¯ä»¥è®¾ç½® `Runner` çš„ `resume` å‚æ•°å¼€å¯è‡ªåŠ¨æ¢å¤è®ç»ƒçš„功能。在å¯åŠ¨è®ç»ƒæ—¶ï¼Œè®¾ç½® `Runner` çš„ `resume` ç‰äºŽ `True`,`Runner` 会从 `work_dir` ä¸åŠ 载最新的 checkpoint。如果 `work_dir` ä¸æœ‰æœ€æ–°çš„ checkpoint(例如该è®ç»ƒåœ¨ä¸Šä¸€æ¬¡è®ç»ƒæ—¶è¢«ä¸æ–),则会从该 checkpoint æ¢å¤è®ç»ƒï¼Œå¦åˆ™ï¼ˆä¾‹å¦‚上一次è®ç»ƒè¿˜æ²¡æ¥å¾—åŠä¿å˜ checkpoint 或者å¯åŠ¨äº†æ–°çš„è®ç»ƒä»»åŠ¡ï¼‰ä¼šé‡æ–°å¼€å§‹è®ç»ƒã€‚下é¢æ˜¯ä¸€ä¸ªå¼€å¯è‡ªåŠ¨æ¢å¤è®ç»ƒçš„示例 + +```python +runner = Runner( + model=ResNet18(), + work_dir='./work_dir', + train_dataloader=train_dataloader_cfg, + optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=3), + resume=True, +) +runner.train() +``` + +## 指定 checkpoint 路径 + +如果希望指定æ¢å¤è®ç»ƒçš„路径,除了设置 `resume=True`,还需è¦è®¾ç½® `load_from` å‚数。需è¦æ³¨æ„的是,如果åªè®¾ç½®äº† `load_from` 而没有设置 `resume=True`,则åªä¼šåŠ è½½ checkpoint ä¸çš„æƒé‡å¹¶é‡æ–°å¼€å§‹è®ç»ƒï¼Œè€Œä¸æ˜¯æŽ¥ç€ä¹‹å‰çš„状æ€ç»§ç»è®ç»ƒã€‚ + +```python +runner = Runner( + model=ResNet18(), + work_dir='./work_dir', + train_dataloader=train_dataloader_cfg, + optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=3), + load_from='./work_dir/epoch_2.pth', + resume=True, +) +runner.train() +``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 49bb92f8..b42652b5 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -29,6 +29,7 @@ :maxdepth: 1 :caption: 示例 + examples/resume_training.md examples/speed_up_training.md .. toctree:: -- GitLab