Git Product home page Git Product logo

wavemix's Introduction

WaveMix

PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC

Resource-efficient Token Mixing for Images using 2D Discrete Wavelet Transform

WaveMix Architecture

image

WaveMix-Lite

image

We propose WaveMix– a novel neural architecture for computer vision that is resource-efficient yet generalizable and scalable. WaveMix networks achieve comparable or better accuracy than the state-of-the-art convolutional neural networks, vision transformers, and token mixers for several tasks, establishing new benchmarks for segmentation on Cityscapes; and for classification on Places-365, f ive EMNIST datasets, and iNAT-mini. Remarkably, WaveMix architectures require fewer parameters to achieve these benchmarks compared to the previous state-of-the-art. Moreover, when controlled for the number of parameters, WaveMix requires lesser GPU RAM, which translates to savings in time, cost, and energy. To achieve these gains we used multi-level two-dimensional discrete wavelet transform (2D-DWT) in WaveMix blocks, which has the following advantages: (1) It reorganizes spatial information based on three strong image priors– scale-invariance, shift-invariance, and sparseness of edges, (2) in a lossless manner without adding parameters, (3) while also reducing the spatial sizes of feature maps, which reduces the memory and time required for forward and backward passes, and (4) expanding the receptive field faster than convolutions do. The whole architecture is a stack of self-similar and resolution-preserving WaveMix blocks, which allows architectural f lexibility for various tasks and levels of resource availability.

Task Dataset Metric Value
Semantic Segmentation Cityscapes Single-scale mIoU 83.04% (SOTA)
Image Classification ImageNet-1k Accuracy 75.32%
Image Classification CIFAR-10 Accuracy 95.98%
Image Classification Galaxy 10 DECals Accuracy 95.42% (SOTA)

Parameter Efficiency

Task Model Parameters
99% Accu. in MNIST WaveMix Lite-8/10 3566
90% Accu. in Fashion MNIST WaveMix Lite-8/5 7156
80% Accu. in CIFAR-10 WaveMix Lite-32/7 37058
90% Accu. in CIFAR-10 WaveMix Lite-64/6 520106

The high parameter efficiency is obtained by replacing Deconvolution layers with Upsampling

This is an implementation of code from the following papers : Openreview Paper, ArXiv Paper 1, ArXiv Paper 2

Install

$ pip install wavemix

Usage

Semantic Segmentation

import torch, wavemix
from wavemix.SemSegment import WaveMix
import torch

model = WaveMix(
    num_classes= 20, 
    depth= 16,
    mult= 2,
    ff_channel= 256,
    final_dim= 256,
    dropout= 0.5,
    level=4,
    stride=2
)

img = torch.randn(1, 3, 256, 256)

preds = model(img) # (1, 20, 256, 256)

Image Classification

import torch, wavemix
from wavemix.classification import WaveMix
import torch

model = WaveMix(
    num_classes= 1000, 
    depth= 16,
    mult= 2,
    ff_channel= 192,
    final_dim= 192,
    dropout= 0.5,
    level=3,
    patch_size=4,
)
img = torch.randn(1, 3, 256, 256)

preds = model(img) # (1, 1000)

Single Image Super-resolution

import wavemix, torch
from wavemix.sisr import WaveMix

model = WaveMix(
    depth = 4,
    mult = 2,
    ff_channel = 144,
    final_dim = 144,
    dropout = 0.5,
    level=1,
)

img = torch.randn(1, 3, 256, 256)
out = model(img) # (1, 3, 512, 512)

To use a single Waveblock

import wavemix, torch
from wavemix import Level1Waveblock

Parameters

  • num_classes: int.
    Number of classes to classify/segment.
  • depth: int.
    Number of WaveMix blocks.
  • mult: int.
    Expansion of channels in the MLP (FeedForward) layer.
  • ff_channel: int.
    No. of output channels from the MLP (FeedForward) layer.
  • final_dim: int.
    Final dimension of output tensor after initial Conv layers. Channel dimension when tensor is fed to WaveBlocks.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • level: int.
    Number of levels of 2D wavelet transform to be used in Waveblocks. Currently supports levels from 1 to 4.
  • stride: int.
    Stride used in the initial convolutional layers to reduce the input resolution before being fed to Waveblocks.
  • initial_conv: str.
    Deciding between strided convolution or patchifying convolutions in the intial conv layer. Used for classification. 'pachify' or 'strided'.
  • patch_size: int.
    Size of each non-overlaping patch in case of patchifying convolution. Should be a multiple of 4.

Cite the following papers

@misc{
p2022wavemix,
title={WaveMix: Multi-Resolution Token Mixing for Images},
author={Pranav Jeevan P and Amit Sethi},
year={2022},
url={https://openreview.net/forum?id=tBoSm4hUWV}
}

@misc{jeevan2022wavemix,
    title={WaveMix: Resource-efficient Token Mixing for Images},
    author={Pranav Jeevan and Amit Sethi},
    year={2022},
    eprint={2203.03689},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

@misc{jeevan2023wavemix,
      title={WaveMix: A Resource-efficient Neural Network for Image Analysis}, 
      author={Pranav Jeevan and Kavitha Viswanathan and Anandu A S and Amit Sethi},
      year={2023},
      eprint={2205.14375},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

wavemix's People

Contributors

pranavphoenix avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

wavemix's Issues

How to use this model on Tiny ImageNet for classification

I am a computer visions student who is thinking about using your WaveMix model for my final project. I am planning to use it on the Tiny ImageNet dataset for classification and model evaluation. However, I have to run all my deep learning code on Google Colab, since I do not have an independent and powerful GPU.

So, I have a few questions regarding using the model:

  1. How can I load the Tiny ImageNet dataset in Google Colab? Can you give a python code for it?
  2. The Readme said that WaveMix reduces the use of GPU resources. Is my T4 GPU on Google Colab enough for training the Tiny ImageNet?
  3. Where is the code for Tiny ImageNet classification? Is there a pre-trained model for it?
  4. In case I want to retrain the model, where is the code for training the model?

How do you actually use this?

Calling model(img) on an arbitrary image resized to 256, and unsqueezed to give it the correct dimension (1,3,256,256) does not actually work. What else are you supposed to do to the image before giving it to the model for inference? Very frustrating.

about upsampling

In many researches, dwt and idwt often go hand in hand.
Could you please tell me why your paper uses transposed convolution instead of idwt for upsampling?

use pretrained Places365, The results are the same for all images

Hello author, i'm tried to use your pretrained model in Places365 Dataset:
import torch, wavemix
from wavemix.classification import WaveMix
model = WaveMix(
num_classes = 365,
depth = 12,
mult = 2,
ff_channel = 256,
final_dim = 256,
dropout = 0.5,
level = 2,
initial_conv = 'pachify',
patch_size = 8
)
url = 'https://huggingface.co/cloudwalker/wavemix/resolve/main/Saved_Models_Weights/Places365/places365_54.94.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(url))

and when i give many images to model, i always get result:
top_values, top_indices = torch.topk(preds, 5)
print("top_values:", top_values)
print("top_indices :", top_indices)

top_values: tensor([[2.0731, 1.9153, 1.7019, 1.5919, 1.5876]], device='cuda:0')
top_indices: tensor([[ 12, 67, 270, 103, 317]], device='cuda:0')
Did I do something wrong at any step?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.