Comments (3)
Hey! In the NNX Transforms section I think we make it clear that you must use the nnx.*
version of the transform when using NNX objects. Is there something we should improve here?
from flax.
I think currently you want this:
(loss_values, (y_preds, the_energies, accuracies)), grads = nnx.vmap(
loss_value_and_grad_fn,
in_axes=(None, 0, 0),
state_axes={...: None}, # <<== add this
)
But in the near future of #3963 , you will be able to do
(loss_values, (y_preds, the_energies, accuracies)), grads = nnx.experimental.vmap(
loss_value_and_grad_fn,
in_axes=(None, 0, 0), # <<== just works
)
and hopefully experimental becomes the norm
from flax.
yes, you're right, the docs are clear to use the nnx transforms
I hit an issue where, I just want to map a function over each item in a batch, and when using the vmap transform, the first time I call the transformed function, it receives a single slice of the batch, which is what I was going for, but then when it gets called again, it receives the whole batch at once, which caused it to crash due to a shape issue in a call to "concatenate"
I reckon it's using an XLA version of the same function on the second invocation
For what to improve, could the docs have a section dedicated to advice about how to work with batches of data?
I was looking for more focused troubleshooting docs to deal with shape bugs in handling batches with vmap, or advice on how to apply a function over slices of a batch independently without having to change the function to handle the whole batch at once
(it's an EBM / energy based model, so I don't want to inappropriately mix energy gradients across the members of the same batch, and it breaks if I try to apply the single-item function to the whole batch)
Might just be my general inexperience with jax and a vmap issue and not a nnx issue; maybe i am missing something about how vmap works
from flax.
Related Issues (20)
- Cannot load checkpoint saved in flax 0.5.3 with flax 0.6.1
- Not sure if flax use the GPU HOT 2
- New python venv fails to pip install the flax mnist example requirements HOT 3
- Best practice of dealing with sporadic FrozenDict conversions? HOT 2
- NNX + mutable state + JIT = "Cannot mutate <module name> from a different trace level"
- `serialization.from_state_dict` does not restore to jax.Arrays HOT 1
- Exceptions are not pickle-able
- Force `fp32` in `attention.MultiHeadDotProductAttention` for softmax operator HOT 1
- Will nnx.MultiHeadAttention support flash att? HOT 3
- `save_checkpoint` fails with the most recent orbax release HOT 2
- flax nn.tabulate Incorrectly Reports FLOPs and VJP FLOPs HOT 2
- Suboptimal default initialization of q/k/v projections in `nn.MultiHeadDotProductAttention` HOT 2
- lstm error HOT 2
- Feature request: Mixture of Experts example HOT 1
- Significant performance difference of NNX relative to equinox HOT 15
- typo in nnx_basics.md HOT 3
- Opaque XLA crash when initializing model HOT 1
- Is there anyway to analyze activations in flax? HOT 1
- Dropout seems not compatible with jax.jit
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.