Git Product home page Git Product logo

Comments (8)

andsteing avatar andsteing commented on July 20, 2024

This is probably due to the data augmentation. Check out the "augmented" training data:

https://colab.sandbox.google.com/github/google-research/vision_transformer/blob/master/vit_jax.ipynb#scrollTo=jFqi3h7yMEsB

You can see that those images actually are a lot harder to recognize and thus the training accuracy is lower. Training on these still helps the network to score better on the testset though. You can modify the Colab to train without data augmentation, and you should see how the train accuracy is better than the test accuracy.

from vision_transformer.

chaoyanghe avatar chaoyanghe commented on July 20, 2024

@andsteing but why there isn't such a phenomenon when training with ResNet/MobileNet which also uses data augmentation?

from vision_transformer.

chaoyanghe avatar chaoyanghe commented on July 20, 2024

I found when I use 0.5 mean and std as follows, the test accuracy can be higher (96%) and the training accuracy is super low (82%), but when I change it to the real mean and std, the test accuracy drops to 93% but the training accuracy is good at 95%.

# CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
# CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

CIFAR_MEAN = [0.5, 0.5, 0.5]
CIFAR_STD = [0.5, 0.5, 0.5]

from vision_transformer.

andsteing avatar andsteing commented on July 20, 2024

Did you use ResNet/MobileNet on the same input pipeline? (Using data augmentation on different datasets with different parameters will produce different results)

I just re-ran the Colab with data augmentation disabled and got 0.97265625 on test set and 0.979321 on train set (i.e. slightly higher on train set, as expected).

Where exactly did you modify CIFAR_MEAN and CIFAR_STD ?

from vision_transformer.

chaoyanghe avatar chaoyanghe commented on July 20, 2024

my code is based on my own PyTorch Implementation. I fixed the issue of low training accuracy by changing the transforms from

transform_train = transforms.Compose([
    transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

to
transform_train = transforms.Compose([
transforms.Resize(args.img_size),
transforms.RandomCrop(args.img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

This confirms that the issue belongs to "RandomSizedCrop".

Also the 0.5 normalization can increase the accuracy around 3%.

But now I still only get 96.5% test accuracy, cannot replicate the 98% test accuracy. May I know what other data augmentation or tricks should I use?

from vision_transformer.

andsteing avatar andsteing commented on July 20, 2024

Indeed, I think in your snippet the observed difference has more to due with RandomSizedCrop() than with Normalize().

It might be interesting to see if this observed difference is larger for ViT than for ResNet/MobileNet, but I would expect for ResNet/MobileNet also to have a deterioration in training accuracy if you apply RandomSizedCrop() with the same parameters.

from vision_transformer.

chaoyanghe avatar chaoyanghe commented on July 20, 2024

Could you help to check my code? It is a single file that expresses the training and tricks.
https://github.com/FedML-AI/FedML/blob/fed-transformer/fedml_experiments/centralized/fed_transformer/main_vit.py

I am not sure whether my reimplementation is correct or not.

from vision_transformer.

chaoyanghe avatar chaoyanghe commented on July 20, 2024

Indeed, I think in your snippet the observed difference has more to due with RandomSizedCrop() than with Normalize().

It might be interesting to see if this observed difference is larger for ViT than for ResNet/MobileNet, but I would expect for ResNet/MobileNet also to have a deterioration in training accuracy if you apply RandomSizedCrop() with the same parameters.

I never experienced such a phenomenon before. I don't think ResNet has such an issue when using RandomSizedCrop()

from vision_transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.