Git Product home page Git Product logo

westlake-ai / moganet Goto Github PK

View Code? Open in Web Editor NEW
131.0 2.0 12.0 4.24 MB

[ICLR 2024] MogaNet: Efficient Multi-order Gated Aggregation Network

Home Page: https://arxiv.org/abs/2211.03295

License: Apache License 2.0

Python 16.87% Jupyter Notebook 83.02% Shell 0.09% Makefile 0.01% CSS 0.01% Batchfile 0.02%
imagenet pytorch vision-transformer image-classification object-detection pose-estimation segmentation instance-segmentation video-prediction 3d-pose-estimation

moganet's Introduction

We propose MogaNet, a new family of efficient ConvNets designed through the lens of multi-order game-theoretic interaction, to pursue informative context mining with preferable complexity-performance trade-offs. It shows excellent scalability and attains competitive results among state-of-the-art models with more efficient use of model parameters on ImageNet and multifarious typical vision benchmarks, including COCO object detection, ADE20K semantic segmentation, 2D&3D human pose estimation, and video prediction.

This repository contains PyTorch implementation for MogaNet (ICLR 2024).

Table of Contents
  1. Catalog
  2. Image Classification
  3. License
  4. Acknowledgement
  5. Citation

Catalog

We plan to release implementations of MogaNet in a few months. Please watch us for the latest release. Currently, this repo is reimplemented according to our official implementations in OpenMixup, and we are working on cleaning up experimental results and code implementations. Models are released in GitHub / Baidu Cloud / Hugging Face.

  • ImageNet-1K Training and Validation Code with timm [code] [models] [Hugging Face 🤗]
  • ImageNet-1K Training and Validation Code in OpenMixup / MMPretrain (TODO)
  • Downstream Transfer to Object Detection and Instance Segmentation on COCO [code] [models] [demo]
  • Downstream Transfer to Semantic Segmentation on ADE20K [code] [models] [demo]
  • Downstream Transfer to 2D Human Pose Estimation on COCO [code] (baselines supported) [models] [demo]
  • Downstream Transfer to 3D Human Pose Estimation (baseline models will be supported)
  • Downstream Transfer to Video Prediction on MMNIST Variants [code] (baselines supported)
  • Image Classification on Google Colab and Notebook Demo [demo]

Image Classification

1. Installation

Please check INSTALL.md for installation instructions.

2. Training and Validation

See TRAINING.md for ImageNet-1K training and validation instructions, or refer to our OpenMixup implementations. We released pre-trained models on OpenMixup in moganet-in1k-weights. We have also reproduced ImageNet results with this repo and released args.yaml / summary.csv / model.pth.tar in moganet-in1k-weights. The parameters in the trained model can be extracted by code.

Here is a notebook demo of MogaNet which run the steps to perform inference with MogaNet for image classification.

3. ImageNet-1K Trained Models

Model Resolution Params (M) Flops (G) Top-1 / top-5 (%) Script Download
MogaNet-XT 224x224 2.97 0.80 76.5 | 93.4 args | script model | log
MogaNet-XT 256x256 2.97 1.04 77.2 | 93.8 args | script model | log
MogaNet-T 224x224 5.20 1.10 79.0 | 94.6 args | script model | log
MogaNet-T 256x256 5.20 1.44 79.6 | 94.9 args | script model | log
MogaNet-T* 256x256 5.20 1.44 80.0 | 95.0 config | script model | log
MogaNet-S 224x224 25.3 4.97 83.4 | 96.9 args | script model | log
MogaNet-B 224x224 43.9 9.93 84.3 | 97.0 args | script model | log
MogaNet-L 224x224 82.5 15.9 84.7 | 97.1 args | script model | log
MogaNet-XL 224x224 180.8 34.5 85.1 | 97.4 args | script model | log

4. Analysis Tools

(1) The code to count MACs of MogaNet variants.

python get_flops.py --model moganet_tiny

(2) The code to visualize Grad-CAM activation maps (or variants of Grad-CAM) of MogaNet and other popular architectures.

python cam_image.py --use_cuda --image_path /path/to/image.JPEG --model moganet_tiny --method gradcam

(back to top)

5. Downstream Tasks

