Git Product home page Git Product logo

Comments (12)

sschoenholz avatar sschoenholz commented on July 28, 2024

Thanks for taking the time to try out NT and raise this issue! I think it is likely that a layer-wise scheme for computing the NTK will be more memory efficient than computing jacobian(vjp(jvp(fn)) all at once (since the latter requires at least as much memory as computing logit gradients). We went with our approach because (especially at the time) it seems hard to handle arbitrary functions with a layerwise scheme. However, we've been thinking about ways to improve our empirical NTK computation in the general case and your issue will certainly help us motivate this work.

While I am not all that surprised by the memory issues, I do find your computation time problems very surprising (my assumption would have been that while our memory usage might be higher than PyTorch, our computation time should be competitive). Not only that but apparently NT is like 600x slower, which seems a touch slow!

Is it possible for you to share more information about your current usage? In particular, if you feel comfortable doing so, it would be extremely helpful if you could share your model + example with us (and sharing the PyTorch code so that we could benchmark would be even better). If you'd rather not share on github, sending something to me (schsam at google.com) or Roman (romann at google.com) would be much appreciated.

In the meantime, if you port your layer-wise kernel code into JAX to produce a kernel_fn(x1, x2) then you ought to be able to plug it directly into the batching code and have it work.

from neural-tangents.

maciejkorzepa avatar maciejkorzepa commented on July 28, 2024

Thanks for the quick reply!

Here's a gist to reproduce the issue: https://gist.github.com/maciejkorzepa/23bbb4443c2fcb93927eb36d9ef2091c

I can also share the PyTorch code if that would help but I would need to edit it a bit first.

I would be happy to port my layer-wise code into JAX, but I just used JAX today for the first time, so it might take me a while before I know how to do it... Also, the rest of the code is in PyTorch and I don't think I have the time to change everything to JAX so my idea was to create an identical model in JAX, load the paramters from PyTorch model into it and compute the kernel and transform it into PyTorch vector again.

from neural-tangents.

maciejkorzepa avatar maciejkorzepa commented on July 28, 2024

Here's PyTorch code that computes the kernel for 80 examples below 1 second on 12GB Titan X GPU:
https://gist.github.com/maciejkorzepa/af3b5ef3676d982f8c4298f224960c53

from neural-tangents.

maciejkorzepa avatar maciejkorzepa commented on July 28, 2024

I found out some new things:

  1. non-batched version works very fast as long as it fits in the memory (and the input has always the same shape)
  2. batched_version is very slow
  3. computations for 10x80 kernel fit into memory while for 25x25 don't fit which is strange as the latter is smaller.
  4. if I evaluate the kernel function for the same input multiple times, it gets very slow (and if the input changes, it says very fast each time). Probably there's no use case for doing that, but still it's strange to me.

You can see these issues and findings here:
https://gist.github.com/maciejkorzepa/4c4bb6c445dc41449f90df14f04e67a9

Also, completely not related to the issue I believe but I seem to be not able to import jax without errors about dynamic libraries if I don't import torch first. It looks like importing torch sets up some paths correctly without which jax can't be imported.

from neural-tangents.

sschoenholz avatar sschoenholz commented on July 28, 2024

Thanks for following up! I've been digging into the code and profiling. While I don't have a solution yet, here are some comments on your investigations:

  1. I think this observation may actually be due to JAX's asynchronous dispatch which is hiding the actual work from your timing code (https://jax.readthedocs.io/en/latest/async_dispatch.html). Try adding a .block_until_ready() to the kernel computation before ending the timing code.

  2. Related to 1, in my profiling I find that that while batching is slower than computing the whole kernel in one shot it's probably not as bad as your profiling would suggest (owing to 1). To give some numbers for context: computing a 20x20 kernel with no batching takes ~3s, with batch size 10 this goes up to 4.5s and with batch size 5 this goes up to 7.5s. Together this points to an overhead of about 0.3s per batch. We should definitely try to get this number down! However, my main takeaway from this discussion is that it is very advantageous to compute the empirical NTK in a layer-wise fashion where possible.

A small note that we will fix imminently is that when you call

nt.batch(jit(nt.empirical_ntk_fn(apply_fn)), batch_size=10, device_count=1)

it actually invokes the parallel batching code which has more overhead than the serial batching code (about 0.45 s per batch). For now I think this invocation uses the serial code:

nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=10)
  1. That is very interesting! Will investigate.

  2. Is also strange / interesting. I suspect, looking at your timing, that this might be due to the aforementioned asynchronous dispatch. Does it still happen with a block_until_ready()?

One final observation that is orthogonal to the current discussion is that writing x_train = np.random.randn(n, d, d, 3).astype(np.float32) generates the data on CPU and then requires a transfer to GPU. You can either cast to a JAX numpy array:

