Comments (2)
OK, so if you were just trying to evaluate this on a single device, you'd just remove this line from create_optimizer
:
optimizer = flax.jax_utils.replicate(optimizer)
the replicate
fn is broadcasting the params to every "device" being used now - on a single gpu just 1 (which is kind of silly), but on multigpu or tpu: 4, 8, etc. Notice you didn't replicate the model above but replicated the optimizer (again both hold params inside them for convenience) which is why the second broke.
Now, assuming you do want to do replicated SPMD computation across multiple devices, you would keep that replicate
call in create_optimizer
but you also need to define a model eval function to be pmapped so that it can use the replicated parameters it stores across devices, for example:
@jax.pmap
def eval_w_pmap(model, x, prng_key):
with flax.nn.stochastic(prng_key):
# model is just a container for replicated params
return model(x)
ldc = jax.local_device_count()
pmap_test_input = jax.random.normal(jax.random.PRNGKey(1), (ldc, 1, 256, 256, 3))
pmap_rngs = jax.random.split(jax.random.PRNGKey(0), ldc)
eval_w_pmap(generator_optimizer.target, pmap_test_input, pmap_rngs).shape # (1, 8, 256, 256, 3)
I hope that helps explain what's going on - please let me know if it's still not clear!
from flax.
Thank you @levskaya, this is very clear explanation!
from flax.
Related Issues (20)
- Documentation issue: batch_stats in batchnorm need to be marked as mutable even in test/inference HOT 1
- How to implement the same functionality as hk.BatchApply in flax HOT 6
- [Feature request]Unify the order of input arguments and returns โโof RNN HOT 2
- Precision error using checkpoint with CUDA
- [struct.dataclass] Consider adding optional `kw_only` arguments HOT 4
- Instance Normalization HOT 1
- Cannot assign arrays to dataclass fields in `nnx` HOT 2
- Flax not found for ppc64 power9 HOT 3
- Cannot use static_argnums with flax.linen.checkpoint on a Module whose __call__ has a boolean control flag HOT 1
- Unexpected behavior for @nn.compact_name_scope
- [Feature Request] Modular checkpointing of Flax module HOT 1
- Add CRF module HOT 1
- Standardizing normalization layers HOT 2
- Jax transforms and Flax models cannot be mixed
- VAE example outdated HOT 3
- Make `self.make_rng()` callable by using a default RNG stream and have `.init()` and `.apply()` use the default RNG stream if no explicit RNG mapping is passed.
- DenseGeneral with more than 2 dimensions cannot be partitioned HOT 1
- asyncio error while loading weights HOT 4
- Add vanilla / Elman / simple RNN cell HOT 4
- NNX `_compute_stats` function missing `use_fast_variance` and `mask` argument HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flax.