Comments (4)
@Zhu-haow
When I drop part of the pretrained model weight., for example, the projection 'embedding' layer(between the R50 and Transformer Encoder)
Actually weights depend on arcitecture, so if you remove any intermediate layer then the rest of the network becomes "useless" and needs to be trained again
from vision_transformer.
In fact, I keep all layers. I just remove a single conv pretrained weight and random initialize this conv again. Nothing changes except this. But when I train this model, I got significant lower accuracy (10 percent).
@Zhu-haow
When I drop part of the pretrained model weight., for example, the projection 'embedding' layer(between the R50 and Transformer Encoder)
Actually weights depend on arcitecture, so if you remove any intermediate layer then the rest of the network becomes "useless" and needs to be trained again
from vision_transformer.
@Zhu-haow Hi Zhu, Im new to flax and jax, and am having trouble fine tuning the pre-trained model on self-defined dataset, would you mind sharing how to build self-defined dataset as well as how to do the fine-tuning...
I have tried the command in main page README but the job got killed all the time.. I am kinda stuck now... Would be great to learn from you. Many thanks!
from vision_transformer.
@Zhu-haow Hi Zhu, Im new to flax and jax, and am having trouble fine tuning the pre-trained model on self-defined dataset, would you mind sharing how to build self-defined dataset as well as how to do the fine-tuning...
I have tried the command in main page README but the job got killed all the time.. I am kinda stuck now... Would be great to learn from you. Many thanks!
Hope can be helpful. you can use below code to get data
` all_image_paths = list(data_path.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths] # 所有图片路径的列表
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))`
then use
data = data.map(_pp, tf.data.experimental.AUTOTUNE)
to map your data to preprocess function of _pp
from vision_transformer.
Related Issues (20)
- ModuleNotFoundError: No module named 'aqt' HOT 4
- Problem with kmnist dataset HOT 1
- Fine-Tuning HOT 3
- ERROR: Could not find a version that satisfies the requirement tensorflow_text (from vit-jax) (from versions: none) ERROR: No matching distribution found for tensorflow_text HOT 1
- If the weights of vit-base trained with dropout available? HOT 1
- How do I download the vit_base_patch8_384.pth
- Package versions' confliction [Windows] HOT 2
- Question about commercial usage of LiT model checkpoints
- Shouldn't accumulate_gradient pass rng_key?
- GPU Requirement to use vision transformer HOT 1
- flax.errors.CallCompactUnboundModuleError
- ViT
- can export the pretrained model to onnx or pytorch? HOT 2
- Vision transformer
- Import error in Jax (colab) HOT 3
- fine-tune imagenet21k_ViT-B_16.npz with pre_logits? HOT 2
- KeyError: 'embedding/kernel is not a file in the archive' HOT 1
- Hyperparameter issues HOT 1
- All attempts to get a Google authentication bearer token failed, returning an empty token.
- Where is ViT-22B?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vision_transformer.