Git Product home page Git Product logo

mmengine's Introduction

 
OpenMMLab website HOT      OpenMMLab platform TRY IT OUT
 

PyPI - Python Version pytorch PyPI license

Introduction | Installation | Get Started | 📘Documentation | 🤔Reporting Issues

English | 简体中文

What's New

v0.10.4 was released on 2024-4-23.

Highlights:

  • Support custom artifact_location in MLflowVisBackend #1505
  • Enable exclude_frozen_parameters for DeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517

Read Changelog for more details.

Introduction

MMEngine is a foundational library for training deep learning models based on PyTorch. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects. Its highlights are as follows:

Integrate mainstream large-scale model training frameworks

Supports a variety of training strategies

Provides a user-friendly configuration system

Covers mainstream training monitoring platforms

Installation

Supported PyTorch Versions
MMEngine PyTorch Python
main >=1.6 <=2.1 >=3.8, <=3.11
>=0.9.0, <=0.10.4 >=1.6 <=2.1 >=3.8, <=3.11

Before installing MMEngine, please ensure that PyTorch has been successfully installed following the official guide.

Install MMEngine

pip install -U openmim
mim install mmengine

Verify the installation

python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'

Get Started

Taking the training of a ResNet-50 model on the CIFAR-10 dataset as an example, we will use MMEngine to build a complete, configurable training and validation process in less than 80 lines of code.

Build Models

First, we need to define a model which 1) inherits from BaseModel and 2) accepts an additional argument mode in the forward method, in addition to those arguments related to the dataset.

  • During training, the value of mode is "loss", and the forward method should return a dict containing the key "loss".
  • During validation, the value of mode is "predict", and the forward method should return results containing both predictions and labels.
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
Build Datasets

Next, we need to create Datasets and DataLoaders for training and validation. In this case, we simply use built-in datasets supported in TorchVision.

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))
val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))
Build Metrics

To validate and test the model, we need to define a Metric called accuracy to evaluate the model. This metric needs to inherit from BaseMetric and implements the process and compute_metrics methods.

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # Save the results of a batch to `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })
    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # Returns a dictionary with the results of the evaluated metrics,
        # where the key is the name of the metric
        return dict(accuracy=100 * total_correct / total_size)
Build a Runner

Finally, we can construct a Runner with previously defined Model, DataLoader, and Metrics, with some other configs, as shown below.

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    # a wrapper to execute back propagation and gradient update, etc.
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # set some training configs like epochs
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
Launch Training
runner.train()

Learn More

Tutorials
Advanced tutorials
Examples
Common Usage
Design
Migration guide

Contributing

We appreciate all contributions to improve MMEngine. Please refer to CONTRIBUTING.md for the contributing guideline.

Citation

If you find this project useful in your research, please consider cite:

