Comments (7)
In the proposed syntax, why can’t the user write z = self.encoder(x)
in apply
? What happens if they do? And why do the encoder and decoder modules have to be shared
? What happens if they’re not?
from flax.
why can’t the user write z = self.encoder(x) in apply?
That will work too. In this case it also doesn't matter because the function is a one-liner. If it is more complicated one might want to use the module method from within apply for code reuse
And why do the encoder and decoder modules have to be shared? What happens if they’re not?
In the example we use the encoder and decoder at most once. So in this case partial would also work.
The idea is to promote the usage of shared in the constructor because using a module multiple times would bite you otherwise.
from flax.
This would be useful for normalizing flows as well. In the meantime, there seems to be a namespace issue with the _shared_module
approach. By modifying this line, to use cls.__name__
, my implementation is working again.
def compose_transforms(transforms):
class TransformSequence(Transform):
def _shared_modules(self):
return [t.shared() for t in transforms]
@flax.nn.module_method
def transform(self, x):
transforms = self._shared_modules()
for t in transforms:
x = t.transform(x)
return x
@flax.nn.module_method
def inverse_and_log_det_jac(self, y):
transforms = self._shared_modules()
log_det_jac = 0.0
for t in reversed(transforms):
y, term = t.inverse_and_log_det_jac(y)
log_det_jac += term
return y, log_det_jac
return TransformSequence
Is there a better workaround for now?
from flax.
My main worry about supporting __init__
is that it could lead to incorrect assumptions about how Flax works, due to apparent similarity with normal class syntax. If I saw a call like the proposed AutoEncoder
without knowing anything about Flax, I would expect the only valid argument to AutoEncoder(*args, **kwargs)
to be those that appear explicitly on __init__
, but that isn't how Flax works.
Some ideas:
- Use a different name from
__init__
, e.g.,setup
(we already useinit
for variables). - Consider (conditionally?) switching to an explicit setup/call split like Keras or Haiku:
AutoEncoder(**params)(x)
. This would probably be more pervasive than you want.
The magic separation of arguments between __init__
and apply
also worries me a little bit. I don't know if there is a good way to do this, but I do think passing all **kwargs
to __init__
(or setup
) is a better alternative than using introspection.
from flax.
This would be useful for normalizing flows as well. In the meantime, there seems to be a namespace issue with the
_shared_module
approach. By modifying this line, to usecls.__name__
, my implementation is working again.def compose_transforms(transforms): class TransformSequence(Transform): def _shared_modules(self): return [t.shared() for t in transforms] @flax.nn.module_method def transform(self, x): transforms = self._shared_modules() for t in transforms: x = t.transform(x) return x @flax.nn.module_method def inverse_and_log_det_jac(self, y): transforms = self._shared_modules() log_det_jac = 0.0 for t in reversed(transforms): y, term = t.inverse_and_log_det_jac(y) log_det_jac += term return y, log_det_jac return TransformSequenceIs there a better workaround for now?
@mattwescott I think your issue was introduced in a recent change to the default name policy. This is probably a bug that should be fixed. I'll look into it ASAP
from flax.
- Use a different name from
__init__
, e.g.,setup
(we already useinit
for variables).
I agree with this proposal
- Consider (conditionally?) switching to an explicit setup/call split like Keras or Haiku:
AutoEncoder(**params)(x)
. This would probably be more pervasive than you want.
That's a very big change to make and takes away a key advantage of calling modules as functions.
The magic separation of arguments between
__init__
andapply
also worries me a little bit. I don't know if there is a good way to do this, but I do think passing all**kwargs
to__init__
(orsetup
) is a better alternative than using introspection.
I do really dislike the introspection as well. Passing all kwargs
to setup
is probably fine the main downside is that they are also passed to all the module methods which then probably need something like an **unused_kwargs
.
from flax.
This is no longer relevant since Linen has landed, so I'm closing this for now.
from flax.
Related Issues (20)
- Cannot assign arrays to dataclass fields in `nnx` HOT 2
- Flax not found for ppc64 power9 HOT 3
- Cannot use static_argnums with flax.linen.checkpoint on a Module whose __call__ has a boolean control flag HOT 1
- Unexpected behavior for @nn.compact_name_scope
- [Feature Request] Modular checkpointing of Flax module HOT 1
- Add CRF module HOT 1
- Standardizing normalization layers HOT 2
- Jax transforms and Flax models cannot be mixed
- VAE example outdated HOT 3
- Make `self.make_rng()` callable by using a default RNG stream and have `.init()` and `.apply()` use the default RNG stream if no explicit RNG mapping is passed.
- DenseGeneral with more than 2 dimensions cannot be partitioned HOT 1
- asyncio error while loading weights HOT 4
- Add vanilla / Elman / simple RNN cell HOT 4
- NNX `_compute_stats` function missing `use_fast_variance` and `mask` argument HOT 1
- Memory issue when randomly initializing large parameters, sharding cannot help
- Deprecation Warnings with orbax 0.5.3 HOT 2
- Feature request: Add ConvGRUCell
- The Error in FLOP Computation of Model Tabulate Function HOT 1
- Make redundant `features` argument optional for recurrent cells HOT 2
- Add `reset_gate` flag to `MGUCell` HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.