Comments (7)
FYI, 100afac altered tensor layout and this might reduce the TPU memory footprint; in general, there's still work to do to fully eliminate padding, and GPUs are much recommended (see 100afac).
(Otherwise, speed was also improved by ~3-5X, which should allow to use smaller batches)
from neural-tangents.
FYI, we have finally added ConvTranspose
in 780ad0c!
from neural-tangents.
Hey! Thanks for pushing on this. We'd love to iterate on this to get it working for you (though looking at the UNet architecture I am a bit concerned that the vanilla version violates some independence assumptions wrt FanInConcat).
A few things off the top of my head:
-
Convolutions generically require quite a bit of storage (a dataset x pixels x dataset x pixels covariance matrix in the general case). It can be helpful to use the batching functionality to run on even modestly large datasets.
-
Having said that, I think we ought to be able to do 10 images at a time! Looking at this stack trace you might notice lines like the following:
Size: 12.50G
Operator: op_type="conv_general_dilated"
Shape: f32[409600,1,64,64]{0,1,3,2:T(2,128)}
Unpadded size: 6.25G
Extra memory due to padding: 6.25G (2.0x expansion)
TPUs must store data in blocks of size 8 x 128. To fit arbitrary data into blocks of this size, XLA will often pad data. Here you can see that the raw size of the data is 6.25Gb, but it is getting padded by a factor of 2. I might recommend trying to run this on GPU rather than TPU and seeing whether the calculation will fit into memory since GPUs don't need to pad. Generally, we have not figured out a way of phrasing our convolutions in a way that doesn't get padded by the TPU (since our channel count is 1). This is an ongoing area of work, but I have to say we have limited tools at our disposal to make progress here (though maybe @romanngg can comment if he's more hopeful than myself).
Let us know how the GPU works. Glancing at the sizes I would expect it to easily fit on a V100 (since it has 32 Gb of RAM whereas this calculation is consuming around 19Gb unpadded).
from neural-tangents.
+1 to Sam re padding, and also note that even unpadded, the intermediary NNGP covariance of shape 10x10x(64x64)x(64x64) is 6.25 Gb. To propagate this tensor through the NNGP computation from one layer to the next, you need 2X of that. Unfortunately, due to JAX internals in practice it requires 3X (see & upvote google/jax#1733, google/jax#1273), which results in peak memory consumption of 19 Gb, which would require a 32Gb GPU (note that V100s come in 16 and 32 GB varieties, so even it may not be enough). For this reason you'd probably need to work on even smaller batches in this case (see nt.batch
), or reduce the image sizes.
from neural-tangents.
Hi @sschoenholz ,
My understanding from #16 was that FanInConcat
is theoritically OK, and also my superficial understanding of the Greg Yang paper was that these kinds of linear operations do not break the theory, but I might be totally wrong.
Since there was not FanInConcat
, I implemented the UNet using ConvTraspose
which resulted in an increase in parameters and less stable SGD training. However, that is different than getting kernel nngp
kernel for ten samples. I will re-run in GPU and report back.
Thanks
from neural-tangents.
FYI, we've just added FanInConcat
support in c485052!
Two caveats:
- When concatenating along the channel/feature axis, a Dense or Convolutional layer is required afterwards (so you can't have [assuming
NHWC
, channel axis-1
]stax.serial(..., stax.FanInConcat(axis=-1), stax.Relu(),...)
for now - this might be implemented later). - This will not reduce the memory footprint (which should be identical to
FanInSum
for channel axis concatenation, and larger for spatial or batch axis concatenation).
from neural-tangents.
Hello @kayhan-batmanghelich, I am currently learning about ntk as well as unet network, do you mind if you share the notebook of colab ? Thank you very much
from neural-tangents.
Related Issues (20)
- How to do Aggregate on a Graph whose nodes are all vectors HOT 6
- The analytical output of GP can not fit the result of NNGP generated by the nt.predict.gp_inference HOT 1
- Question: Relu Kernel Computation HOT 3
- Question: Connection MLE "parametrized" GP in infinite Width Limit vs minimizing MSE "parametrized" Kernel in infinite Width HOT 4
- Question regarding OOM issues HOT 3
- Question regarding lr in Neural Tangents Cookbook
- eNTK implementation uses deprecated xla attribute HOT 2
- Colab notebooks issue HOT 2
- How to obtain aleatoric uncertainty? HOT 2
- How to compute the empirical after kernel? HOT 1
- pip install issues HOT 2
- Erf function goes beyond [-1,1] HOT 2
- using stax.Cos(a=1.0, b=1.0, c=0.0) to get kernel from conv layer gives error HOT 2
- NTK is not PD
- stax.serial PSDness HOT 1
- How to use batch to gradient_descent_mse_ensemble ? HOT 1
- NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation HOT 7
- NKT_mean output Nan, when the number of training sample is increased HOT 3
- Inefficient jacobian computation for embedding layers. HOT 1
- Question regarding the cookbook
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neural-tangents.