import jax.numpy as jnp
x_train = jnp.array(np.random.randn(n, d, d, 3).astype(np.float32))

which will still invoke a transfer but only the first time you use the array. An alternative method would be to generate the training data using JAX's rng:

from jax import random
x_train = random.normal(random.PRNGKey(0), (n, d, d, 3))

The issue with importing JAX and pytorch simultaneously is interesting. I'm sure the JAX team would be interested if you wouldn't mind making an issue over there. Otherwise, let me know and I can bug them in person / raise an issue.

from neural-tangents.

maciejkorzepa avatar maciejkorzepa commented on July 28, 2024

Thanks for investigating the issue! Actually, asynchronous dispatch does explain some of my observations - the non-batched version is quite slow after I synchronize before measuring time. Also, the execution time is the same now both for new and previous inputs.

Did you time the computation on the model that I provided? If yes, what GPU? I'm getting much higher run times - non-batched version for 20x20 kernel takes ~6s and the batched version with batch size 10 takes ~8s (and ~15s for batch size 5). Also, the first run (for non-batched version) with jit compilation takes ~40s . Is it expected to take this long?

Regarding, the jax / torch import issue, it's gone now after I reinstalled jax for cuda 10.2. The cluster I'm using has many version of cuda available, so perhaps it was the issue.

I also made a small attempt to implement layer-wise kernel computation (on a toy example):
https://gist.github.com/maciejkorzepa/02403aacfd53a92f0e2a202daefd86f6

Currently, I can compute the kernel of each layer for jacobians of the output of that layer instead of jacobians of the output of the last layer. Consequently, the shapes of kernels of all layers but last are wrong. I could do it for the last layer, but propagating inputs from the processed layers through all layers till the last one repeated for each layer sounds like doing things very inefficiently. Anyway, as it's my first experience with writing jax code, I would appreciate very much if you could give me some hints how to move forward and how to do it in an efficient way. I don't need to make a general solution for all architectures, but only for some networks that I can define like in the gist I shared.

from neural-tangents.

sschoenholz avatar sschoenholz commented on July 28, 2024

Ok! So I think I may have made some progress. I would like to understand why NT is slower than the sample you provided and then, separately, think about other timing issues so let's revisit those aspects of your latest post a bit later.

Let's think about what we're doing when we do backpropagation. During backpropagation for a composed function, we iteratively take a cotangent from the output of one function to the cotangent of its input using the vector-jacobian product. The cotangent of the output should have the same dimension as the output and the cotangent of the input should have the same dimension as the input.

So far so good. Now, what should we choose as the output cotangent for the whole function (the starting point of backpropagation)? If the function outputs a scalar (e.g. a loss) then the choice is unambiguous and we choose the 1d-vector [1.0]. What if there are more outputs though (e.g. in your case, there is a batch dimension)? There is not a "correct" choice, because it depends on what we are trying to compute.

For this reason, in JAX if you take the gradient of a vector-valued function then it will raise an error. In what I think is a very unfortunate choice, TensorFlow (and now I assume PyTorch), will not raise an error but will implicitly use an all-ones cotangent, [1.0, 1.0, ..., 1.0]. What does this end up doing? This ends up producing a gradient that is the sum, $\sum_i \partial_j f_i(x)$. In your case, this amounts to computing the sum over all the inputs (even though you will get gradients that have the right shape during backprop). Thus, if you look at the "NTK" produced by your PyTorch approach you will find that all of the outputs for the different entries are identical.

I am therefore pretty sure that the reason why your code is so much faster is because it's not doing work that must be done. To see how JAX solves this problem check out their jacrev implementation here. They do one backpropagation per output using the standard basis (e.g. for the first output they use the cotangent [1.0, 0.0, ..., 0.0] etc...). Of course they don't use a for-loop since that might be slow but instead use vmap to vectorize over the different basis vectors.

So what does this mean? This means, at least, that our speeds / memory usage are not so different in the sense that you likely need to iterate your approach once per input which will incur a factor of around 100 slowdown. Nonetheless, I do believe there ought to be some advantage we can gain from doing in place updating of the NTK during backprop but I anticipate it being less of an advantage than we might expect.

What do you think? Does this make sense to you or have I missed something?

from neural-tangents.

maciejkorzepa avatar maciejkorzepa commented on July 28, 2024

I'm convinced that my pytorch implementation is doing the right thing - I calculated the relative error wrt to kernel computed in the most naive way, that is calculating Jacobians one at a time, collecting them into a matrix J and calculating the matrix product [email protected](). You can see the updated part in the gist in the very end.
https://gist.github.com/maciejkorzepa/af3b5ef3676d982f8c4298f224960c53

