Comments (9)
I played with the notebook a bit, and here is an updated version https://colab.research.google.com/gist/kazewong/033c89e548ef59b3ceb649dcf2ffe9e5/bayeux_and_flowmc.ipynb
Here are some of the changes:
- I changed the runtime to T4GPU, which helps speed up the training.
- Fiddled with the network hyperparameters a bit, shouldn't really matter
- Changed the local sampler back to MALA, I think this should be sufficient for this example since the dimensionality is not too high. And in general, MALA is faster than HMC.
- Increase
n_chains
to 24. Just playing around with the number, in general, the more the better, and the runtime usually doesn't change on a GPU unless one starts to saturate the compute or memory bandwidth of the GPU. - Added a diagnostic cell with
print(nf_sampler.get_sampler_state(training=True)['loss_vals'].min(),global_accs.mean())
. The local sampler acceptance shows whether the local sampler is reasonable, anywhere between 0.2-0.8 is more less okay. The global sampler acceptance shows how well the flow has been trained to approximate the target, the higher the better. In the original notebook was about 0.02, now it is 0.44. - Also changed some hyperparameters for flowMC. They are annotated with inline comments. The most important ones are probably
n_loop_training
andn_loop_production
, they basically control how many times the sampler alternates between the local sampler and global sampler. The larger they are the longer the run time since you are asking for more samples, but that helps with training and producing more samples. I should make this more clear in the tuning guide (writing one soon!)
With all these changes, I think the flowMC result is more reasonable. Now there is one more problem that is actually interesting and I had it in the back of my mind but never really finished it.
With everything else being the same as shown in the notebook, this is combining all the chains when I put n_loop_production=10
This is when I use n_loop_production=30
You can see the stripes are distributed in the larger n_loop_production
case. The performance of the flow is the same since the training phase is the same for the two, only the length of the production phase is changed.
The reason for this is probably because the samples produced by the local sampler are correlated while the ones by the global sampler are uncorrelated ( or way less correlated). Since n_local_steps
and n_global_steps
are the same, this means the global sampler will fly the chain around for 50 steps per loop, then the local sampler will jitter the chain locally for 50 steps per loop. This probably caused the extra cluster of points around each stripe, and as we increase the n_loop_production
, the global sampler brings the chains to more places, creating more densely sampled stripes.
To solve this, I think there is actual work to do. Basically, we want effective samples with the local sampler instead of every sample, which should provide way smoother posterior and take away the stripes artifacts.
Last remark, I agree this is a problem HMC can probably solve quite well. flowMC added the extra layer of normalizing flow to deal with bad geometries, such as multimodality or really stretched out and local correlation (like a donut). This problem is rather unimodal and smooth, so HMC shouldn't have a hard time dealing with it.
Please let me know if you have more problems regarding this, I am also happy to help make this an example so other users can follow the logic behind this discussion.
from flowmc.
Thanks again -- working on the PR now.
It should do even slightly better than above, since bayeux has machinery to transform the support of models to all of R^n -- right now flowMC has no idea school_effects
must be positive (other than the nan
log density), while numpyro
has the advantage of having that parameter transformed by softmax
.
from flowmc.
jax-ml/bayeux#23 is out now -- if you have time for comments/suggestions, please do!
There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using bayeux
.
A few notes from doing this -- lmk if you'd like these to be separate issues:
- requiring a
params
orkwargs
argument makes it difficult for static checkers to make sure I am using the right arguments. I sort of assume that ifflowMC
gets updated,bayeux
will break -- it seems like you could expand the signature of the nf_models and local_samplers? - related, a few of these keyword arguments are similar, but slightly different (
n_layer
vsn_layers
andhidden_size
vsn_hidden
) - the random_key_set seems like it is probably an anti-pattern -- in particular, I would like to pass in a jax prngkey and have everything "just work" (or even pass such a key to the helper function instead of an int). it seems like maybe
Sampler
has enough information to handle the key splitting itself?
from flowmc.
https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing is updated with flowMC
, including a multimodal distribution towards the end where it seems to do better than numpyro (for some value of "better"!)
I'll probably add an example notebook based on that soon. First I have to fix the bug that doesn't allow setting keyword arguments!
from flowmc.
flowMC assumes the likelihood function with the following signature:
def logp(pts: jnp.array, data: PyTree) -> float:
...
where pts should be the parameters you want to sample, and data is a pytree containing auxiliary data you don't have to sample over.
So in order to get this to work, I think you need to modify your likelihood defined around
def log_prob(pts):
print("HEY!", pts)
return bx_model.log_density(pts)
in two ways:
- The input should be a flatten array of the struct tuple. Currently, init is a struct tuple with shape ((8,) (8,), (8,8)) in your example. Say if the parameters you want to sample are a combination of
avg_effect
,avg_stddev
, andschool_effects
, which in total are 80 parameters, then your input would be a (80,) jax array, then within the likelihood you need to restructure the parameters into struct tuple in order to pass it to your model. I am not familiar with struct tuple from tfp, so I don't know how to do that, but that should be a good start. - Currently
log_prob(init)
returns 8 numbers, which makes me think you actually have less parameters than 80. Ideallybx_model.log_density(pts)
should return a single scalar. flowMC handles the vmapping under the hood, so the likelihood should be a function only with one chain instead of all the chains.
To be more concrete, the init points should have a shape (n_chains, n_dim)
, where n_chains
is the number of chains, and n_dim
is the dimension of the parameters you want to sample over.
Let me know whether this helps resolving the problem. If this works in the end, would you mind if I link this example on our doc page so others can take a look of this as well?
P.S. Would you mind pointing me to the LearnBayesStat episode?
from flowmc.
Thanks for the pointers!
I was able to get my example running -- I have to do some funny things to get my state flat (and unflat) which I would guess hurts performance, particularly on accelerators where reshape
is not free.
Performance is also terrible-ish, which I assume is user error -- here is the example with 12 chains:
I updated the colab with the code that actually runs. I'll keep looking at this tomorrow, but continue to appreciate suggestions for improving performance if you have any: bayeux
tries to set generally sensible defaults (which are also provided to the user). I would guess that I need to adjust these defaults to get better sampling! https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing
For what it is worth, I included samples from numpyro
's NUTS sampler at the bottom -- a well tuned sampler does 7 HMC steps with step size 0.6.
from flowmc.
Update: I ran with HMC (which secretly requires a condition_matrix
argument -- np.eye(n_dim)
is a sensible default) using the numpyro
tuning arguments and added more epochs:
Does the fact that I am ignoring the data
argument matter here? bayeux
assumes that the user will just close over data, like log_density = functools.partial(log_prob, data=data)
.
Also, the podcast is here: https://open.spotify.com/episode/1wRsmH8xXTpO8JOWajgWpL?si=df762a337c7d4361 (the website https://learnbayesstats.com/ has not updated with @marylou-gabrie's episode). She did an excellent job describing the algorithm, and I'm hoping bayeux
will allow users of other PPLs to give flowMC a try. Also, I personally hate benchmarks for MCMC and appreciated Dr. Gabrié's nuance there! I'll certainly let you know if/when this merges over there!
from flowmc.
Ok, got a little nerd-sniped by this, but it looks like setting more local steps and more loops gets reasonable performance:
n_local_steps = 200,
n_global_steps = 50,
n_loop_production=4,
![image](https://private-user-images.githubusercontent.com/2295568/300872963-f29be100-952b-47ca-8c93-d3f0a7777ba3.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTc4MDY4NTYsIm5iZiI6MTcxNzgwNjU1NiwicGF0aCI6Ii8yMjk1NTY4LzMwMDg3Mjk2My1mMjliZTEwMC05NTJiLTQ3Y2EtOGM5My1kM2YwYTc3NzdiYTMucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDYwOCUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA2MDhUMDAyOTE2WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9MWU5M2YwZDAxZjIwYTIyMGVjMGQ2ZjhkOTQ5MWVhN2Q1YzhjYzgzOWU1ODE2NThiNDMyMTA4NmI5MDc0OWE3NCZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.80ayfBG1AEsoygFDV-9SOUCOVcOyh3kw9I6UHoJoRT4)
which compares reasonably well with numpyro
I'm going off the podcast here, but I suppose this is a problem that is well suited to HMC, rather than being some wild statistical mechanics problem with symmetries to deal with? I'll try to cook up one of those tomorrow along with a PR to make this a little more ergonomic.
from flowmc.
jax-ml/bayeux#23 is out now -- if you have time for comments/suggestions, please do!
There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using
bayeux
.A few notes from doing this -- lmk if you'd like these to be separate issues:
* requiring a `params` or `kwargs` argument makes it difficult for static checkers to make sure I am using the right arguments. I sort of assume that if `flowMC` gets updated, `bayeux` will break -- it seems like you could expand the signature of the nf_models and local_samplers? * related, a few of these keyword arguments are similar, but slightly different (`n_layer` vs `n_layers` and `hidden_size` vs `n_hidden`) * the random_key_set seems like it is probably an anti-pattern -- in particular, I would like to pass in a jax prngkey and have everything "just work" (or even pass such a key to the helper function instead of an int). it seems like maybe `Sampler` has enough information to handle the key splitting itself?
- Do you mean
params
in different local_sampler? Currently that is used to maintain a somewhat more unified API across different local_sampler. When I was making this code a year ago I wasn't paying much attention to typing so it is rather unsatisfactory in the way it currently is now. It seems a bit tricky to me how to handle different params with a static checker while providing a unified API to flowMC since different local sampler might have different number of params, coming in different shape and type. There might be a solution lying in some examples of equinox. I will see what I can do about this. The minimum will probably a run-time check during initialization. - Noted, will clean it up in the next version.
- Yeah I wanted to get that guy out for a while, will clean it up
from flowmc.
Related Issues (20)
- `jax.interpreters.pxla` has no attribute `ShardedDeviceArray` HOT 2
- Image of the function in tutorial HOT 1
- Ensemble training of normalizing flow
- Sampling from arrays
- Get rid of random_key_set
- Clean up parameter names HOT 1
- Question regarding the data for the log-likelihood HOT 3
- Use scan to reduce NF compilation time
- Making sampler composable
- Put training loop into NF class
- TypeError: unsupported operand type(s) for *: `dict` and `dict` in MALA.py after flowMC-v3.0.0 release HOT 4
- Implement optimization strategy
- Add probability floor to normalizing flow model
- [Fixed bug, but not in release] UnboundLocalError: local variable 'best_state' referenced before assignment HOT 2
- Update examples
- Lower precision training HOT 1
- Why do we have to pass data two times? HOT 1
- Refine strategy interface
- Implement flow matching
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flowmc.