Git Product home page Git Product logo

Comments (10)

AlexeyG avatar AlexeyG commented on April 27, 2024

@adarob could you provide a minimal repro for this?

from flax.

adarob avatar adarob commented on April 27, 2024

Doesn't work:

@jax.jit
def rnd():
  return (jax.random.randint(nn.make_rng(), (5,), 0, 10), 
          jax.random.randint(nn.make_rng(), (5,), 0, 10))

with nn.stochastic(jax.random.PRNGKey(0)):
  for _ in range(5):
    print(rnd())

Output:

(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))

Works

@jax.jit
def rnd(rng):
  with nn.stochastic(rng):
    return (jax.random.randint(nn.make_rng(), (5,), 0, 10),
            jax.random.randint(nn.make_rng(), (5,), 0, 10))

with nn.stochastic(jax.random.PRNGKey(0)):
  for _ in range(5):
    print(rnd(nn.make_rng()))

Output:

(DeviceArray([8, 5, 6, 6, 7], dtype=int32), DeviceArray([4, 9, 7, 1, 5], dtype=int32))
(DeviceArray([9, 3, 1, 6, 0], dtype=int32), DeviceArray([6, 0, 5, 3, 9], dtype=int32))
(DeviceArray([2, 7, 8, 8, 1], dtype=int32), DeviceArray([9, 2, 5, 0, 6], dtype=int32))
(DeviceArray([0, 1, 2, 8, 1], dtype=int32), DeviceArray([5, 4, 6, 1, 1], dtype=int32))
(DeviceArray([1, 8, 4, 8, 3], dtype=int32), DeviceArray([1, 3, 6, 6, 4], dtype=int32))

from flax.

levskaya avatar levskaya commented on April 27, 2024

In that second example you meant to write rng in place of nn.make_rng(), no?

from flax.

adarob avatar adarob commented on April 27, 2024

I don't know which line you're referring to but it looks like what I intended.

from flax.

levskaya avatar levskaya commented on April 27, 2024

Ah, my apologies I misread it on the first read.

from flax.

jheek avatar jheek commented on April 27, 2024

This is part of a larger issue concerning mixing states and jax transformations. nn.stochastic should throw an exception in this case because mixing jax transformations and internal state are ambigious. I will make a PR for this but it might lead to some false positives that need to be fixed.

from flax.

avital avatar avital commented on April 27, 2024

I think #125 is the PR that should address this.

Effectively, it should make your code @adarob throw an explicit error, and then you can decide how to deal with the PRNGs. E.g. if you're using vmap you will have to explicitly choose whether you split them or reuse the PRNG.

from flax.

jheek avatar jheek commented on April 27, 2024

Btw we are also looking into automatically supporting things like stateful and stochastic in combination with jax transforms together with the Haiku folks and the jax core team. But for know we just try to avoid silent errors

from flax.

avital avatar avital commented on April 27, 2024

@jheek assigning to you because I believe you're looking into this

from flax.

jheek avatar jheek commented on April 27, 2024

nn.stochastic correctly throws an error but it does now extend into init_by_shape (as of PR #159).

from flax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.