Comments (4)
Hi Jon, I was able to reproduce this, but once I upgraded to latest flax, jax (which required a new version of jaxlib as well) this was resolved. Can you give that a shot?
from flax.
Hi @jondeaton, I believe your returns
should be target vectors so maybe try changing to returns = jax.random.normal(key, (batch_size, 1))
(though the fact that this error is what appears should be considered a bug!)
from flax.
Thanks for the suggestion! Unfortunately, I'm still encountering the same problem after changing returns
as you recommend. Although, I tried reinstalling flax at head and now I am getting a different error / stack trace, (same exact code though)
Traceback (most recent call last):
File "test.py", line 25, in <module>
_, model = CNN.create_by_shape(key, [(input_shape, jnp.float32)])
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/nn/base.py", line 261, in wrapper
return super_fn(*args, **kwargs)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/nn/base.py", line 381, in create_by_shape
return jax_utils.partial_eval_by_shape(lazy_create, input_specs)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/jax_utils.py", line 86, in partial_eval_by_shape
output_shapes = jax.eval_shape(lazy_fn, *input_structs)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 273, in abstract_eval_fun
instantiate=True)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/jax_utils.py", line 79, in lazy_fn
master = leaves[0].trace.master
AttributeError: 'function' object has no attribute 'master'
Any ideas why this is occurring or how I could contribute a patch that would fix this bug? I'm not able to continue with my RL project that I'm trying to use flax for, so I'd be happy to send a PR but I'm not really sure where to start looking to fix this.
from flax.
I managed to run it without errors as well using the latest version of flax, so I am closing this issue.
from flax.
Related Issues (20)
- Error occurs in `nn.vmap` while `variable_axes` is a nested dict HOT 1
- flax.linen.module.init still fails under dynamic type checking for nested modules
- *Module Parameters* section of docs is outdated. HOT 4
- More memory consume compared with Pytorch HOT 1
- Difference in output between jitted and non-jitted call
- Error when calling `Module.tabulate` on normalization wrappers like `WeightNorm` and `SpectralNorm`
- Orbax checkpoint for LogicallyPartitioned params HOT 2
- For some reason these imports are elided on read the docs
- Using variable declared at a broader scope in a function is bad form HOT 1
- Add `BatchRenorm` layer to `linen.normalization`
- GroupedConv distributed training failure
- In `MultiHeadAttention`, let `num_heads=1` by default
- Documentation/notebook errors HOT 2
- Remove `tree_map` deprecation filter after Flax upgrades minimum Python version to 3.10
- Unpickled modules with constructor arguments cannot be initialized
- Improve SEO for docs pages HOT 2
- Add ability to easily change documentation version
- Problem while using checkpoints.restore_checkpoint with gradio
- nnx static fields not part of static tree structure HOT 1
- nn.remat_scan doesn't work with nn.with_partitioning
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.