Git Product home page Git Product logo

Comments (7)

romanngg avatar romanngg commented on May 6, 2024 2

FYI, 100afac altered tensor layout and this might reduce the TPU memory footprint; in general, there's still work to do to fully eliminate padding, and GPUs are much recommended (see 100afac).

(Otherwise, speed was also improved by ~3-5X, which should allow to use smaller batches)

from neural-tangents.

romanngg avatar romanngg commented on May 6, 2024 1

FYI, we have finally added ConvTranspose in 780ad0c!

from neural-tangents.

sschoenholz avatar sschoenholz commented on May 6, 2024

Hey! Thanks for pushing on this. We'd love to iterate on this to get it working for you (though looking at the UNet architecture I am a bit concerned that the vanilla version violates some independence assumptions wrt FanInConcat).

A few things off the top of my head:

  1. Convolutions generically require quite a bit of storage (a dataset x pixels x dataset x pixels covariance matrix in the general case). It can be helpful to use the batching functionality to run on even modestly large datasets.

  2. Having said that, I think we ought to be able to do 10 images at a time! Looking at this stack trace you might notice lines like the following:

Size: 12.50G
     Operator: op_type="conv_general_dilated"
     Shape: f32[409600,1,64,64]{0,1,3,2:T(2,128)}
     Unpadded size: 6.25G
     Extra memory due to padding: 6.25G (2.0x expansion)

TPUs must store data in blocks of size 8 x 128. To fit arbitrary data into blocks of this size, XLA will often pad data. Here you can see that the raw size of the data is 6.25Gb, but it is getting padded by a factor of 2. I might recommend trying to run this on GPU rather than TPU and seeing whether the calculation will fit into memory since GPUs don't need to pad. Generally, we have not figured out a way of phrasing our convolutions in a way that doesn't get padded by the TPU (since our channel count is 1). This is an ongoing area of work, but I have to say we have limited tools at our disposal to make progress here (though maybe @romanngg can comment if he's more hopeful than myself).

Let us know how the GPU works. Glancing at the sizes I would expect it to easily fit on a V100 (since it has 32 Gb of RAM whereas this calculation is consuming around 19Gb unpadded).

from neural-tangents.

romanngg avatar romanngg commented on May 6, 2024

+1 to Sam re padding, and also note that even unpadded, the intermediary NNGP covariance of shape 10x10x(64x64)x(64x64) is 6.25 Gb. To propagate this tensor through the NNGP computation from one layer to the next, you need 2X of that. Unfortunately, due to JAX internals in practice it requires 3X (see & upvote google/jax#1733, google/jax#1273), which results in peak memory consumption of 19 Gb, which would require a 32Gb GPU (note that V100s come in 16 and 32 GB varieties, so even it may not be enough). For this reason you'd probably need to work on even smaller batches in this case (see nt.batch), or reduce the image sizes.

from neural-tangents.

kayhan-batmanghelich avatar kayhan-batmanghelich commented on May 6, 2024

Hi @sschoenholz ,

My understanding from #16 was that FanInConcat is theoritically OK, and also my superficial understanding of the Greg Yang paper was that these kinds of linear operations do not break the theory, but I might be totally wrong.

Since there was not FanInConcat, I implemented the UNet using ConvTraspose which resulted in an increase in parameters and less stable SGD training. However, that is different than getting kernel nngp kernel for ten samples. I will re-run in GPU and report back.

Thanks

from neural-tangents.

romanngg avatar romanngg commented on May 6, 2024

FYI, we've just added FanInConcat support in c485052!

Two caveats:

  1. When concatenating along the channel/feature axis, a Dense or Convolutional layer is required afterwards (so you can't have [assuming NHWC, channel axis -1] stax.serial(..., stax.FanInConcat(axis=-1), stax.Relu(),...) for now - this might be implemented later).
  2. This will not reduce the memory footprint (which should be identical to FanInSum for channel axis concatenation, and larger for spatial or batch axis concatenation).

from neural-tangents.

n17dccn151 avatar n17dccn151 commented on May 6, 2024

Hello @kayhan-batmanghelich, I am currently learning about ntk as well as unet network, do you mind if you share the notebook of colab ? Thank you very much

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.