Comments (9)
In the vanilla vit, the forward pass takes in an image and returns its transformed output. In the distilled vit, the distillwrapper forward pass takes an image + labels and returns the loss. Fastai gets tripped up on this because it automatically supplies the model with labels and is expecting the forward pass to return a transformed output. If there is a way to decouple the forward pass and the loss calculation I suspect it may work.
from vit-pytorch.
@lwomalley Hi Logan! I know of FastAI but not too familiar with their API
What would it take to be compatible?
from vit-pytorch.
@lwomalley yup, but the problem is the distillation comes with an auxiliary loss that gets returned. Will FastAI know to add this to the main loss it calculates?
from vit-pytorch.
Dont know if can help, but are "vit" and "distilled vit" layers? what is the auxiliary loss? or how it will be in a "normal loop" used?
This is where the loss is calculated, but AFAIK we have the options of callbacks and also we can "replace" this method for the learner itself with a new implementation/different code.
https://github.com/fastai/fastai/blob/master/fastai/learner.py#L172-L173
from vit-pytorch.
@tyoc213 yeah, it won't work, because that line needs to also add the distillation loss https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/distill.py#L120 i could return the logits on the first element of the tuple, and the auxiliary loss on the second, but FastAI will still need to sum the auxiliary loss into the calculated one
from vit-pytorch.
mmmm, so the forward is this https://github.com/fastai/fastai/blob/master/fastai/learner.py#L169 in fastai returns a tuple IIRC and we can apply watever we want doing a transform at https://github.com/fastai/fastai/blob/master/fastai/learner.py#L174 which is "call the transforms that answer to after_loss
event" and all the transforms have access to learner and other things see https://docs.fast.ai/callback.core.html#Callback
after_loss
called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).
from vit-pytorch.
@tyoc213 I see, so I'd have to store the auxiliary loss on the instance somewhere? and then in the callback it would be fetched and added to the main loss?
from vit-pytorch.
@tyoc213 do you have any code examples of how you are using ViT
with FastAI?
from vit-pytorch.
Not, but maybe this tiny bit can help https://youtu.be/4w3sEgqDvSo?t=1148 ?
from vit-pytorch.
Related Issues (20)
- Potential regression with PT 2.0 and CUDA 12.2/CuDNN 8.9.4 HOT 1
- Not correctly understanding the Multi Head Attention part of the ViT implementation... HOT 3
- CvT with 1 channel input data HOT 2
- Layernorm in Cross attention HOT 4
- how to train HOT 2
- Questions about distill_loss HOT 1
- Trouble loading ViT - Dino structure for channels>3?
- Question regarding 1d fft use HOT 1
- Masking attention with batches
- can we use CvT model for segmentation?
- Multi-target Regression Question
- Problems regarding training 3D Vision transformer : model does not converge
- Add implementation of LongVit HOT 4
- A question with ViT 3d
- Cuda memory for 3D VIT HOT 2
- PyPi page markdown render HOT 1
- CrossViT does not handle other than three channel images HOT 2
- Non-deterministic results based on group_max_seq_len in NaViT HOT 3
- Whether to include pre-trained models
- Request for Pre-trained Weights for Vit
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vit-pytorch.