Git Product home page Git Product logo

sret's Introduction

Sliced Recursive Transformer (SReT)

Pytorch implementation of our paper: Sliced Recursive Transformer (ECCV 2022), Zhiqiang Shen, Zechun Liu and Eric Xing.

FLOPs and Params Comparison

Our Approach

  • Recursion operation:
  • Sliced Group Self-Attention:

Abstract

We present a neat yet effective recursive operation on vision transformers that can improve parameter utilization without involving additional parameters. This is achieved by sharing weights across the depth of transformer networks. The proposed method can obtain a substantial gain of about 2% simply using naive recursive operation, requires no special or sophisticated knowledge for designing principles of networks, and introduces minimal computational overhead to the training procedure. To reduce the additional computation caused by recursive operation while maintaining the superior accuracy, we propose an approximating method through multiple sliced group self-attentions across recursive layers which can reduce the cost consumption by 10~30% with minimal performance loss. We call our model Sliced Recursive Transformer (SReT), a novel and parameter-efficient vision transformer design that is compatible with a broad range of other designs for efficient ViT architectures. Our best model establishes significant improvement on ImageNet-1K over state-of-the-art methods while containing fewer parameters. The flexible scalability has shown great potential for scaling up models and constructing extremely deep vision transformers.

SReT Models

Install timm using:

pip install git+https://github.com/rwightman/pytorch-image-models.git

Create SReT models:

import torch
import SReT

model = SReT.SReT_S(pretrained=False)
print(model(torch.randn(1, 3, 224, 224)))
...

Load pre-trained SReT models:

import torch
import SReT

model = SReT.SReT_S(pretrained=False)
model.load_state_dict(torch.load('./pre-trained/SReT_S.pth')['model'])
print(model(torch.randn(1, 3, 224, 224)))
...

Train SReT models with knowledge distillation (recommend training with FKD, which is faster with higher performance):

import torch
import 
import SReT
import kd_loss

criterion_kd = kd_loss.KDLoss()

model = SReT.SReT_S_distill(pretrained=False)
student_outputs = model(images)
...
# we use the soft label only for distillation procedure as MEAL V2
# Note that 'student_outputs' and 'teacher_outputs' are logits before softmax
loss = criterion_kd(student_outputs/T, teacher_outputs/T)
...

Pre-trained Model

We currently provide the last epoch checkpoints and will add the best ones together with more models soon. (⋇ indicates without slice.) We notice that using a larger initial lr (0.001 $\times$ $batchsize \over 512$) with longer warmup epochs = 30 can obtain better results on SReT.

Model FLOPs #params accuracy weights (last) weights (best) logs configurations
SReT_⋇T 1.4G 4.8M 76.1 link TBA link link
SReT_T 1.1G 4.8M 76.0 link TBA link link
SReT_⋇LT 1.4G 5.0M 76.8 link TBA link link
SReT_LT [8-4-1,2-1-1] 1.2 G 5.0M 76.7 link TBA link link
SReT_LT [16-14-1,1-1-1] 1.2 G 5.0M 76.6 link TBA link link
SReT_⋇S 4.7G 20.9M 82.0 link TBA link link
SReT_S 4.2G 20.9M 81.9 link TBA link link
SReT_⋇T_Distill 1.4G 4.8M 77.7 link TBA link link
SReT_T_Distill 1.1G 4.8M 77.6 link TBA link link
SReT_⋇LT_Distill 1.4G 5.0M 77.9 link TBA link link
SReT_LT_Distill 1.2G 5.0M 77.7 link TBA link link
SReT_⋇T_Distill_Finetune384 6.4G 4.9M 79.7 link TBA link link
SReT_⋇S_Distill_Finetune384 18.5G 21.0M 83.8 link TBA link link
SReT_⋇S_Distill_Finetune512 42.8G 21.3M 84.3 link TBA link link

Citation

If you find our code is helpful for your research, please cite:

@article{shen2021sliced,
      title={Sliced Recursive Transformer}, 
      author={Zhiqiang Shen and Zechun Liu and Eric Xing},
      year={2021},
      journal={arXiv preprint arXiv:2111.05297}
}

Contact

Zhiqiang Shen (zhiqiangshen0214 at gmail.com or zhiqians at andrew.cmu.edu)

sret's People

Contributors

szq0214 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

sret's Issues

hi

Hello, I saw that you mentioned in your paper that adding a recursive module alone is also an improvement. I don't quite understand what its specific role is. For the transformer structure, is it just feedforward not conflicting? In addition, I would like to ask if you have tried adding him to Swin, because I see from the structure that he mainly acts between steps. I am a beginner. If the question asked is very low-level, I hope you will give me some advice.

参数共享

I just started learning deep learning, so I would like to ask you how to realize parameter sharing in the iterative process of transformer. Can you cut out this small piece of code? Stupid question, thank you for your answer.。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.