@article{mmengine2022,
  title   = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
  author  = {MMEngine Contributors},
  howpublished = {\url{https://github.com/open-mmlab/mmengine}},
  year={2022}
}

License

This project is released under the Apache 2.0 license.

Ecosystem

Projects in OpenMMLab

  • MIM: MIM installs OpenMMLab packages.
  • MMCV: OpenMMLab foundational library for computer vision.
  • MMEval: A unified evaluation library for multiple machine learning libraries.
  • MMPreTrain: OpenMMLab pre-training toolbox and benchmark.
  • MMagic: OpenMMLab Advanced, Generative and Intelligent Creation toolbox.
  • MMDetection: OpenMMLab detection toolbox and benchmark.
  • MMYOLO: OpenMMLab YOLO series toolbox and benchmark.
  • MMDetection3D: OpenMMLab's next-generation platform for general 3D object detection.
  • MMRotate: OpenMMLab rotated object detection toolbox and benchmark.
  • MMTracking: OpenMMLab video perception toolbox and benchmark.
  • MMPose: OpenMMLab pose estimation toolbox and benchmark.
  • MMSegmentation: OpenMMLab semantic segmentation toolbox and benchmark.
  • MMOCR: OpenMMLab text detection, recognition, and understanding toolbox.
  • MMHuman3D: OpenMMLab 3D human parametric model toolbox and benchmark.
  • MMSelfSup: OpenMMLab self-supervised learning toolbox and benchmark.
  • MMFewShot: OpenMMLab fewshot learning toolbox and benchmark.
  • MMAction2: OpenMMLab's next-generation action understanding toolbox and benchmark.
  • MMFlow: OpenMMLab optical flow toolbox and benchmark.
  • MMDeploy: OpenMMLab model deployment framework.
  • MMRazor: OpenMMLab model compression toolbox and benchmark.
  • Playground: A central hub for gathering and showcasing amazing projects built upon OpenMMLab.

mmengine's People

Contributors

c1rn09 avatar dai-wenxun avatar enkilee avatar fanqino1 avatar gt9505 avatar haochenye avatar harold-lkk avatar hhaandroid avatar hit-cwh avatar ice-tong avatar imabackstabber avatar jbwang1997 avatar ly015 avatar lzhgrla avatar mambawong avatar mzr1996 avatar okotaku avatar plyfager avatar rangeking avatar rangilyu avatar sanbuphy avatar sjiang95 avatar teamwong111 avatar vansin avatar xiangxu-0103 avatar xin-li-67 avatar youkaichao avatar yuanliuuuuuu avatar zhouzaida avatar zwwwayne avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mmengine's Issues

`AverageModel` has bug in updating judgement

if self.steps % self.interval == 0:
avg_param = (
itertools.chain(self.module.parameters(),
self.module.buffers())
if self.update_buffers else self.parameters())
src_param = (
itertools.chain(model.parameters(), model.buffers())
if self.update_buffers else model.parameters())
for p_avg, p_src in zip(avg_param, src_param):
device = p_avg.device
p_src_ = p_src.detach().to(device)
if self.steps == 0:
p_avg.detach().copy_(p_src_)
else:
p_avg.detach().copy_(
self.avg_func(p_avg.detach(), p_src_,
self.steps.to(device)))
self.steps += 1

self.steps starts from 0. Should we change this condition to (self.step + 1) % self.interval == 0?

用户设置完了paramwise_cfg,如何知道是否符合预期,这个是否也要提供相应的脚本,用户运行后可以很容易的知道哪些参数被 frozen,不太参数组超参的不同。暂时没有时间开发的话,可以作为未来一个需求吧

用户设置完了,如何知道是否符合预期,这个是否也要提供相应的脚本,用户运行后可以很容易的知道哪些参数被 frozen,不太参数组超参的不同。暂时没有时间开发的话,可以作为未来一个需求吧

Originally posted by @hhaAndroid in #25 (comment)

CI codecov use `--source mmdet`

Thanks for your error report and we appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The bug has not been fixed in the latest version.

Describe the bug
A clear and concise description of what the bug is.

coverage run --branch --source mmdet -m pytest tests/

Reproduction

  1. What command or script did you run?
A placeholder for the command.
  1. Did you make any modifications on the code or config? Did you understand what you have modified?
  2. What dataset did you use?

Environment

  1. Please run python mmdet/utils/collect_env.py to collect necessary environment information and paste it here.
  2. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback
If applicable, paste the error trackback here.

A placeholder for trackback.

Bug fix
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

'Runner' object has no attribute 'log_buffer'

When I run fcos_r50_caffe_fpn_gn-head_1x_coco.py which has the setting of default_hooks = dict(optimizer=dict(type='OptimizerHook', grad_clip=dict(max_norm=35, norm_type=2))), the program reports an error as below.

File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    outputs=self.runner.outputs)
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/runner/runner.py", line 1304, in call_hook
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    getattr(hook, fn_name)(self, **kwargs)
  File "/mnt/cache/wangjiabao1.vendor/workspace/refactor/mmengine/mmengine/hooks/optimizer_hook.py", line 98, in after_train_iter
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
AttributeError: 'Runner' object has no attribute 'log_buffer'
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
AttributeError: 'Runner' object has no attribute 'log_buffer'
AttributeError: 'Runner' object has no attribute 'log_buffer'
AttributeError: 'Runner' object has no attribute 'log_buffer'
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
AttributeError: 'Runner' object has no attribute 'log_buffer'
AttributeError: 'Runner' object has no attribute 'log_buffer'
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
AttributeError: 'Runner' object has no attribute 'log_buffer'
    runner.log_buffer.update({'grad_norm': float(grad_norm)},
AttributeError: 'Runner' object has no attribute 'log_buffer'
phoenix-srun: error: SH-IDC1-10-140-0-252: tasks 0-7: Exited with exit code 1
phoenix-srun: Terminating job step 1084231.0

I consider if the log_buffer need to be replaced with message_hub or logger.

Config 文档说明不同格式 config 之间功能性上的差异

YAML/JSON/PY 格式支持的 config 内容范围是不一样的,例如 JSON 格式中不支持 tuple,因此 PY config 中的 tuple 在dump 到 JSON 以后会变成 list。
除此之外,还有一些功能性接口是针对 python config 支持的,这种情况下,接口和文档应当予以说明。
在 API 文档或者 config 教程中,应当清晰地列出对不同格式 config 的支持程度,以及不同格式 config 的局限性/差异性。

Enable automatically loading latest checkpoint from ceph

Describe the feature

Motivation
Since the storage is limited, more and more users save their checkpoints in ceph and leaves no checkpoints in the local working directory. However, when resuming the job, the auto-resume function is only able to find the checkpoint in the local path and cannot automatically load the checkpoints saved in ceph.

To solve this issue, a naive description can be as below:
When saving the checkpoints during training, no matter where the checkpoint is saved, save last_checkpoint.txt in the local&ceph working directory indicating the real path of the lastest checkpoint (can be either local storage or ceph). When auto-resuming the checkpoint in training, read the file and load the checkpoint based on the file string. Thus, users can safely use auto-resume using the command like below

sh ./tools/slurm_train.sh $PATITION $CONFIG $WORK_DIR --auto-resume

Or users can manually resume the model in a unified way no matter where the latest checkpoint is saved like below:

sh ./tools/slurm_train.sh $PATITION $CONFIG $WORK_DIR --load-from $WORK_DIR/last_checkpoint --resume

The last_checkpoint.txt serves as a soft like of the latest checkpoint across platforms and works for any kind of storages.

Related resources
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.

Additional context
Detectron2 has similar design.

Explain the priority design of meta info in code.

不行,如果这样,在lazy init=True情况,meta里的内容为用户传入meta(高优先级)与类属性 BaseDataset.META 字典(低优先级),之后调用full_init读取标注文件中的meta(中优先级),中优先级meta里的key不知道怎么覆盖高优先级与低优先级里的key

Originally posted by @GT9505 in #7 (comment)

Add documentation of evaluation on multiple dataset with multiple metric.

Describe the feature
Add documentation to show how to evaluate multiple datasets with multiple metrics and use one of the metrics of a dataset as the best indicator.

Motivation
Users might need to evaluate different metrics on multiple datasets.
In such a case, only one metric on one dataset needs to be selected to indicate whether the model is the best model and should be saved.
It is unnecessary to officially support this feature in MMEngine, but MMEngine supports users to create a new Loop class to support this feature. Therefore, we should update the documentation to show such an example.

Related resources
See a previous PR in MMSeg open-mmlab/mmsegmentation#1461

Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.

Support to visualize learning rate status before training

Describe the feature

Motivation
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when [....].
Ex2. There is a recent paper [....], which is very helpful for [....].

Related resources
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.

Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.

Clear some `type: ignore` flags

Describe the feature

Motivation
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when [....].
Ex2. There is a recent paper [....], which is very helpful for [....].

Related resources
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.

Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.

Fully support of different file clients in `BaseDataset`.

Describe the feature
We have already supported multiple file clients in BaseDataset, but some arguments are still not.
Especially the usage of the os.path package may cause incompatibility in different file clients, please check it.

Support auto-scaling LR in param_scheduler

Describe the feature

Motivation
It is quite common that users need to update LR based on their GPU numbers. A brief solution might be:
add an argument like default_batchsize somewhere, when start to initialize the param_scheduler, calculate the real batch_size then scale the LR based on their ratio. This enables different repos to set different default_batchsize for their own needs.

Related resources
See auto_scale_lr in mmdet

Additional context
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.