Comments (8)
This is probably due to the data augmentation. Check out the "augmented" training data:
You can see that those images actually are a lot harder to recognize and thus the training accuracy is lower. Training on these still helps the network to score better on the testset though. You can modify the Colab to train without data augmentation, and you should see how the train accuracy is better than the test accuracy.
from vision_transformer.
@andsteing but why there isn't such a phenomenon when training with ResNet/MobileNet which also uses data augmentation?
from vision_transformer.
I found when I use 0.5 mean and std as follows, the test accuracy can be higher (96%) and the training accuracy is super low (82%), but when I change it to the real mean and std, the test accuracy drops to 93% but the training accuracy is good at 95%.
# CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
# CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
CIFAR_MEAN = [0.5, 0.5, 0.5]
CIFAR_STD = [0.5, 0.5, 0.5]
from vision_transformer.
Did you use ResNet/MobileNet on the same input pipeline? (Using data augmentation on different datasets with different parameters will produce different results)
I just re-ran the Colab with data augmentation disabled and got 0.97265625 on test set and 0.979321 on train set (i.e. slightly higher on train set, as expected).
Where exactly did you modify CIFAR_MEAN
and CIFAR_STD
?
from vision_transformer.
my code is based on my own PyTorch Implementation. I fixed the issue of low training accuracy by changing the transforms from
transform_train = transforms.Compose([
transforms.RandomSizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
to
transform_train = transforms.Compose([
transforms.Resize(args.img_size),
transforms.RandomCrop(args.img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
This confirms that the issue belongs to "RandomSizedCrop".
Also the 0.5 normalization can increase the accuracy around 3%.
But now I still only get 96.5% test accuracy, cannot replicate the 98% test accuracy. May I know what other data augmentation or tricks should I use?
from vision_transformer.
Indeed, I think in your snippet the observed difference has more to due with RandomSizedCrop()
than with Normalize()
.
It might be interesting to see if this observed difference is larger for ViT than for ResNet/MobileNet, but I would expect for ResNet/MobileNet also to have a deterioration in training accuracy if you apply RandomSizedCrop()
with the same parameters.
from vision_transformer.
Could you help to check my code? It is a single file that expresses the training and tricks.
https://github.com/FedML-AI/FedML/blob/fed-transformer/fedml_experiments/centralized/fed_transformer/main_vit.py
I am not sure whether my reimplementation is correct or not.
from vision_transformer.
Indeed, I think in your snippet the observed difference has more to due with
RandomSizedCrop()
than withNormalize()
.It might be interesting to see if this observed difference is larger for ViT than for ResNet/MobileNet, but I would expect for ResNet/MobileNet also to have a deterioration in training accuracy if you apply
RandomSizedCrop()
with the same parameters.
I never experienced such a phenomenon before. I don't think ResNet has such an issue when using RandomSizedCrop()
from vision_transformer.
Related Issues (20)
- ModuleNotFoundError: No module named 'aqt' HOT 4
- Problem with kmnist dataset HOT 1
- Fine-Tuning HOT 3
- ERROR: Could not find a version that satisfies the requirement tensorflow_text (from vit-jax) (from versions: none) ERROR: No matching distribution found for tensorflow_text HOT 1
- If the weights of vit-base trained with dropout available? HOT 1
- How do I download the vit_base_patch8_384.pth
- Package versions' confliction [Windows] HOT 2
- Question about commercial usage of LiT model checkpoints
- Shouldn't accumulate_gradient pass rng_key?
- GPU Requirement to use vision transformer HOT 1
- flax.errors.CallCompactUnboundModuleError
- ViT
- can export the pretrained model to onnx or pytorch? HOT 2
- Vision transformer
- Import error in Jax (colab) HOT 3
- fine-tune imagenet21k_ViT-B_16.npz with pre_logits? HOT 2
- KeyError: 'embedding/kernel is not a file in the archive' HOT 1
- Hyperparameter issues HOT 1
- All attempts to get a Google authentication bearer token failed, returning an empty token.
- Where is ViT-22B?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vision_transformer.