Git Product home page Git Product logo

Comments (32)

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024 4

πŸ’― to this change. It aligns the mental model with TF2's tf.Module and PyTorch's nn.Module a lot more, and both of these have converged to where they are now after many years of mistakes, so this is a good thing.

(NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when call is invoked a second time, unless .shared() was called on the module instance first)

Please don't. Re-using an instance is the common, intuitive, friction-less way of sharing weights; this would just add annoying overhead for the sake of avoiding a mistake which, frankly, I have never encountered. An explicit :share method was how it was done in Torch7, and it was annoying and painful and does not exist anymore in PyTorch.

Regarding the __init__ vs __call__ separation, I don't think that it makes good code impossible, so if someone creates a monster hydra code because of that, it's probable the author's fault, not the library's. Using dataclass (or attr.s) for this is an interesting idea. However, usually what is done in __init__ is just normalizing of convenience of parameters, for example allowing filter-size to be passed as (3,3) or as 3, and then turning 3 into (3,3) in __init__, such that __call__ is cleaner to read, and really you can skip reading __init__ with that in mind. I think this is a good thing.

Finally, I think you can have an even more convincing example for modules which have more than just the obvious __call__, like the VAE example here which currently is not trivial to understand: I either have to do a lot of guess-work about FLAX internals, or go back and read the whole docs. Whereas after your proposal (and in PyTorch) it can be much more straightforward.

from flax.

srush avatar srush commented on April 27, 2024 3

Nice, I like this change. It is a good start.

However, if you are making such a breaking change, this feels too conservative.

