Git Product home page Git Product logo

Comments (9)

kazewong avatar kazewong commented on August 16, 2024 1

I played with the notebook a bit, and here is an updated version https://colab.research.google.com/gist/kazewong/033c89e548ef59b3ceb649dcf2ffe9e5/bayeux_and_flowmc.ipynb

Here are some of the changes:

  1. I changed the runtime to T4GPU, which helps speed up the training.
  2. Fiddled with the network hyperparameters a bit, shouldn't really matter
  3. Changed the local sampler back to MALA, I think this should be sufficient for this example since the dimensionality is not too high. And in general, MALA is faster than HMC.
  4. Increase n_chains to 24. Just playing around with the number, in general, the more the better, and the runtime usually doesn't change on a GPU unless one starts to saturate the compute or memory bandwidth of the GPU.
  5. Added a diagnostic cell with print(nf_sampler.get_sampler_state(training=True)['loss_vals'].min(),global_accs.mean()). The local sampler acceptance shows whether the local sampler is reasonable, anywhere between 0.2-0.8 is more less okay. The global sampler acceptance shows how well the flow has been trained to approximate the target, the higher the better. In the original notebook was about 0.02, now it is 0.44.
  6. Also changed some hyperparameters for flowMC. They are annotated with inline comments. The most important ones are probably n_loop_training and n_loop_production, they basically control how many times the sampler alternates between the local sampler and global sampler. The larger they are the longer the run time since you are asking for more samples, but that helps with training and producing more samples. I should make this more clear in the tuning guide (writing one soon!)

With all these changes, I think the flowMC result is more reasonable. Now there is one more problem that is actually interesting and I had it in the back of my mind but never really finished it.

With everything else being the same as shown in the notebook, this is combining all the chains when I put n_loop_production=10
image

This is when I use n_loop_production=30
image

You can see the stripes are distributed in the larger n_loop_production case. The performance of the flow is the same since the training phase is the same for the two, only the length of the production phase is changed.

The reason for this is probably because the samples produced by the local sampler are correlated while the ones by the global sampler are uncorrelated ( or way less correlated). Since n_local_steps and n_global_steps are the same, this means the global sampler will fly the chain around for 50 steps per loop, then the local sampler will jitter the chain locally for 50 steps per loop. This probably caused the extra cluster of points around each stripe, and as we increase the n_loop_production, the global sampler brings the chains to more places, creating more densely sampled stripes.

To solve this, I think there is actual work to do. Basically, we want effective samples with the local sampler instead of every sample, which should provide way smoother posterior and take away the stripes artifacts.

Last remark, I agree this is a problem HMC can probably solve quite well. flowMC added the extra layer of normalizing flow to deal with bad geometries, such as multimodality or really stretched out and local correlation (like a donut). This problem is rather unimodal and smooth, so HMC shouldn't have a hard time dealing with it.

Please let me know if you have more problems regarding this, I am also happy to help make this an example so other users can follow the logic behind this discussion.

from flowmc.

ColCarroll avatar ColCarroll commented on August 16, 2024 1

Thanks again -- working on the PR now.

It should do even slightly better than above, since bayeux has machinery to transform the support of models to all of R^n -- right now flowMC has no idea school_effects must be positive (other than the nan log density), while numpyro has the advantage of having that parameter transformed by softmax.

from flowmc.

ColCarroll avatar ColCarroll commented on August 16, 2024 1

jax-ml/bayeux#23 is out now -- if you have time for comments/suggestions, please do!

There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using bayeux.

A few notes from doing this -- lmk if you'd like these to be separate issues:

  • requiring a params or kwargs argument makes it difficult for static checkers to make sure I am using the right arguments. I sort of assume that if flowMC gets updated, bayeux will break -- it seems like you could expand the signature of the nf_models and local_samplers?
  • related, a few of these keyword arguments are similar, but slightly different (n_layer vs n_layers and hidden_size vs n_hidden)
  • the random_key_set seems like it is probably an anti-pattern -- in particular, I would like to pass in a jax prngkey and have everything "just work" (or even pass such a key to the helper function instead of an int). it seems like maybe Sampler has enough information to handle the key splitting itself?

from flowmc.

ColCarroll avatar ColCarroll commented on August 16, 2024 1

https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing is updated with flowMC, including a multimodal distribution towards the end where it seems to do better than numpyro (for some value of "better"!)

I'll probably add an example notebook based on that soon. First I have to fix the bug that doesn't allow setting keyword arguments!

from flowmc.

kazewong avatar kazewong commented on August 16, 2024

flowMC assumes the likelihood function with the following signature:

def logp(pts: jnp.array, data: PyTree) -> float:
...

where pts should be the parameters you want to sample, and data is a pytree containing auxiliary data you don't have to sample over.

So in order to get this to work, I think you need to modify your likelihood defined around

