Git Product home page Git Product logo

Comments (9)

andsteing avatar andsteing commented on August 21, 2024 1

Exactly. That's why I always mentioned the --accum_steps together with the batch size. I'm sorry if that was ambiguous and caused you unnecessary debugging work!

Note that if you check out the source code of tfds you'll probably find an easy way to work with your uncompressed Imagenet data...

You might also find this repo interesting:

https://github.com/lucidrains/vit-pytorch

from vision_transformer.

andsteing avatar andsteing commented on August 21, 2024

Hi Zhiyuan

The results from tensorboard were achieved using the default flags (and --accum_steps=32) on a 8x 16G V100 machine with the exact code from this repository.

Note that internally we produced the same results on a 64x TPUv2 (8GB) where we achieved 83.2% after 800 steps (~10 min).

Not sure what is different in your setup, but did you try to reproduce the same results as from the tensorboard with this repository and the default flags?

from vision_transformer.

ZhiyuanChen avatar ZhiyuanChen commented on August 21, 2024

Hi Zhiyuan

The results from tensorboard were achieved using the default flags (and --accum_steps=32) on a 8x 16G V100 machine with the exact code from this repository.

Note that internally we produced the same results on a 64x TPUv2 (8GB) where we achieved 83.2% after 800 steps (~10 min).

Not sure what is different in your setup, but did you try to reproduce the same results as from the tensorboard with this repository and the default flags?

Thank you very much for replying.

I implemented a PyTorch version of the code, and I'm pretty sure the model part is right at least, as the converted pre-train weights achieved 85% (Thank you again for providing those weights btw). For the training part, I used torchvision's RandomResizedCrop and get 84.1%. I also tried to use the provided transform in tensorflow, where I converted it to ndarray before convert to pytorch tensor at last, in that case, it would be around 76%.

Can I please confirm the total batch size for 8x 16G V100 is 512? I have read some flax documentation and it looks really beautiful to me. However, I only have uncompressed ImageNet on my machine, and it seems this implementation requires compressed tar :(

from vision_transformer.

ZhiyuanChen avatar ZhiyuanChen commented on August 21, 2024

Can I please confirm the total batch size for 8x 16G V100 is 512? I have read some flax documentation and it looks really beautiful to me. However, I only have uncompressed ImageNet on my machine, and it seems this implementation requires compressed tar :(

def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps):
  """Accumulate gradient over multiple steps to save on memory."""
  if accum_steps and accum_steps > 1:
    assert images.shape[0] % accum_steps == 0, (
        f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
    step_size = images.shape[0] // accum_steps
    l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size])

    def acc_grad_and_loss(i, l_and_g):
      imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0),
                                   (step_size,) + images.shape[1:])
      lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0),
                                   (step_size, labels.shape[1]))
      li, gi = loss_and_grad_fn(params, imgs, lbls)
      l, g = l_and_g
      return (l + li, jax.tree_multimap(lambda x, y: x + y, g, gi))

    l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
    return jax.tree_map(lambda x: x / accum_steps, (l, g))
  else:
    return loss_and_grad_fn(params, images, labels)

The running batch size should be 512/accum_steps

from vision_transformer.

ZhiyuanChen avatar ZhiyuanChen commented on August 21, 2024

Exactly. That's why I always mentioned the --accum_steps together with the batch size. I'm sorry if that was ambiguous and caused you unnecessary debugging work!

lol, no worries, I was just a bit lazy to read the source code (which should have been done considering I'm trying to reproduce a result

Note that if you check out the source code of tfds you'll probably find an easy way to work with your uncompressed Imagenet data...

:( Thank you, I'll go ahead and find out~ (God I wish there could be a framework that combines the interface of PyTorch and the code quality of Google

You might also find this repo interesting:

https://github.com/lucidrains/vit-pytorch

Thank you for sharing it, I have checked it previously, but it seems our implementation is more close to the original design in this repo (do not change anything unnecessary).

from vision_transformer.

ZhiyuanChen avatar ZhiyuanChen commented on August 21, 2024

@andsteing Hope I'm not disturbing, can I ask what's the difference between inception_crop and random resized crop?

      if inception_crop:
        channels = im.shape[-1]
        begin, size, _ = tf.image.sample_distorted_bounding_box(
            tf.shape(im),
            tf.zeros([0, 0, 4], tf.float32),
            area_range=(0.05, 1.0),
            min_object_covered=0,  # Don't enforce a minimum area.
            use_image_if_no_bounding_boxes=True)
        im = tf.slice(im, begin, size)
        # Unfortunately, the above operation loses the depth-dimension. So we
        # need to restore it the manual way.
        im.set_shape([None, None, channels])
        im = tf.image.resize(im, [crop_size, crop_size])
      else:
        im = tf.image.resize(im, [resize_size, resize_size])
        im = tf.image.random_crop(im, [crop_size, crop_size, 3])
        im = tf.image.flip_left_right(im)

from vision_transformer.

andsteing avatar andsteing commented on August 21, 2024

It's similar but different. You can check out the difference in the Colab by setting input_pipeline.get_data(..., inception_crop=False) and then visually inspect the difference.

I have set inception_crop=True by default because that's what we used internally, but it might not actually make a big difference in terms of final performance (another thing you can try in the Colab).

from vision_transformer.

ZhiyuanChen avatar ZhiyuanChen commented on August 21, 2024

It's similar but different. You can check out the difference in the Colab by setting input_pipeline.get_data(..., inception_crop=False) and then visually inspect the difference.

I have set inception_crop=True by default because that's what we used internally, but it might not actually make a big difference in terms of final performance (another thing you can try in the Colab).

The reason why I'm asking this is because im = tf.image.random_crop(im, [crop_size, crop_size, 3]) looks not working at all as the image has been already resized to 384, 384 in the previous line.

I've just tested it locally and it seems to confirm my hypothesis

test code

for i in range(100):
    auged = tf.image.resize(im, [384, 384])
    auged = tf.image.random_crop(auged, [384, 384, 3])
    auged = tf.image.flip_left_right(auged)
    auged = auged.numpy().astype(np.uint8)
    cv2.imwrite(f'tf/{i}.jpg', auged)

for i in range(100):
    channels = im.shape[-1]
    begin, size, _ = tf.image.sample_distorted_bounding_box(tf.shape(im), tf.zeros([0, 0, 4], tf.float32),  area_range=(0.05, 1.0), min_object_covered=0, use_image_if_no_bounding_boxes=True)
    auged = tf.slice(im, begin, size)
    auged.set_shape([None, None, channels])
    auged = tf.image.resize(auged, [384, 384])
    auged = auged.numpy().astype(np.uint8)
    cv2.imwrite(f'inception/{i}.jpg', auged)

from vision_transformer.

ZhiyuanChen avatar ZhiyuanChen commented on August 21, 2024

@andsteing Thank you very much for helping
We found that it is resulted because the interpolation of posembed was ignored
We have achieved promising results for now, and will release the code once the review process has finished.

from vision_transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.