Comments (9)
Exactly. That's why I always mentioned the --accum_steps
together with the batch size. I'm sorry if that was ambiguous and caused you unnecessary debugging work!
Note that if you check out the source code of tfds
you'll probably find an easy way to work with your uncompressed Imagenet data...
You might also find this repo interesting:
https://github.com/lucidrains/vit-pytorch
from vision_transformer.
Hi Zhiyuan
The results from tensorboard were achieved using the default flags (and --accum_steps=32
) on a 8x 16G V100 machine with the exact code from this repository.
Note that internally we produced the same results on a 64x TPUv2 (8GB) where we achieved 83.2% after 800 steps (~10 min).
Not sure what is different in your setup, but did you try to reproduce the same results as from the tensorboard with this repository and the default flags?
from vision_transformer.
Hi Zhiyuan
The results from tensorboard were achieved using the default flags (and
--accum_steps=32
) on a 8x 16G V100 machine with the exact code from this repository.Note that internally we produced the same results on a 64x TPUv2 (8GB) where we achieved 83.2% after 800 steps (~10 min).
Not sure what is different in your setup, but did you try to reproduce the same results as from the tensorboard with this repository and the default flags?
Thank you very much for replying.
I implemented a PyTorch version of the code, and I'm pretty sure the model part is right at least, as the converted pre-train weights achieved 85% (Thank you again for providing those weights btw). For the training part, I used torchvision's RandomResizedCrop
and get 84.1%. I also tried to use the provided transform in tensorflow, where I converted it to ndarray before convert to pytorch tensor at last, in that case, it would be around 76%.
Can I please confirm the total batch size for 8x 16G V100 is 512? I have read some flax documentation and it looks really beautiful to me. However, I only have uncompressed ImageNet on my machine, and it seems this implementation requires compressed tar :(
from vision_transformer.
Can I please confirm the total batch size for 8x 16G V100 is 512? I have read some flax documentation and it looks really beautiful to me. However, I only have uncompressed ImageNet on my machine, and it seems this implementation requires compressed tar :(
def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps):
"""Accumulate gradient over multiple steps to save on memory."""
if accum_steps and accum_steps > 1:
assert images.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
step_size = images.shape[0] // accum_steps
l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size])
def acc_grad_and_loss(i, l_and_g):
imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0),
(step_size,) + images.shape[1:])
lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0),
(step_size, labels.shape[1]))
li, gi = loss_and_grad_fn(params, imgs, lbls)
l, g = l_and_g
return (l + li, jax.tree_multimap(lambda x, y: x + y, g, gi))
l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
return jax.tree_map(lambda x: x / accum_steps, (l, g))
else:
return loss_and_grad_fn(params, images, labels)
The running batch size should be 512/accum_steps
from vision_transformer.
Exactly. That's why I always mentioned the
--accum_steps
together with the batch size. I'm sorry if that was ambiguous and caused you unnecessary debugging work!
lol, no worries, I was just a bit lazy to read the source code (which should have been done considering I'm trying to reproduce a result
Note that if you check out the source code of
tfds
you'll probably find an easy way to work with your uncompressed Imagenet data...
:( Thank you, I'll go ahead and find out~ (God I wish there could be a framework that combines the interface of PyTorch and the code quality of Google
You might also find this repo interesting:
Thank you for sharing it, I have checked it previously, but it seems our implementation is more close to the original design in this repo (do not change anything unnecessary).
from vision_transformer.
@andsteing Hope I'm not disturbing, can I ask what's the difference between inception_crop
and random resized crop?
if inception_crop:
channels = im.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(im),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0, # Don't enforce a minimum area.
use_image_if_no_bounding_boxes=True)
im = tf.slice(im, begin, size)
# Unfortunately, the above operation loses the depth-dimension. So we
# need to restore it the manual way.
im.set_shape([None, None, channels])
im = tf.image.resize(im, [crop_size, crop_size])
else:
im = tf.image.resize(im, [resize_size, resize_size])
im = tf.image.random_crop(im, [crop_size, crop_size, 3])
im = tf.image.flip_left_right(im)
from vision_transformer.
It's similar but different. You can check out the difference in the Colab by setting input_pipeline.get_data(..., inception_crop=False)
and then visually inspect the difference.
I have set inception_crop=True
by default because that's what we used internally, but it might not actually make a big difference in terms of final performance (another thing you can try in the Colab).
from vision_transformer.
It's similar but different. You can check out the difference in the Colab by setting
input_pipeline.get_data(..., inception_crop=False)
and then visually inspect the difference.I have set
inception_crop=True
by default because that's what we used internally, but it might not actually make a big difference in terms of final performance (another thing you can try in the Colab).
The reason why I'm asking this is because im = tf.image.random_crop(im, [crop_size, crop_size, 3])
looks not working at all as the image has been already resized to 384, 384
in the previous line.
I've just tested it locally and it seems to confirm my hypothesis
test code
for i in range(100):
auged = tf.image.resize(im, [384, 384])
auged = tf.image.random_crop(auged, [384, 384, 3])
auged = tf.image.flip_left_right(auged)
auged = auged.numpy().astype(np.uint8)
cv2.imwrite(f'tf/{i}.jpg', auged)
for i in range(100):
channels = im.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(tf.shape(im), tf.zeros([0, 0, 4], tf.float32), area_range=(0.05, 1.0), min_object_covered=0, use_image_if_no_bounding_boxes=True)
auged = tf.slice(im, begin, size)
auged.set_shape([None, None, channels])
auged = tf.image.resize(auged, [384, 384])
auged = auged.numpy().astype(np.uint8)
cv2.imwrite(f'inception/{i}.jpg', auged)
from vision_transformer.
@andsteing Thank you very much for helping
We found that it is resulted because the interpolation of posembed
was ignored
We have achieved promising results for now, and will release the code once the review process has finished.
from vision_transformer.
Related Issues (20)
- Fine-Tuning HOT 3
- ERROR: Could not find a version that satisfies the requirement tensorflow_text (from vit-jax) (from versions: none) ERROR: No matching distribution found for tensorflow_text HOT 1
- If the weights of vit-base trained with dropout available? HOT 1
- How do I download the vit_base_patch8_384.pth
- Package versions' confliction [Windows] HOT 2
- Question about commercial usage of LiT model checkpoints
- Shouldn't accumulate_gradient pass rng_key?
- GPU Requirement to use vision transformer HOT 1
- flax.errors.CallCompactUnboundModuleError
- ViT
- can export the pretrained model to onnx or pytorch? HOT 2
- Vision transformer
- Import error in Jax (colab) HOT 3
- fine-tune imagenet21k_ViT-B_16.npz with pre_logits? HOT 2
- KeyError: 'embedding/kernel is not a file in the archive' HOT 1
- Hyperparameter issues HOT 1
- All attempts to get a Google authentication bearer token failed, returning an empty token.
- Where is ViT-22B?
- Is it possible to provide the checkpoint?
- How can I use imagenet21k+imagenet2012/ViT-B_16.npz to inference
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vision_transformer.