Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
<div align="center">
<img src="https://user-images.githubusercontent.com/58739961/187154444-fce76639-ac8d-429b-9354-c6fac64b7ef8.jpg" width="600"/>
<div> </div>
<div align="center">
<b><font size="5">OpenMMLab website</font></b>
<sup>
<a href="https://openmmlab.com">
<i><font size="4">HOT</font></i>
</a>
</sup>
<b><font size="5">OpenMMLab platform</font></b>
<sup>
<a href="https://platform.openmmlab.com">
<i><font size="4">TRY IT OUT</font></i>
</a>
</sup>
</div>
<div> </div>
[](https://pypi.org/project/mmengine/)
[](https://pypi.org/project/mmengine)
[](https://github.com/open-mmlab/mmengine/blob/main/LICENSE)
[](https://github.com/open-mmlab/mmengine/issues)
[](https://github.com/open-mmlab/mmengine/issues)
[🤔Reporting Issues](https://github.com/open-mmlab/mmengine/issues/new/choose)
</div>
<div align="center">
English | [简体中文](README_zh-CN.md)
</div>
## Introduction
MMEngine is a fundational library for training deep learning models based on PyTorch. It can runs on Linux, Windows, and MacOS.
Major features:
1. A general and powerful runner:
- Users can train different models with several lines of code, e.g., training ImageNet in 80 lines (in comparison with PyTorch example that need more than 400 lines).
- Can train models in popular libraries like TIMM, TorchVision, and Detectron2.
2. An open framework with unified interfaces:
- Users can do one thing to all OpenMMLab 2.x projects with the same code. For example, MMRazor 1.x can compress models in all OpenMMLab 2.x projects with 40% of the code reduced from MMRazor 0.x.
- Simplify the support of up/down-streams. Currently, MMEngine can run on Nvidia CUDA, Mac MPS, AMD, MLU, and other devices.
3. A `legoified` training process:
- Dynamical training, optimization, and data augmentation strategies like Early stopping
- Arbitrary forms of model weight averaging including Exponential Momentum Average (EMA) and Stochastic Weight Averaging (SWA)
- Visualize and log whatever you want
- Fine-grained optimization strategies of each parameter groups
- Flexible control of mixed precision training
## Installation
Before installing MMEngine, please make sure that PyTorch has been successfully installed following the [official guide](https://pytorch.org/get-started/locally/).
Install MMEngine
```bash
pip install -U openmim
mim install mmengine
```
Verify the installation
```bash
python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'
```
## Get Started
As an example of training a ResNet-50 model on the CIFAR-10 dataset, we will build a complete, configurable training and validation process using MMEngine in less than 80 lines of code.
<details>
<summary>Build Models</summary>
First, we need to define a **Model** that 1) inherits from `BaseModel`, and 2) accepts an additional argument `mode` in the `forward` method, in addition to those arguments related to the dataset. During training, the value of `mode` is "loss" and the `forward` method should return a dict containing the key "loss". During validation, the value of `mode` is "predict" and the forward method should return results containing both predictions and labels.
```python
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```
</details>
<details>
<summary>Build Datasets</summary>
Next, we need to create a **Dataset** and **DataLoader** for training and validation.
In this case, we simply use built-in datasets supported in TorchVision.
```python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
val_dataloader = DataLoader(batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
'data/cifar10',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**norm_cfg)
])))
```
</details>
<details>
<summary>Build Metrics</summary>
To validate and test the model, we need to define a **Metric** like accuracy to evaluate the model. This metric needs inherit from `BaseMetric` and implements the `process` and `compute_metrics` methods.
```python
from mmengine.evaluator import BaseMetric
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
# Save the results of a batch to `self.results`
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
# Returns a dictionary with the results of the evaluated metrics,
# where the key is the name of the metric
return dict(accuracy=100 * total_correct / total_size)
```
</details>
<details>
<summary>Build a Runner</summary>
Finally, we can construct a **Runner** with previously defined Model, DataLoader, Metrics and some other configs, as shown below.
```python
from torch.optim import SGD
from mmengine.runner import Runner
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
# a wapper to execute back propagation and gradient update, etc.
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
# set some training configs like epochs
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
)
```
</details>
<details>
<summary>Launch Training</summary>
```python
runner.train()
```
</details>
## Contributing
We appreciate all contributions to improve MMEngine. Please refer to [CONTRIBUTING.md](CONTRIBUTING.md) for the contributing guideline.
## License
This project is released under the [Apache 2.0 license](LICENSE).
## Projects in OpenMMLab
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.