Comments (10)
@adarob could you provide a minimal repro for this?
from flax.
Doesn't work:
@jax.jit
def rnd():
return (jax.random.randint(nn.make_rng(), (5,), 0, 10),
jax.random.randint(nn.make_rng(), (5,), 0, 10))
with nn.stochastic(jax.random.PRNGKey(0)):
for _ in range(5):
print(rnd())
Output:
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
Works
@jax.jit
def rnd(rng):
with nn.stochastic(rng):
return (jax.random.randint(nn.make_rng(), (5,), 0, 10),
jax.random.randint(nn.make_rng(), (5,), 0, 10))
with nn.stochastic(jax.random.PRNGKey(0)):
for _ in range(5):
print(rnd(nn.make_rng()))
Output:
(DeviceArray([8, 5, 6, 6, 7], dtype=int32), DeviceArray([4, 9, 7, 1, 5], dtype=int32))
(DeviceArray([9, 3, 1, 6, 0], dtype=int32), DeviceArray([6, 0, 5, 3, 9], dtype=int32))
(DeviceArray([2, 7, 8, 8, 1], dtype=int32), DeviceArray([9, 2, 5, 0, 6], dtype=int32))
(DeviceArray([0, 1, 2, 8, 1], dtype=int32), DeviceArray([5, 4, 6, 1, 1], dtype=int32))
(DeviceArray([1, 8, 4, 8, 3], dtype=int32), DeviceArray([1, 3, 6, 6, 4], dtype=int32))
from flax.
In that second example you meant to write rng
in place of nn.make_rng()
, no?
from flax.
I don't know which line you're referring to but it looks like what I intended.
from flax.
Ah, my apologies I misread it on the first read.
from flax.
This is part of a larger issue concerning mixing states and jax transformations. nn.stochastic
should throw an exception in this case because mixing jax transformations and internal state are ambigious. I will make a PR for this but it might lead to some false positives that need to be fixed.
from flax.
I think #125 is the PR that should address this.
Effectively, it should make your code @adarob throw an explicit error, and then you can decide how to deal with the PRNGs. E.g. if you're using vmap you will have to explicitly choose whether you split them or reuse the PRNG.
from flax.
Btw we are also looking into automatically supporting things like stateful and stochastic in combination with jax transforms together with the Haiku folks and the jax core team. But for know we just try to avoid silent errors
from flax.
@jheek assigning to you because I believe you're looking into this
from flax.
nn.stochastic
correctly throws an error but it does now extend into init_by_shape
(as of PR #159).
from flax.
Related Issues (20)
- Incorrect flop count estimate when using flax.linen.remat_scan HOT 1
- Multiple initializations. Is this a bug? HOT 5
- Improve documentation of nn.vjp's vjp_variables parameter? HOT 1
- Full Module RNG Scope HOT 2
- Add a `mask` argument to `RNN.__call__` HOT 1
- Add missing `mask` arguments to `__call__` methods of normalization modules HOT 2
- How to optimize multi modelโs parameter in one optimizer ? HOT 1
- Bad interaction between nnx.Rngs and custom derivatives HOT 6
- Feature request: optionally sow attention weight in `dot_product_attention`
- Bug in guide "Scale up Flax Modules on multiple devices" flax.errors.InvalidRngError: RNGs should be of shape (2,) HOT 3
- NNX Conv documentation states that kernel size can be an integer for 1D convs. HOT 1
- Bug in Tutorial: Jax 101 Working with PyTrees "AttributeError: module 'jax.tree_util' has no attribute 'register_static'" HOT 5
- flax is significantly slower than pytorch HOT 1
- NNX attention layer missing `qkv_features` arg HOT 2
- Metadata in `nnx.param_field` is not passed to the Param class
- Swapping selected layers with a different layer. HOT 1
- Quick start badges point to 404 HOT 1
- A guide on distributed training
- Proposal: StackedRNNCell HOT 3
- Is it an API like model.summary in Pytorch? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.