The reason why I get (almost) the same values in the kernel is, I think, simply because both the inputs and the parameters are random high dimensional. Values in neural tangents' NTK kernel are also almost the same (well, slightly more different but that probably depends on how we initialize network parameters).

To calculate layer-wise Jacobians in PyTorch, I use its autograd to back-propagate and obtain the gradients of the final output wrt to the model parameters. To calculate the Jacobian of layer m, the autograd needs to calculate the derivative of the final function wrt to the output of layer m. Then it calculates vjp, where vector v is that derivative. The autograd is by default calculating vjp accummulating Jacobian over the batch dimension, but by using hooks I can process the intermediate derivatives myself and calculating the Jacobian of the final output wrt layer's parameters without summing over the batch dimension (or like I did for linear layers I can calculate JJ^T implicitly).

Actually, I did some calculations and realized that even if I can compute 100x100 batch of a kernel in 1 second, it will take like a day or half to calculate the kernel for 30000 examples. So I'm thinking I might be better off with calculating Jacobians on GPU and storing them in RAM, instead of recomputing them excessively. For 30000 examples and a model with 2.5M parameters the Jacobians (for a single output) will be almost 300 GB and that's relatively close to what I have at my disposal, so I might just go with batch size of 15k and have just a couple of forward and backward passes over the full dataset rather than hundreds.

However it would be great if neural tangents could also handle such big problems!

from neural-tangents.

sschoenholz avatar sschoenholz commented on July 28, 2024

Thanks for adding the check! You're clearly correct and you've come up with a super clever method! For my own sanity, I'll have to do some digging to figure out what exactly is going on. In the meantime, though, I took a stab at getting your technique working in JAX. Here is a preliminary version that I've been playing around with. It's still slow compared with your PyTorch version, jit-compilation takes too long, and the batching overhead is still too high. Nonetheless, progress! We'd love to add something based on this method for computing empirical kernels into NT (unclear whether the sketch above is how we'd want the implementation to go). Would you be interested in working with us on the implementation (if you don't want to get pulled from your research we will of course credit you in whatever solution we end up implementing)?

from neural-tangents.

maciejkorzepa avatar maciejkorzepa commented on July 28, 2024

Thanks for working on the new implementation! It looks great and I can see it does basically exactly what I'm doing in PyTorch. I was wondering if you can get implicit JJ^t for convolutional layer similarly as for dense layers, but it doesn't seem that straightforward so I just followed with explicit JJ^t. If there's some implicit way to do it, I guess it would be more beneficial for the higher-level layers when the feature maps are rather small and weights are in huge numbers. For early layers, it's usually the reverse - the feature maps are very large and there are fewer weights (due to lower number of channels) and explicit JJ^t might not be very expensive for them.

I would love to get more involved in the implementation if possible although I might need to catch up with JAX a bit more and thoroughly understand your implementation draft first.

from neural-tangents.

uditsaxena avatar uditsaxena commented on July 28, 2024

I'm also interested in something similar.

I want to calculate the empirical NTK for k outputs/classes for n where n could be the size of MNIST for large dense networks. Because these networks are a few layers deep, with, say 10 classes, and the width varies from 10 to 10000, computing the NTK is turning out to be unfeasible for me since the output NTK would be of the size of nk X nk
I came across this post and went through the gists mentioned.

Correct me if I'm wrong, but the code in the gist here (same as above) (https://gist.github.com/maciejkorzepa/af3b5ef3676d982f8c4298f224960c53) to calculate the NTK by accumulating it per layer is incorrect, right? If we're accumulating the per layer kernels, we're ignoring the weights interactions across layers, right? That would mean the NTK computation would be incorrect. I tried to compute the NTK naively using torch.autograd and it doesn't show the same value.

Have I misunderstood?

from neural-tangents.

romanngg avatar romanngg commented on July 28, 2024

Thanks a lot for pointing out this issue and providing such a detailed repro!

I haven't followed the layerwise Jacobian idea discussion, so can't comment on it, but @sschoenholz and @maciejkorzepa discussion about how the Jacobian is computed in JAX indeed pointed out a way we can speed it up, namely if there are no interactions between different batch elements (e.g. no batch norm), then we can pull back an all-ones vector along the batch dimension (not the logit dimension), instead of the full identity matrix. This can now be specified with the new vmap_axes argument passed to nt.empirical_ntk_fn (see more at https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html and f15b652)

With this setting the example above: (thanks again for writing it up!!)
https://gist.github.com/maciejkorzepa/23bbb4443c2fcb93927eb36d9ef2091c
can be now computed in under 1 second in NT:
https://colab.research.google.com/gist/romanngg/66574cca1dc1a6a7781c14745aeb1141/empirical_ntk_speedup.ipynb

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.