def log_prob(pts):
  print("HEY!", pts)
  return bx_model.log_density(pts)

in two ways:

  1. The input should be a flatten array of the struct tuple. Currently, init is a struct tuple with shape ((8,) (8,), (8,8)) in your example. Say if the parameters you want to sample are a combination of avg_effect, avg_stddev, and school_effects, which in total are 80 parameters, then your input would be a (80,) jax array, then within the likelihood you need to restructure the parameters into struct tuple in order to pass it to your model. I am not familiar with struct tuple from tfp, so I don't know how to do that, but that should be a good start.
  2. Currently log_prob(init) returns 8 numbers, which makes me think you actually have less parameters than 80. Ideally bx_model.log_density(pts) should return a single scalar. flowMC handles the vmapping under the hood, so the likelihood should be a function only with one chain instead of all the chains.

To be more concrete, the init points should have a shape (n_chains, n_dim), where n_chains is the number of chains, and n_dim is the dimension of the parameters you want to sample over.

Let me know whether this helps resolving the problem. If this works in the end, would you mind if I link this example on our doc page so others can take a look of this as well?

P.S. Would you mind pointing me to the LearnBayesStat episode?

from flowmc.

ColCarroll avatar ColCarroll commented on August 16, 2024

Thanks for the pointers!

I was able to get my example running -- I have to do some funny things to get my state flat (and unflat) which I would guess hurts performance, particularly on accelerators where reshape is not free.

Performance is also terrible-ish, which I assume is user error -- here is the example with 12 chains:

image

I updated the colab with the code that actually runs. I'll keep looking at this tomorrow, but continue to appreciate suggestions for improving performance if you have any: bayeux tries to set generally sensible defaults (which are also provided to the user). I would guess that I need to adjust these defaults to get better sampling! https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing

For what it is worth, I included samples from numpyro's NUTS sampler at the bottom -- a well tuned sampler does 7 HMC steps with step size 0.6.

image

from flowmc.

ColCarroll avatar ColCarroll commented on August 16, 2024

Update: I ran with HMC (which secretly requires a condition_matrix argument -- np.eye(n_dim) is a sensible default) using the numpyro tuning arguments and added more epochs:
image

Does the fact that I am ignoring the data argument matter here? bayeux assumes that the user will just close over data, like log_density = functools.partial(log_prob, data=data).

Also, the podcast is here: https://open.spotify.com/episode/1wRsmH8xXTpO8JOWajgWpL?si=df762a337c7d4361 (the website https://learnbayesstats.com/ has not updated with @marylou-gabrie's episode). She did an excellent job describing the algorithm, and I'm hoping bayeux will allow users of other PPLs to give flowMC a try. Also, I personally hate benchmarks for MCMC and appreciated Dr. Gabrié's nuance there! I'll certainly let you know if/when this merges over there!

from flowmc.

ColCarroll avatar ColCarroll commented on August 16, 2024

Ok, got a little nerd-sniped by this, but it looks like setting more local steps and more loops gets reasonable performance:

    n_local_steps = 200, 
    n_global_steps = 50,
    n_loop_production=4,
image

which compares reasonably well with numpyro
image

I'm going off the podcast here, but I suppose this is a problem that is well suited to HMC, rather than being some wild statistical mechanics problem with symmetries to deal with? I'll try to cook up one of those tomorrow along with a PR to make this a little more ergonomic.

from flowmc.

kazewong avatar kazewong commented on August 16, 2024

jax-ml/bayeux#23 is out now -- if you have time for comments/suggestions, please do!

There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using bayeux.

A few notes from doing this -- lmk if you'd like these to be separate issues:

* requiring a `params` or `kwargs` argument makes it difficult for static checkers to make sure I am using the right arguments. I sort of assume that if `flowMC` gets updated, `bayeux` will break -- it seems like you could expand the signature of the nf_models and local_samplers?

* related, a few of these keyword arguments are similar, but slightly different (`n_layer` vs `n_layers` and `hidden_size` vs `n_hidden`)

* the random_key_set seems like it is probably an anti-pattern -- in particular, I would like to pass in a jax prngkey and have everything "just work" (or even pass such a key to the helper function instead of an int). it seems like maybe `Sampler` has enough information to handle the key splitting itself?
  1. Do you mean params in different local_sampler? Currently that is used to maintain a somewhat more unified API across different local_sampler. When I was making this code a year ago I wasn't paying much attention to typing so it is rather unsatisfactory in the way it currently is now. It seems a bit tricky to me how to handle different params with a static checker while providing a unified API to flowMC since different local sampler might have different number of params, coming in different shape and type. There might be a solution lying in some examples of equinox. I will see what I can do about this. The minimum will probably a run-time check during initialization.
  2. Noted, will clean it up in the next version.
  3. Yeah I wanted to get that guy out for a while, will clean it up

from flowmc.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.