Object Detection and Instance Segmentation on COCO
  • MogaNet + Mask R-CNN
  • Method Backbone Pretrain Params FLOPs Lr schd box mAP mask mAP Config Download
    Mask R-CNN MogaNet-XT ImageNet-1K 22.8M 185.4G 1x 40.7 37.6 config log / model
    Mask R-CNN MogaNet-T ImageNet-1K 25.0M 191.7G 1x 42.6 39.1 config log / model
    Mask R-CNN MogaNet-S ImageNet-1K 45.0M 271.6G 1x 46.6 42.2 config log / model
    Mask R-CNN MogaNet-B ImageNet-1K 63.4M 373.1G 1x 49.0 43.8 config log / model
    Mask R-CNN MogaNet-L ImageNet-1K 102.1M 495.3G 1x 49.4 44.2 config log / model
    Mask R-CNN MogaNet-T ImageNet-1K 25.0M 191.7G MS 3x 45.3 40.7 config log / model
    Mask R-CNN MogaNet-S ImageNet-1K 45.0M 271.6G MS 3x 48.5 43.1 config log / model
    Mask R-CNN MogaNet-B ImageNet-1K 63.4M 373.1G MS 3x 50.3 44.4 config log / model
    Mask R-CNN MogaNet-L ImageNet-1K 63.4M 373.1G MS 3x 50.6 44.6 config log / model
  • MogaNet + RetinaNet
  • Method Backbone Pretrain Params FLOPs Lr schd box mAP Config Download
    RetinaNet MogaNet-XT ImageNet-1K 12.1M 167.2G 1x 39.7 config log / model
    RetinaNet MogaNet-T ImageNet-1K 14.4M 173.4G 1x 41.4 config log / model
    RetinaNet MogaNet-S ImageNet-1K 35.1M 253.0G 1x 45.8 config log / model
    RetinaNet MogaNet-B ImageNet-1K 53.5M 354.5G 1x 47.7 config log / model
    RetinaNet MogaNet-L ImageNet-1K 92.4M 476.8G 1x 48.7 config log / model
  • MogaNet + Cascade Mask R-CNN
  • Method Backbone Pretrain Params FLOPs Lr schd box mAP mask mAP Config Download
    Cascade Mask R-CNN MogaNet-S ImageNet-1K 77.9M 405.4G MS 3x 51.4 44.9 config log / model
    Cascade Mask R-CNN MogaNet-S ImageNet-1K 82.8M 750.2G GIOU+MS 3x 51.7 45.1 config log / model
    Cascade Mask R-CNN MogaNet-B ImageNet-1K 101.2M 851.6G GIOU+MS 3x 52.6 46.0 config log / model
    Cascade Mask R-CNN MogaNet-L ImageNet-1K 139.9M 973.8G GIOU+MS 3x 53.3 46.1 config -
    Semantic Segmentation on ADE20K
  • MogaNet + Semantic FPN
  • Method Backbone Pretrain Params FLOPs Iters mIoU mAcc Config Download
    Semantic FPN MogaNet-XT ImageNet-1K 6.9M 101.4G 80K 40.3 52.4 config log / model
    Semantic FPN MogaNet-T ImageNet-1K 9.1M 107.8G 80K 43.1 55.4 config log / model
    Semantic FPN MogaNet-S ImageNet-1K 29.1M 189.7G 80K 47.7 59.8 config log / model
    Semantic FPN MogaNet-B ImageNet-1K 47.5M 293.6G 80K 49.3 61.6 config log / model
    Semantic FPN MogaNet-L ImageNet-1K 86.2M 418.7G 80K 50.2 63.0 config log / model
  • MogaNet + UperNet
  • Method Backbone Pretrain Params FLOPs Iters mIoU mAcc Config Download
    UperNet MogaNet-XT ImageNet-1K 30.4M 855.7G 160K 42.2 55.1 config log / model
    UperNet MogaNet-T ImageNet-1K 33.1M 862.4G 160K 43.7 57.1 config log / model
    UperNet MogaNet-S ImageNet-1K 55.3M 946.4G 160K 49.2 61.6 config log / model
    UperNet MogaNet-B ImageNet-1K 73.7M 1050.4G 160K 50.1 63.4 config log / model
    UperNet MogaNet-L ImageNet-1K 113.2M 1176.1G 160K 50.9 63.5 config log / model
    2D Human Pose Estimation on COCO
  • MogaNet + Top-Down
  • Backbone Input Size Params FLOPs AP AP50 AP75 AR ARM ARL Config Download
    MogaNet-XT 256x192 5.6M 1.8G 72.1 89.7 80.1 77.7 73.6 83.6 config log | model
    MogaNet-XT 384x288 5.6M 4.2G 74.7 90.1 81.3 79.9 75.9 85.9 config log | model
    MogaNet-T 256x192 8.1M 2.2G 73.2 90.1 81.0 78.8 74.9 84.4 config log | model
    MogaNet-T 384x288 8.1M 4.9G 75.7 90.6 82.6 80.9 76.8 86.7 config log | model
    MogaNet-S 256x192 29.0M 6.0G 74.9 90.7 82.8 80.1 75.7 86.3 config log | model
    MogaNet-S 384x288 29.0M 13.5G 76.4 91.0 83.3 81.4 77.1 87.7 config log | model
    MogaNet-B 256x192 47.4M 10.9G 75.3 90.9 83.3 80.7 76.4 87.1 config log | model
    MogaNet-B 384x288 47.4M 24.4G 77.3 91.4 84.0 82.2 77.9 88.5 config log | model
    Video Prediction on Moving MNIST
    Architecture Setting Params FLOPs FPS MSE MAE SSIM PSNR Download
    IncepU (SimVPv1) 200 epoch 58.0M 19.4G 209 32.15 89.05 0.9268 21.84 model | log
    gSTA (SimVPv2) 200 epoch 46.8M 16.5G 282 26.69 77.19 0.9402 22.78 model | log
    ViT 200 epoch 46.1M 16.9G 290 35.15 95.87 0.9139 21.67 model | log
    Swin Transformer 200 epoch 46.1M 16.4G 294 29.70 84.05 0.9331 22.22 model | log
    Uniformer 200 epoch 44.8M 16.5G 296 30.38 85.87 0.9308 22.13 model | log
    MLP-Mixer 200 epoch 38.2M 14.7G 334 29.52 83.36 0.9338 22.22 model | log
    ConvMixer 200 epoch 3.9M 5.5G 658 32.09 88.93 0.9259 21.93 model | log
    Poolformer 200 epoch 37.1M 14.1G 341 31.79 88.48 0.9271 22.03 model | log
    ConvNeXt 200 epoch 37.3M 14.1G 344 26.94 77.23 0.9397 22.74 model | log
    VAN 200 epoch 44.5M 16.0G 288 26.10 76.11 0.9417 22.89 model | log
    HorNet 200 epoch 45.7M 16.3G 287 29.64 83.26 0.9331 22.26 model | log
    MogaNet 200 epoch 46.8M 16.5G 255 25.57 75.19 0.9429 22.99 model | log
    IncepU (SimVPv1) 2000 epoch 58.0M 19.4G 209 21.15 64.15 0.9536 23.99 model | log
    gSTA (SimVPv2) 2000 epoch 46.8M 16.5G 282 15.05 49.80 0.9675 25.97 model | log
    ViT 2000 epoch 46.1M 16.9.G 290 19.74 61.65 0.9539 24.59 model | log
    Swin Transformer 2000 epoch 46.1M 16.4G 294 19.11 59.84 0.9584 24.53 model | log
    Uniformer 2000 epoch 44.8M 16.5G 296 18.01 57.52 0.9609 24.92 model | log
    MLP-Mixer 2000 epoch 38.2M 14.7G 334 18.85 59.86 0.9589 24.58 model | log
    ConvMixer 2000 epoch 3.9M 5.5G 658 22.30 67.37 0.9507 23.73 model | log
    Poolformer 2000 epoch 37.1M 14.1G 341 20.96 64.31 0.9539 24.15 model | log
    ConvNeXt 2000 epoch 37.3M 14.1G 344 17.58 55.76 0.9617 25.06 model | log
    VAN 2000 epoch 44.5M 16.0G 288 16.21 53.57 0.9646 25.49 model | log
    HorNet 2000 epoch 45.7M 16.3G 287 17.40 55.70 0.9624 25.14 model | log
    MogaNet 2000 epoch 46.8M 16.5G 255 15.67 51.84 0.9661 25.70 model | log
    Video Prediction on Moving FMNIST
    Architecture Setting Params FLOPs FPS MSE MAE SSIM PSNR Download
    IncepU (SimVPv1) 200 epoch 58.0M 19.4G 209 30.77 113.94 0.8740 21.81 model | log
    gSTA (SimVPv2) 200 epoch 46.8M 16.5G 282 25.86 101.22 0.8933 22.61 model | log
    ViT 200 epoch 46.1M 16.9.G 290 31.05 115.59 0.8712 21.83 model | log
    Swin Transformer 200 epoch 46.1M 16.4G 294 28.66 108.93 0.8815 22.08 model | log
    Uniformer 200 epoch 44.8M 16.5G 296 29.56 111.72 0.8779 21.97 model | log
    MLP-Mixer 200 epoch 38.2M 14.7G 334 28.83 109.51 0.8803 22.01 model | log
    ConvMixer 200 epoch 3.9M 5.5G 658 31.21 115.74 0.8709 21.71 model | log
    Poolformer 200 epoch 37.1M 14.1G 341 30.02 113.07 0.8750 21.95 model | log
    ConvNeXt 200 epoch 37.3M 14.1G 344 26.41 102.56 0.8908 22.49 model | log
    VAN 200 epoch 44.5M 16.0G 288 31.39 116.28 0.8703 22.82 model | log
    HorNet 200 epoch 45.7M 16.3G 287 29.19 110.17 0.8796 22.03 model | log
    MogaNet 200 epoch 46.8M 16.5G 255 25.14 99.69 0.8960 22.73 model | log

    License

    This project is released under the Apache 2.0 license.

    Acknowledgement

    Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.

    • pytorch-image-models (timm): PyTorch image models, scripts, pretrained weights.
    • PoolFormer: Official PyTorch implementation of MetaFormer.
    • ConvNeXt: Official PyTorch implementation of ConvNeXt.
    • OpenMixup: Open-source toolbox for visual representation learning.
    • MMDetection: OpenMMLab Detection Toolbox and Benchmark.
    • MMSegmentation: OpenMMLab Semantic Segmentation Toolbox and Benchmark.
    • MMPose: OpenMMLab Pose Estimation Toolbox and Benchmark.
    • MMHuman3D: OpenMMLab 3D Human Parametric Model Toolbox and Benchmark.
    • OpenSTL: A Comprehensive Benchmark of Spatio-Temporal Predictive Learning.

    Citation

    If you find this repository helpful, please consider citing:

    @inproceedings{iclr2024MogaNet,
      title={MogaNet: Multi-order Gated Aggregation Network},
      author={Siyuan Li and Zedong Wang and Zicheng Liu and Cheng Tan and Haitao Lin and Di Wu and Zhiyuan Chen and Jiangbin Zheng and Stan Z. Li},
      booktitle={International Conference on Learning Representations},
      year={2024}
    }
    

    (back to top)

    moganet's People

    Contributors

    jacky1128 avatar lupin1998 avatar

    Stargazers

     avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

    Watchers

     avatar  avatar

    moganet's Issues

    Unable to train model

    Thanks for your significant paper. However, I encountered a problem when I ran the instruction code for training:

    File "/usr/local/lib/python3.10/dist-packages/mmcv/utils/registry.py", line 72, in build_from_cfg
    raise type(e)(f'{obj_cls.name}: {e}')
    urllib.error.URLError: <urlopen error MaskRCNN: <urlopen error MogaNet_feat: <urlopen error [Errno 104] Connection reset by peer>>>

    I appreciate your help!

    About load pretrained models error

    Hi! Thanks for your code release!
    when I use moganet_base, I set pretrained=True, the error as follows:

    RuntimeError: Error(s) in loading state_dict for MogaNet: Unexpected key(s) in state_dict: "head.weight", "head.bias"
    Can u give me some advice?

    Code Issue about MultiOrderGatedAggregation

    MogaNet/models/moganet.py

    Lines 264 to 333 in cd53ea0

    class MultiOrderGatedAggregation(nn.Module):
    """Spatial Block with Multi-order Gated Aggregation.
    Args:
    embed_dims (int): Number of input channels.
    attn_dw_dilation (list): Dilations of three DWConv layers.
    attn_channel_split (list): The raletive ratio of splited channels.
    attn_act_type (str): The activation type for Spatial Block.
    Defaults to 'SiLU'.
    """
    def __init__(self,
    embed_dims,
    attn_dw_dilation=[1, 2, 3],
    attn_channel_split=[1, 3, 4],
    attn_act_type='SiLU',
    attn_force_fp32=False,
    ):
    super(MultiOrderGatedAggregation, self).__init__()
    self.embed_dims = embed_dims
    self.attn_force_fp32 = attn_force_fp32
    self.proj_1 = nn.Conv2d(
    in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
    self.gate = nn.Conv2d(
    in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
    self.value = MultiOrderDWConv(
    embed_dims=embed_dims,
    dw_dilation=attn_dw_dilation,
    channel_split=attn_channel_split,
    )
    self.proj_2 = nn.Conv2d(
    in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
    # activation for gating and value
    self.act_value = build_act_layer(attn_act_type)
    self.act_gate = build_act_layer(attn_act_type)
    # decompose
    self.sigma = ElementScale(
    embed_dims, init_value=1e-5, requires_grad=True)
    def feat_decompose(self, x):
    x = self.proj_1(x)
    # x_d: [B, C, H, W] -> [B, C, 1, 1]
    x_d = F.adaptive_avg_pool2d(x, output_size=1)
    x = x + self.sigma(x - x_d)
    x = self.act_value(x)
    return x
    def forward_gating(self, g, v):
    with torch.autocast(device_type='cuda', enabled=False):
    g = g.to(torch.float32)
    v = v.to(torch.float32)
    return self.proj_2(self.act_gate(g) * self.act_gate(v))
    def forward(self, x):
    shortcut = x.clone()
    # proj 1x1
    x = self.feat_decompose(x)
    # gating and value branch
    g = self.gate(x)
    v = self.value(x)
    # aggregation
    if not self.attn_force_fp32:
    x = self.proj_2(self.act_gate(g) * self.act_gate(v))
    else:
    x = self.forward_gating(self.act_gate(g), self.act_gate(v))
    x = x + shortcut
    return x

    image

    Hi! Thank you for your great work! MultiOrderGatedAggregation模块的实现与论文不符,论文图中并没有shortcut,且FD的激活函数用的GELU。请问,我应该遵循哪个呢?

    Distributions of the interaction strength

    hi, thank you for your nice work.
    could you offer us your code of Distributions of the interaction strength , which, i believe, is a new perspective of networks.

    Cascade Mask RCNN Configuration

    Congratulations on the ICLR24 acceptance.

    I apologize if I missed it, but I was unable to find the cascade rcnn config file. Would it be possible to share it, or provide me with a link to its location?

    What is "trivial interactions" mentioned in the paper?

    In paper, authors wrote " we propose FD(·) to dynamically exclude trivial interactions" and "By re-weighting the trivial interaction component Y − GAP(Y ), FD(·) also increase feature diversities"

    What exactly is this "trivial interactions"? And why taking Y - GAP(Y) can increase feature diversities?

    cooldown epochs

    Thank you for your great work!

    As far as I know, models such as DeiT and ConvNext do not use "cooldown_epochs".
    However, the code looks like MogaNet was trained in 310 epochs rather than 300 epochs. Are the accuracies in the paper posted on openreview all learned from 310 epochs?

    About pretrain models.

    Thank you for your work on MogaNet. When can the pre-training model be released? Thank you!

    How do I create an interference code

    I'd like to make an interference code.

    I'm trying to create an interference code, but I keep getting this error: "EncoderDecoder: 'MogaNet_feat is not in the models registry'.

    It doesn't go back to the mmsegmentation demo code, so please.

    Recommend Projects

    • React photo React

      A declarative, efficient, and flexible JavaScript library for building user interfaces.

    • Vue.js photo Vue.js

      🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

    • Typescript photo Typescript

      TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

    • TensorFlow photo TensorFlow

      An Open Source Machine Learning Framework for Everyone

    • Django photo Django

      The Web framework for perfectionists with deadlines.

    • D3 photo D3

      Bring data to life with SVG, Canvas and HTML. 📊📈🎉

    Recommend Topics

    • javascript

      JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

    • web

      Some thing interesting about web. New door for the world.

    • server

      A server is a program made to process requests and deliver data to clients.

    • Machine learning

      Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

    • Game

      Some thing interesting about game, make everyone happy.

    Recommend Org

    • Facebook photo Facebook

      We are working to build community through open source technology. NB: members must have two-factor auth.

    • Microsoft photo Microsoft

      Open source projects and samples from Microsoft.

    • Google photo Google

      Google ❤️ Open Source for everyone.

    • D3 photo D3

      Data-Driven Documents codes.