diff --git a/docs/zh_cn/get_started/15_minutes.md b/docs/zh_cn/get_started/15_minutes.md new file mode 100644 index 0000000000000000000000000000000000000000..dad771631ea36e10bafb9935111d681dc48421b0 --- /dev/null +++ b/docs/zh_cn/get_started/15_minutes.md @@ -0,0 +1,242 @@ +# 15 分钟上手 MMEngine + +以在 CIFAR-10 æ•°æ®é›†ä¸Šè®ç»ƒä¸€ä¸ª ResNet-50 模型为例,我们将使用 80 行以内的代ç ,利用 MMEngine 构建一个完整的〠+å¯é…置的è®ç»ƒå’ŒéªŒè¯æµç¨‹ï¼Œæ•´ä¸ªæµç¨‹åŒ…å«å¦‚下æ¥éª¤ï¼š + +1. [构建模型](#构建模型) +2. [构建数æ®é›†å’Œæ•°æ®åŠ 载器](#构建数æ®é›†å’Œæ•°æ®åŠ 载器) +3. [æž„å»ºè¯„æµ‹æŒ‡æ ‡](#æž„å»ºè¯„æµ‹æŒ‡æ ‡) +4. [构建执行器并执行任务](#构建执行器并执行任务) + +## 构建模型 + +首先,我们需è¦æž„建一个**模型**,在 MMEngine ä¸ï¼Œæˆ‘们约定这个模型应当继承 `BaseModel`,并且其 `forward` 方法除了接å—æ¥è‡ªæ•°æ®é›†çš„若干å‚数外, +还需è¦æŽ¥å—é¢å¤–çš„å‚æ•° `mode`:对于è®ç»ƒï¼Œæˆ‘ä»¬éœ€è¦ `mode` 接å—å—符串 "loss"ï¼Œå¹¶è¿”å›žä¸€ä¸ªåŒ…å« "loss" å—段的å—典; +对于验è¯ï¼Œæˆ‘ä»¬éœ€è¦ `mode` 接å—å—符串 "predict",并返回åŒæ—¶åŒ…å«é¢„测信æ¯å’ŒçœŸå®žä¿¡æ¯çš„结果。 + +```python +import torch.nn.functional as F +import torchvision +from mmengine.model import BaseModel + + +class MMResNet50(BaseModel): + def __init__(self): + super().__init__() + self.resnet = torchvision.models.resnet50() + + def forward(self, imgs, labels, mode): + x = self.resnet(imgs) + if mode == 'loss': + return {'loss': F.cross_entropy(x, labels)} + elif mode == 'predict': + return x, labels +``` + +## 构建数æ®é›†å’Œæ•°æ®åŠ 载器 + +其次,我们需è¦æž„建è®ç»ƒå’ŒéªŒè¯æ‰€éœ€è¦çš„**æ•°æ®é›† (Dataset)**å’Œ**æ•°æ®åŠ 载器 (DataLoader)**。 +对于基础的è®ç»ƒå’ŒéªŒè¯åŠŸèƒ½ï¼Œæˆ‘们å¯ä»¥ç›´æŽ¥ä½¿ç”¨ç¬¦åˆ PyTorch æ ‡å‡†çš„æ•°æ®åŠ 载器和数æ®é›†ã€‚ + +```python +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) +train_dataloader = DataLoader(batch_size=32, + shuffle=True, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ]))) + +val_dataloader = DataLoader(batch_size=32, + shuffle=False, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ]))) +``` + +## æž„å»ºè¯„æµ‹æŒ‡æ ‡ + +为了进行验è¯å’Œæµ‹è¯•ï¼Œæˆ‘们需è¦å®šä¹‰æ¨¡åž‹æŽ¨ç†ç»“果的**è¯„æµ‹æŒ‡æ ‡**ã€‚æˆ‘ä»¬çº¦å®šè¿™ä¸€è¯„æµ‹æŒ‡æ ‡éœ€è¦ç»§æ‰¿ `BaseMetric`, +并实现 `process` å’Œ `compute_metrics` æ–¹æ³•ã€‚å…¶ä¸ `process` 方法接å—æ•°æ®é›†çš„输出和模型 `mode="predict"` +时的输出,æ¤æ—¶çš„æ•°æ®ä¸ºä¸€ä¸ªæ‰¹æ¬¡çš„æ•°æ®ï¼Œå¯¹è¿™ä¸€æ‰¹æ¬¡çš„æ•°æ®è¿›è¡Œå¤„ç†åŽï¼Œä¿å˜ä¿¡æ¯è‡³ `self.results` 属性。 +而 `compute_metrics` æŽ¥å— `results` å‚数,这一å‚数的输入为 `process` ä¸ä¿å˜çš„æ‰€æœ‰ä¿¡æ¯ +(如果是分布å¼çŽ¯å¢ƒï¼Œ`results` ä¸ä¸ºå·²æ”¶é›†çš„,包括å„个进程 `process` ä¿å˜ä¿¡æ¯çš„结果), +利用这些信æ¯è®¡ç®—并返回ä¿å˜æœ‰è¯„æµ‹æŒ‡æ ‡ç»“æžœçš„å—典。 + +```python +from mmengine.evaluator import BaseMetric + +class Accuracy(BaseMetric): + def process(self, data_batch, data_samples): + score, gt = data_samples + # 将一个批次的ä¸é—´ç»“æžœä¿å˜è‡³ `self.results` + self.results.append({ + 'batch_size': len(gt), + 'correct': (score.argmax(dim=1) == gt).sum().cpu(), + }) + + def compute_metrics(self, results): + total_correct = sum(item['correct'] for item in results) + total_size = sum(item['batch_size'] for item in results) + # 返回ä¿å˜æœ‰è¯„æµ‹æŒ‡æ ‡ç»“æžœçš„å—典,其ä¸é”®ä¸ºæŒ‡æ ‡å称 + return dict(accuracy=100 * total_correct / total_size) +``` + +## 构建执行器并执行任务 + +最åŽï¼Œæˆ‘们利用构建好的**模型**,**æ•°æ®åŠ 载器**,**è¯„æµ‹æŒ‡æ ‡**构建一个**执行器 (Runner)**,åŒæ—¶åœ¨å…¶ä¸é…ç½® +**优化器**ã€**工作路径**ã€**è®ç»ƒä¸ŽéªŒè¯é…ç½®**ç‰é€‰é¡¹ï¼Œå³å¯é€šè¿‡è°ƒç”¨ `train()` 接å£å¯åŠ¨è®ç»ƒï¼š + +```python +from torch.optim import SGD +from mmengine.runner import Runner + +runner = Runner( + # 用以è®ç»ƒå’ŒéªŒè¯çš„模型,需è¦æ»¡è¶³ç‰¹å®šçš„接å£éœ€æ±‚ + model=MMResNet50(), + # 工作路径,用以ä¿å˜è®ç»ƒæ—¥å¿—ã€æƒé‡æ–‡ä»¶ä¿¡æ¯ + work_dir='./work_dir', + # è®ç»ƒæ•°æ®åŠ 载器,需è¦æ»¡è¶³ PyTorch æ•°æ®åŠ 载器åè®® + train_dataloader=train_dataloader, + # 优化器包装,用于模型优化,并æä¾› AMPã€æ¢¯åº¦ç´¯ç§¯ç‰é™„åŠ åŠŸèƒ½ + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + # è®ç»ƒé…置,用于指定è®ç»ƒå‘¨æœŸã€éªŒè¯é—´éš”ç‰ä¿¡æ¯ + train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1), + # 验è¯æ•°æ®åŠ 载器,需è¦æ»¡è¶³ PyTorch æ•°æ®åŠ 载器åè®® + val_dataloader=val_dataloader, + # 验è¯é…置,用于指定验è¯æ‰€éœ€è¦çš„é¢å¤–å‚æ•° + val_cfg=dict(), + # 用于验è¯çš„è¯„æµ‹å™¨ï¼Œè¿™é‡Œä½¿ç”¨é»˜è®¤è¯„æµ‹å™¨ï¼Œå¹¶è¯„æµ‹æŒ‡æ ‡ + val_evaluator=dict(type=Accuracy), +) + +runner.train() +``` + +最åŽï¼Œè®©æˆ‘们把以上部分汇总æˆä¸ºä¸€ä¸ªå®Œæ•´çš„,利用 MMEngine 执行器进行è®ç»ƒå’ŒéªŒè¯çš„脚本: + +<a href="https://colab.research.google.com/github/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/get_started.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab ä¸æ‰“å¼€"/></a> + +```python +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.optim import SGD +from torch.utils.data import DataLoader + +from mmengine.evaluator import BaseMetric +from mmengine.model import BaseModel +from mmengine.runner import Runner + + +class MMResNet50(BaseModel): + def __init__(self): + super().__init__() + self.resnet = torchvision.models.resnet50() + + def forward(self, imgs, labels, mode): + x = self.resnet(imgs) + if mode == 'loss': + return {'loss': F.cross_entropy(x, labels)} + elif mode == 'predict': + return x, labels + + +class Accuracy(BaseMetric): + def process(self, data_batch, data_samples): + score, gt = data_samples + self.results.append({ + 'batch_size': len(gt), + 'correct': (score.argmax(dim=1) == gt).sum().cpu(), + }) + + def compute_metrics(self, results): + total_correct = sum(item['correct'] for item in results) + total_size = sum(item['batch_size'] for item in results) + return dict(accuracy=100 * total_correct / total_size) + + +norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201]) +train_dataloader = DataLoader(batch_size=32, + shuffle=True, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ]))) + +val_dataloader = DataLoader(batch_size=32, + shuffle=False, + dataset=torchvision.datasets.CIFAR10( + 'data/cifar10', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(**norm_cfg) + ]))) + +runner = Runner( + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), +) +runner.train() +``` + +输出的è®ç»ƒæ—¥å¿—如下: + +``` +2022/08/22 15:51:53 - mmengine - INFO - +------------------------------------------------------------ +System environment: + sys.platform: linux + Python: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] + CUDA available: True + numpy_random_seed: 1513128759 + GPU 0: NVIDIA GeForce GTX 1660 SUPER + CUDA_HOME: /usr/local/cuda +... + +2022/08/22 15:51:54 - mmengine - INFO - Checkpoints will be saved to /home/mazerun/work_dir by HardDiskBackend. +2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][10/1563] lr: 1.0000e-03 eta: 0:18:23 time: 0.1414 data_time: 0.0077 memory: 392 loss: 5.3465 +2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][20/1563] lr: 1.0000e-03 eta: 0:11:29 time: 0.0354 data_time: 0.0077 memory: 392 loss: 2.7734 +2022/08/22 15:51:56 - mmengine - INFO - Epoch(train) [1][30/1563] lr: 1.0000e-03 eta: 0:09:10 time: 0.0352 data_time: 0.0076 memory: 392 loss: 2.7789 +2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][40/1563] lr: 1.0000e-03 eta: 0:08:00 time: 0.0353 data_time: 0.0073 memory: 392 loss: 2.5725 +2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][50/1563] lr: 1.0000e-03 eta: 0:07:17 time: 0.0347 data_time: 0.0073 memory: 392 loss: 2.7382 +2022/08/22 15:51:57 - mmengine - INFO - Epoch(train) [1][60/1563] lr: 1.0000e-03 eta: 0:06:49 time: 0.0347 data_time: 0.0072 memory: 392 loss: 2.5956 +2022/08/22 15:51:58 - mmengine - INFO - Epoch(train) [1][70/1563] lr: 1.0000e-03 eta: 0:06:28 time: 0.0348 data_time: 0.0072 memory: 392 loss: 2.7351 +... +2022/08/22 15:52:50 - mmengine - INFO - Saving checkpoint at 1 epochs +2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][10/313] eta: 0:00:03 time: 0.0122 data_time: 0.0047 memory: 392 +2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][20/313] eta: 0:00:03 time: 0.0122 data_time: 0.0047 memory: 308 +2022/08/22 15:52:51 - mmengine - INFO - Epoch(val) [1][30/313] eta: 0:00:03 time: 0.0123 data_time: 0.0047 memory: 308 +... +2022/08/22 15:52:54 - mmengine - INFO - Epoch(val) [1][313/313] accuracy: 35.7000 +``` + +é™¤äº†ä»¥ä¸ŠåŸºç¡€ç»„ä»¶ï¼Œä½ è¿˜å¯ä»¥åˆ©ç”¨**执行器**è½»æ¾åœ°ç»„åˆé…ç½®å„ç§è®ç»ƒæŠ€å·§ï¼Œå¦‚å¼€å¯æ··åˆç²¾åº¦è®ç»ƒå’Œæ¢¯åº¦ç´¯ç§¯ï¼ˆè§ [优化器å°è£…(OptimWrapper)](../tutorials/optim_wrapper.md))ã€é…ç½®å¦ä¹ 率衰å‡æ›²çº¿ï¼ˆè§ [è¯„æµ‹æŒ‡æ ‡ä¸Žè¯„æµ‹å™¨ï¼ˆMetrics & Evaluator)](../tutorials/metric_and_evaluator.md))ç‰ã€‚ diff --git a/docs/zh_cn/tutorials/get_started.md b/docs/zh_cn/tutorials/get_started.md deleted file mode 100644 index 8e9ca6ec2dfa47b2dca92cd8d1a216230b0e9612..0000000000000000000000000000000000000000 --- a/docs/zh_cn/tutorials/get_started.md +++ /dev/null @@ -1,30 +0,0 @@ -# 使用 MMEngine æ¥è®ç»ƒæ¨¡åž‹ - -MMEngine 实现了 OpenMMLab 算法库的新一代è®ç»ƒæž¶æž„,为算法模型的è®ç»ƒã€æµ‹è¯•ã€æŽ¨ç†å’Œå¯è§†åŒ–定义了一套基类与接å£ã€‚ -å’Œ OpenMMLab 算法库的上一代è®ç»ƒæž¶æž„相比,它具有如下三个特点: - -- 统一:为ä¸åŒæ–¹å‘算法模型的è®ç»ƒã€æµ‹è¯•ã€æŽ¨ç†ã€å’Œå¯è§†åŒ–è¿‡ç¨‹è¿›è¡Œäº†æŠ½è±¡å¹¶å®šä¹‰äº†ä¸€å¥—ç»Ÿä¸€çš„æŽ¥å£ -- 清晰:å°è£…的层次与逻辑清晰简å•ï¼ŒæŠ½è±¡çš„定义与接å£æ›´åŠ 清晰,模å—çš„æ‹†åˆ†ä¸Žè¾¹ç•Œæ›´åŠ æ¸…æ™° -- çµæ´»ï¼šåœ¨ç»Ÿä¸€çš„基础框架内,模å—å¯ä»¥çµæ´»æ‹“展和æ’拔,支æŒå„类型算法和å¦ä¹ 范å¼ï¼ŒåŒ…æ‹¬å°‘æ ·æœ¬å’Œé›¶æ ·æœ¬å¦ä¹ ,自监ç£ã€åŠç›‘ç£ã€å’Œå¼±ç›‘ç£å¦ä¹ ,和模型的蒸é¦ã€å‰ªæžã€ä¸Žé‡åŒ–。 - -## 组件 - -MMEngine 将算法模型è®ç»ƒã€æŽ¨ç†ã€æµ‹è¯•å’Œå¯è§†åŒ–过程ä¸çš„å„ä¸ªç»„ä»¶è¿›è¡Œäº†æŠ½è±¡ï¼Œå®šä¹‰äº†å¦‚ä¸‹å‡ ä¸ªç»„ä»¶å’Œä»–ä»¬çš„ç›¸å…³æŽ¥å£ï¼Œè¿™äº›ç»„件的关系如下图所示: - - - -ä»¥ä¸‹æ ¹æ®ä¸Šå›¾ç®€è¿°è¿™äº›æ¨¡å—的功能与è”系,用户å¯ä»¥é€šè¿‡å„个组件的用户文档了解他们。 - -- [执行器(Runner)](./runner.md):负责执行è®ç»ƒã€æµ‹è¯•å’ŒæŽ¨ç†ä»»åŠ¡å¹¶ç®¡ç†è¿™äº›è¿‡ç¨‹ä¸æ‰€éœ€è¦çš„å„个组件。 -- [é’©å(Hook)](./hook.md):负责在è®ç»ƒã€æµ‹è¯•ã€æŽ¨ç†ä»»åŠ¡æ‰§è¡Œè¿‡ç¨‹ä¸çš„特定ä½ç½®æ‰§è¡Œè‡ªå®šä¹‰é€»è¾‘。 -- [æ•°æ®é›†ï¼ˆDataset)](./basedataset.md):负责在è®ç»ƒã€æµ‹è¯•ã€æŽ¨ç†ä»»åŠ¡ä¸æž„建数æ®é›†ï¼Œå¹¶å°†æ•°æ®é€ç»™æ¨¡åž‹ã€‚实际使用过程ä¸ä¼šè¢«æ•°æ®åŠ 载器(DataLoader)å°è£…一层,数æ®åŠ 载器会å¯åŠ¨å¤šä¸ªå进程æ¥åŠ 载数æ®ã€‚ -- [模型(Model)](./model.md):在è®ç»ƒè¿‡ç¨‹ä¸æŽ¥å—æ•°æ®ã€è¾“出 loss,在测试ã€æŽ¨ç†ä»»åŠ¡ä¸æŽ¥å—æ•°æ®ï¼Œå¹¶è¿›è¡Œé¢„测。分布å¼è®ç»ƒç‰æƒ…况下会被模型的å°è£…器(Model Wrapper,如 .`nn.DistributedDataParallel`)å°è£…一层。 -- [è¯„æµ‹æŒ‡æ ‡ä¸Žè¯„æµ‹å™¨ï¼ˆMetrics & Evaluator)](./metric_and_evaluator.md):评测器负责基于数æ®é›†å¯¹æ¨¡åž‹çš„é¢„æµ‹è¿›è¡Œè¯„ä¼°ã€‚è¯„æµ‹å™¨å†…è¿˜æœ‰ä¸€å±‚æŠ½è±¡æ˜¯è¯„æµ‹æŒ‡æ ‡ï¼Œè´Ÿè´£è®¡ç®—å…·ä½“çš„ä¸€ä¸ªæˆ–å¤šä¸ªè¯„æµ‹æŒ‡æ ‡ï¼ˆå¦‚å¬å›žçŽ‡ã€æ£ç¡®çŽ‡ç‰ï¼‰ã€‚ -- [æ•°æ®å…ƒç´ (Data Element)](./data_element.md):评测器,模型和数æ®ä¹‹é—´äº¤æµçš„接å£ä½¿ç”¨æ•°æ®å…ƒç´ 进行å°è£…。 -- [å‚数调度器(Parameter Scheduler)](./param_scheduler.md):è®ç»ƒè¿‡ç¨‹ä¸ï¼Œå¯¹å¦ä¹ 率ã€åŠ¨é‡ç‰å‚数进行动æ€è°ƒæ•´ã€‚ -- [优化器(Optimizer)](./optimizer_wrapper.md):优化器负责在è®ç»ƒè¿‡ç¨‹ä¸æ‰§è¡Œåå‘ä¼ æ’优化模型。实际使用过程ä¸ä¼šè¢«ä¼˜åŒ–器å°è£…(OptimWrapper)å°è£…ä¸€å±‚ï¼Œå®žçŽ°æ¢¯åº¦ç´¯åŠ ã€æ··åˆç²¾åº¦è®ç»ƒç‰åŠŸèƒ½ã€‚ -- [日志管ç†ï¼ˆLogging Modules)](./logging.md)ï¼šè´Ÿè´£ç®¡ç† Runner è¿è¡Œè¿‡ç¨‹ä¸äº§ç”Ÿçš„å„ç§æ—¥å¿—ä¿¡æ¯ã€‚å…¶ä¸æ¶ˆæ¯æž¢çº½ (MessageHub)负责实现组件与组件ã€æ‰§è¡Œå™¨ä¸Žæ‰§è¡Œå™¨ä¹‹é—´çš„æ•°æ®å…±äº«ï¼Œæ—¥å¿—处ç†å™¨ï¼ˆLog Processor)负责对日志信æ¯è¿›è¡Œå¤„ç†ï¼Œå¤„ç†åŽçš„日志会分别å‘é€ç»™æ‰§è¡Œå™¨çš„日志器(Logger)和å¯è§†åŒ–器(Visualizer)进行日志的管ç†ä¸Žå±•ç¤ºã€‚ -- [é…置类(Config)](./config.md):在 OpenMMLab 算法库ä¸ï¼Œç”¨æˆ·å¯ä»¥é€šè¿‡ç¼–写 config æ¥é…ç½®è®ç»ƒã€æµ‹è¯•è¿‡ç¨‹ä»¥åŠç›¸å…³çš„组件。 -- [注册器(Registry)](./registry.md):负责管ç†ç®—法库ä¸å…·æœ‰ç›¸åŒåŠŸèƒ½çš„模å—。MMEngine æ ¹æ®å¯¹ç®—法库模å—çš„æŠ½è±¡ï¼Œå®šä¹‰äº†ä¸€å¥—æ ¹æ³¨å†Œå™¨ï¼Œç®—æ³•åº“ä¸çš„注册器å¯ä»¥ç»§æ‰¿è‡ªè¿™å¥—æ ¹æ³¨å†Œå™¨ï¼Œå®žçŽ°æ¨¡å—的跨算法库调用。 -- [分布å¼é€šä¿¡åŽŸè¯ï¼ˆDistributed Communication Primitives)](./distributed.md):负责在程åºåˆ†å¸ƒå¼è¿è¡Œè¿‡ç¨‹ä¸ä¸åŒè¿›ç¨‹é—´çš„通信。这套接å£å±è”½äº†åˆ†å¸ƒå¼å’Œéžåˆ†å¸ƒå¼çŽ¯å¢ƒçš„区别,åŒæ—¶ä¹Ÿè‡ªåŠ¨å¤„ç†äº†æ•°æ®çš„设备和通信åŽç«¯ã€‚ -- [其他工具(Utils)](./utils.md):还有一些工具性的模å—,如管ç†å™¨æ··å…¥ï¼ˆManagerMixin),它实现了一ç§å…¨å±€å˜é‡çš„创建和获å–æ–¹å¼ï¼ŒRunner 内很多全局å¯è§å¯¹è±¡çš„基类就是 ManagerMixin。