Git Product home page Git Product logo

Comments (3)

cgarciae avatar cgarciae commented on July 17, 2024 1

Hey! In the NNX Transforms section I think we make it clear that you must use the nnx.* version of the transform when using NNX objects. Is there something we should improve here?

from flax.

cgarciae avatar cgarciae commented on July 17, 2024 1

I think currently you want this:

(loss_values, (y_preds, the_energies, accuracies)), grads = nnx.vmap(
  loss_value_and_grad_fn,
  in_axes=(None, 0, 0),
  state_axes={...: None},  # <<== add this
)

But in the near future of #3963 , you will be able to do

(loss_values, (y_preds, the_energies, accuracies)), grads = nnx.experimental.vmap(
  loss_value_and_grad_fn,
  in_axes=(None, 0, 0), # <<== just works
)

and hopefully experimental becomes the norm

from flax.

bionicles avatar bionicles commented on July 17, 2024

yes, you're right, the docs are clear to use the nnx transforms

I hit an issue where, I just want to map a function over each item in a batch, and when using the vmap transform, the first time I call the transformed function, it receives a single slice of the batch, which is what I was going for, but then when it gets called again, it receives the whole batch at once, which caused it to crash due to a shape issue in a call to "concatenate"

I reckon it's using an XLA version of the same function on the second invocation

For what to improve, could the docs have a section dedicated to advice about how to work with batches of data?

I was looking for more focused troubleshooting docs to deal with shape bugs in handling batches with vmap, or advice on how to apply a function over slices of a batch independently without having to change the function to handle the whole batch at once

(it's an EBM / energy based model, so I don't want to inappropriately mix energy gradients across the members of the same batch, and it breaks if I try to apply the single-item function to the whole batch)

Might just be my general inexperience with jax and a vmap issue and not a nnx issue; maybe i am missing something about how vmap works

from flax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.