Git Product home page Git Product logo

Comments (1)

levskaya avatar levskaya commented on April 27, 2024

Sorry, it took me a bit to figure out what was going on.
A Model should be pmap'able - what's happening here is a bit of a subtle bug:

First, a short-term "fix" is just wrapping it in a lambda passthrough:

import jax
from flax import nn
layer=nn.Dense.partial(features=1)
key=jax.random.PRNGKey(0)
x=jax.random.normal(key, (4, 20, 2))
_,params=layer.init(key, x[0,...])
layer_m=nn.Model(layer, params)
jax.pmap(lambda z: layer_m(z))(x)

Now, what's going on:

  • in a great change google/jax#2073 made ~2 months ago to improve XLA call stack metadata JAX tries to get the __name__ attribute from the pmap'd function, which in this case is our callable Model instance.
  • the problem is that in another refactoring of the base flax code a month ago baf43e7 we override __getattr__ on Model to passthrough and grab the requested attr from Module, but inside that we are trying to eval issubclass(fetched_attr, flax.nn.Module) and issubclass(<string object>, flax.nn.Module) throws an error in python since it's nonsense.

We almost always use a Model inside an optimizer or indirectly in another function, and I think we must not have a unit test of a direct jit/pmap on a Model - my apologies for letting this slip through, we'll try to get a fix in asap.

from flax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.