Core Issues:

  • This function still violates Pythonic conventions. nn.Dense is seemingly making mutable changes to some internal state buffer that is invisible to the user and not transparent in the syntax. (I know this happens in TF, but flax should be better.)
  def __call__(self, x):
    x = nn.Dense(features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(features=16)(x)

=> Does this mean?

  def __call__(self, x):
    x = nn.Dense(self, features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(self, features=16)(x)

(Or alternatively pytorch / sonnet 2 syntax which both do this better)

  • Params are still treated differently than Layers, and use a string-based naming which seems dangerous and tempting for abuse.
bias = self.param('bias', (self.features,), self.bias_init) 

=> ?

bias = nn.Param(self, (self.features,), self.bias_init)

from flax.

akolesnikoff avatar akolesnikoff commented on April 27, 2024 1

I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive.

Besides, I actually see a separation between __init__ and __call__ as a potential win. Conceptually, I imagine that __init__ should admit static parameters, like number of channels, and __call__ should admit actual data that is processed. Currently, these different types of parameters are all mixed together.

from flax.

shoyer avatar shoyer commented on April 27, 2024 1

Clearly in apply_a we expect the Dense parameters to be shared when we call apply_a multiple times on an instance of MyModule but not between iterations of the loop. But what if we take the dense_factory from apply_a and turn it into a method (_dense_factory)? A seemingly innocent refactor will now cause all the Dense modules in apply_b to be shared.

My expectation from reading this code is that all Dense parameters in both examples would be unshared. If you want to use the same parameters, you need to use the same Dense object.

from flax.

srush avatar srush commented on April 27, 2024 1

Speaking as a teacher, the XLA docs scare me. They are very jargon heavy. It would be like asking numpy students to read the blas docs.

from flax.

srush avatar srush commented on April 27, 2024 1

Thanks @avital !

I really like the new api, thanks for putting the work into it and being direct about the tradeoffs. I will definitely be using it for my next project. (probably without @nn.compact , but that is totally okay if they are compatible).

I found this helpful: https://colab.research.google.com/github/google/flax/blob/master/docs/notebooks/linen_intro.ipynb

from flax.

avital avatar avital commented on April 27, 2024

I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive.

Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as jit and pmap which don't work with mutations.

Besides, I actually see a separation between init and call as a potential win. Conceptually, I imagine that init should admit static parameters, like number of channels, and call should admit actual data that is processed. Currently, these different types of parameters are all mixed together.

Yes. The issue is that by simply letting people use __init__ and __call__ arbitrarily, you many times end up with things like this, where you really have to move up and down many times to be able to fully follow the flow of what the module's forward pass does. Hence the restriction from using dataclasses encourages the __init__ to be as dumb as possible.

from flax.

danielsuo avatar danielsuo commented on April 27, 2024

Thanks for this proposal! I agree with the other comments:

  • How __init__ and __call__ might separate responsibility (user-created monster hydras not withstanding @lucasb-eyer)
  • Removing .shared(). I understand the rationale for keeping (one less thing to debug), but in this case, it makes sense to opt for less friction vs. more safety if that's the common user expectation. If we really wanted to be extra, we could provide some flax linting utilities (FL201: Did you mean to reuse a module?) :)

Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as jit and pmap which don't work with mutations.

Do you mean passing modules directly into jit? One of the things I tried to do away with during my weekend excursion was flax.nn.Model, given the constraint that flax.nn.Module must be immutable. The solution was not great: have an instance method that returns a new flax.nn.Module when you update parameters or state.

from flax.

shoyer avatar shoyer commented on April 27, 2024

In general I really like the look of this! I think it would be a significant improvement/simplification of Flax's mental model.

πŸ‘ for eliminating the use of __new__ in Modules

πŸ‘ for eliminating .partial().

πŸ‘ for eliminating .shared(). I don't think we need the safeguard -- it is quite common to intentionally reuse models in neural net code

πŸ‘ for encouraging the use of dataclasses (in particular, @dataclass(frozen=True) to enforce immutability)

πŸ‘Ž for requiring dataclasses, and not allowing __init__ methods to be written explicitly. Even if this were possible to enforce in a clean way (I have my doubts), sometimes __init__ can be a nice way to write this, as @lucasb-eyer writes in #208 (comment).

πŸ‘ for the proposed transition plan, which looks quite practical.

from flax.

shoyer avatar shoyer commented on April 27, 2024

One question arises: how does this change effect (if at all) with the way we initialize Flax models? Do we still stick with Module.init and call methods, except these are now normal methods instead of class methods?

from flax.

jesseengel avatar jesseengel commented on April 27, 2024

πŸ‘ to everything said by @lucasb-eyer, @srush, and @shoyer. I think having separate __init__ and __call__ is actually a huge net positive. It allows people to just think in Python instead of "thinking in Flax" like we have to do with TF.

FWIW, I don't see the hydra thing as much of a disadvantage. In many cases it requires people to be more explicit, and you can see what's going on in the submodule itself, instead of hiding things in implicit behind the scenes work. It also then makes it easier to access model attributes from outside the module if you want to hack things later, say in a colab notebook.

Also, i think it's great to allow access to __call__ directly, rather than redirecting to some other function like apply. I'm running into challenges with this in Keras at the moment, as I'm trying to work around some aspects of the forced programming model, but it's inflexible if I only have access to call and not __call__, and requires me digging deep into the Keras base layer code, which is a mess. Let's not make the same mistake for Flax.

from flax.

david-waterworth avatar david-waterworth commented on April 27, 2024

This should also help with my confusion #16 (comment) where not calling partial before create_by_shape results in the model being created with different parameters to what it was trained with.

from flax.

cghawthorne avatar cghawthorne commented on April 27, 2024

Can you add an example of how this would work with an equivalent to module_method?

from flax.

shoyer avatar shoyer commented on April 27, 2024

(Or alternatively pytorch / sonnet 2 syntax which both do this better)

@srush Could you kindly clarify what you mean by this?

Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g., self.dense = nn.Dense(features=16)?

This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in __init__, which is separated from where they are used.

from flax.

jesseengel avatar jesseengel commented on April 27, 2024

Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g., self.dense = nn.Dense(features=16)?

This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in __init__, which is separated from where they are used.

@shoyer To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters.

The mental overhead of having a little boiler plate is a small price to pay for such explicit clarity and python native interaction paradigms (using python's built-in object attributes, vs. some behind the scenes implicit naming schemes)

from flax.

shoyer avatar shoyer commented on April 27, 2024

To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters.

Absolutely, these are all real advantages. On the other hand, I've also had cases where separating initialization/use of layers made my code harder to read and modify because two different parts of the code need to be kept in sync. You also can't use input shapes to determine the shapes of variables. It is not clear to me (personally) which is better/worse in general. It may depend on the context.

Keras lets you write things both ways, which is convenient for users, but of course imposes an even higher cost in terms of complexity.

For JAX, there is one additional consideration, which is whether the module abstraction is amenable to functional transformations -- one of the core strengths of JAX. My understanding is that this is hard to do with Python's mutable object model.

from flax.

srush avatar srush commented on April 27, 2024

I consider inline initialization a Keras design flaw. It mixes functional and structural concerns and makes it very hard to reason about, document, and analyze modules.

However, whether or not you agree with this, the fact that it is causing the library to have ill-defined semantics, with very minimal benefits ("less scrolling up?"), should be a red flag that it is maybe a problem.

from flax.

jheek avatar jheek commented on April 27, 2024

I think we should not use worlds like "normal" or "pythonic". They are really vague statements that essentially refer to similarity with existing programing paradigms that are common in the Python world. We shouldn't strive to please the status quo.

I think the points raised by @srush are important. Although sharing becomes clearer with explicit construction it still isn't quite like an object that owns it's parameters.

Consider the following example:

class MyModule(flax.Module):
  def apply_a(self, x):
    def inner_dense_factory():
      return nn.Dense(123)
    for i in range(3):
      x = inner_dense_factory()(x) 
    return x

  def apply_b(self, x):
    for i in range(3):
      x = self._dense_factory()(x) 
    return x

  def _dense_factory(self):
    return nn.Dense(123, self.my_fancy_init)

Clearly in apply_a we expect the Dense parameters to be shared when we call apply_a multiple times on an instance of MyModule but not between iterations of the loop. But what if we take the dense_factory from apply_a and turn it into a method (_dense_factory)? A seemingly innocent refactor will now cause all the Dense modules in apply_b to be shared.

Of course we can add annotation trickery to distinguish between module methods that have a scope and "inline methods"? But the mental model is still significantly more complex than plain old Python objects.

from flax.

srush avatar srush commented on April 27, 2024

Perhaps I am missing something, but I don't really understand the example above. The implied semantics feel really complicated to me as state seems to bind to functions in a way I cannot trace.

Btw, I don't know if it is helpful, but here is a proof-of-concept of the sort of pure world I like (not saying flax needs to go this way).

https://github.com/srush/parallax

# Everything is immutable @module =  dataclass(frozen=True, repr=False)
@module
class Dense(Module):

    # All parameter-holders are explicitly declared.
    weight : Parameter
    bias   : Parameter

    # Setup replace __init__ and creates shapes and binds lazy initializers.
    @staticmethod
    def setup(in_size, out_size):
        return Dense.init(
            weight = Parameter.setup((out_size, in_size), init.xavier_normal_),
            bias   = Parameter.setup((out_size,), init.normal_))

    # Forward is just like standard pytorch.
    def forward(self, input):
        return self.weight @ input + self.bias

"Sharing" would requires a manual split of the parameter into two parts like this.

@module
class BinaryNetwork(Module):

    # No difference between modules and parameters
    dense1  : Dense
    dense2  : Dense
    dense3  : Dense
    dropout : Dropout

    @staticmethod
    def setup(input_size, hidden_size):
        return BinaryNetwork.init(
            dense1  = Dense.setup(input_size, hidden_size),
            dense2  = Dense.setup(hidden_size, hidden_size),
            dense3  = Dense.setup(hidden_size, 1),
            dropout = Dropout.setup(rate=0.2)
        )

    def forward(self, input):

        # Standard usage works out of the box.
        x = torch.tanh(self.dense1(input))

        # Stochastic modules (have random seed already)
        x = self.dropout(x)

        # Shared params / recurrence requires split (like RNG)
        dense2_a, dense2_b = self.dense2.split(2)
        x = torch.tanh(dense2_a(x))
        x = torch.tanh(dense2_b(x))

        return torch.sigmoid(self.dense3(torch.tanh(x)))

from flax.

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024

yep, was about to say the same as @shoyer the example is convoluted, but we are creating a new Dense object each time, so would definitely not expect weight sharing. Any sharing happening in that code would be weird magic happening under the hood that is very confusing.

from flax.

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024

@srush I fail to see how your example semantically differs from plain PyTorch/nn code? It's "create object at init, use object to apply at forward" semantics, the remaining differences from plain PyTorch/nn look like mostly syntax to me? (edit: not saying this is bad, I like PyTorch/nn)

from flax.

srush avatar srush commented on April 27, 2024

@lucasb-eyer Sorry, I should have explained better. The fact that it looks like pytorch syntax is a red-herring, unlike pytorch the implementation is pure / immutable.

It's "create declarative skeleton at init, (engine fills in tensors), (engine distributes RNG to module), use objects statelessly to apply at forward"

layer = BinaryNetwork.setup(5, 10)

# Initialize parameters -> stateful, hidden
rng = rng_state()
layer = layer.initialize(rng)

for i in range(10):
    rng = rng_state()
    layer = layer.init_state(rng, mode="train")
    grad = grad(layer.forward)(x)
    layer = layer.update(lambda a, b: a + b, grad)

from flax.

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024

I see, yeah I was missing the "use it" code, should've checked your repo. My personal opinion is that classes are the wrong concept to build something pure/immutable/functional.

A few colleagues and I have an internal codebase built on jax, which uses flax in a completely pure/functional way, and flax was open to some design changes to make using flax in that way possible and nice. I think it is very close to your example code actually. We made a simplified version of it public just now, see here: https://github.com/google-research/big_transfer/tree/master/bit_jax

However, all of this pure, pretty, neat, readable stuff goes to πŸ’£ πŸ’© ⚑ the moment you want to add BatchNorm :)

from flax.

srush avatar srush commented on April 27, 2024

Nice I will check it out. Maybe what needs to happen is for the jax community to just have nn.functional module like pytorch so different module systems can use the same layers.

@lucasb-eyer I am still just stuck on one point that is keeping me bother by all these solutions: When you read this code below what is the internal/informal semantics that is going on in your head. Particularly: Where do you imagine that name is stored? do you believe this code knows it is in an object? Do you have a type in your head of x? How do you reason about whether this line of code knows if it is the first or last time it is called? Could this code be tested independently of its system?

    x = nn.Dense(x, num_classes, name="conv_head", kernel_init=nn.initializers.zeros)

Until I can answer these questions, I just can't imagine this will be the final state of a reliable module system.

from flax.

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024

but jax.lax and jax.numpy pretty much correspond to nn.functional :) The next step is deciding how bookkeeping of variables/parameters happens, and that is where all the frameworks opinions differ (and mine differs again, and so does yours).

Regarding your second paragraph, I agree that the line has too much magic (also, where are the dense's w/b tracked? a global collection maybe? 😨) And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics.

from flax.

srush avatar srush commented on April 27, 2024

but jax.lax and jax.numpy pretty much correspond to nn.functional :)

That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy : https://pytorch.org/docs/stable/nn.functional.html

The next step is deciding how bookkeeping of variables/parameters happens,

I agree. That's what I'm interested in.

And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics.

It gets halfway there, I'm arguing it needs to be really solved.

from flax.

srush avatar srush commented on April 27, 2024

@lucasb-eyer Very neat paper though!

from flax.

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024

That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy

Not true either. jax.lax has a pretty powerful implementation of conv (jax.lax.conv_general_dilated), similar for pooling (jax.lax.reduce_window) and, linear (jax.lax.dot_general).

I was about to concede it's missing an RNN, but there is actually none in nn.functional either. The only remaining non-trivial entry of nn.functional that is missing from jax.{nn,lax,numpy} is ctc_loss, and I'm sure jaxers would happily accept a PR for jax.nn.ctc_loss. So I maintain my point that torch.nn.functional β‰ˆ jax.{nn,lax,numpy}.

It gets halfway there, I'm arguing it needs to be really solved.

I went back to read it, and I actually agree with the points in your first comment in this thread.

Thanks :)

from flax.

srush avatar srush commented on April 27, 2024

Oh well, now I feel silly. It does seem like the lax functions just are much more general than the pytorch implementations. I honestly never found reduce_window on my own (the doc of "Wraps XLA’s ReduceWindow operator" doesn't really help). The Stax implementation does make it clear though.

from flax.

lucasb-eyer avatar lucasb-eyer commented on April 27, 2024

No worries. jax.lax is extremely powerful, I like its API a lot (reminds me of BLAS, but in times of XLA) and is criminally under-documented!

from flax.

j-towns avatar j-towns commented on April 27, 2024

FYI (you may already know this) most of the ops in lax (things like reduce_window) are documented in more detail here. I guess we ought to copy more of those docs over to JAX.

from flax.

avital avatar avital commented on April 27, 2024

It's been a while, and sorry for not posting more in this thread. We've gone through a major API redesign aligned with the goals originally described in this thread.

Our new Linen API came out of many user group discussions, trying to find a solution that empowers our users, while staying relatively simple and exposes the full power of JAX.

All of our examples have been ported, and multiple large projects have transitioned using our upgrade guide, so now we're making it the official API.

Please check it out! Please ask any questions or suggestions for improvements on our discussion board.

The old flax.nn API is being deprecated.

from flax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.