Comments (11)
How far along is this implementation? Would contributions be appreciated? I may have some time to work on a tensorflow version.
from probability.
@junpenglao You can always run it with batch size = 1... There would probably be some overhead that a direct unbatched implementation wouldn't impose; if that overhead becomes a problem, we (or you!) could look into ways to optimize it.
from probability.
Yes, we are actively working on it. NUTS (unlike vanilla HMC) has data-dependent recursive control flow, which makes it a bit tricky to implement in graph-mode TF (i.e., without relying on TF Eager). There is therefore no definite time line, but it is a priority.
from probability.
Awesome- thanks for letting me know, and definitely thanks for working on this!
from probability.
We are all happy to know you're interested in helping! You should know that we're going for an implementation that can run in graph mode, and can batch together multiple independent chains running in tandem (or at least their leapfrog steps). Since different chains can take different numbers of leapfrog steps during one update, the solution is proving surprisingly complicated---it basically amounts to writing a multi-stage compiler targeting TF graphs, which implements the logic needed for batching and recursion (which, however, should be reusable).
Given the scope, off the top of my head I would guess the implementation roadmap will take at least another full-time engineer-month. If you are interested in diving in, we can look for a place that would make sense, but it may take a while to get you up to speed on the design and what has happened so far.
from probability.
Thanks for the reply! That sounds quite involved, and I am not sure whether I have the time to spare for that right now, but I'll get in touch if I can.
from probability.
@axch Sounds exciting. How do you mean different chains? Running separate NUTS chains in parallel? What does batching do and how does it help?
from probability.
Yes, independent NUTS chains in parallel. For an example of how it can help, consider Bayesian linear regression, on, say, the cover type data set. That has 54 features (plus the intercept) and a half-million data points. The dominant cost of running NUTS on that will be the matrix-vector multiplies of a 500,000 x 55 matrix with a 55-vector that happen in the leapfrog steps.
But suppose we are able to run, say, 100 NUTS chains in parallel, and synchronize them on the leapfrog steps. Then, to run the leapfrogs in batch, we just need one (gradient of a) matrix-matrix multiply of a 500,000 x 55 matrix with a 55 x 100 matrix, and on modern accelerators that's much faster than 100 separate matrix-vector multiplies. (It's even faster on CPU, but not as dramatically.)
from probability.
That's a really interesting idea, so if I get the idea correctly, the gradient computation is performed once at each leapfrog step across all the parallel chains, so you basically create a new Hamiltonian system that contains n chain
copy of the original system? But how would that help synchronize the tree building? As each chain the recurrent tree building would still be different.
from probability.
@junpenglao I wouldn't call it one Hamiltonian system, but rather 100 Hamiltonian systems whose energy function we can evaluate in parallel with batching (to get 100 energies, or rather, 100 energy gradients).
You identified the difficulty exactly: the recurrent tree building is different in each chain, and that is why the design involves what I called a "compiler" earlier, though you can just think of it as an elaborate Python program for building the TensorFlow graph. What happens at the bottom is that, in order to synchronize the leapfrogs, we have to let the control logic at least partially de-synchronize: if chain A ends up needing 5 u-turn checks to get to its next leapfrog while chain B needs only 2, we have to do all 5 anyway, and waste some compute power while chain B waits. But, at least on the cover type example I gave earlier, that should be fine: those u-turn checks are dot products of 55-vectors, so it's ok to waste some in order to get the best batching for the operation involving the 500,000-row matrix.
from probability.
OK I see, thanks for the explanation! Will running different chain out of sync also a potential option (FYI that's how PyMC3 does it)?
from probability.
Related Issues (20)
- Can you give an example of Bayesian Vector Autoregression based on TFP
- Normal Inverse Gaussian Outputs Positive log_prob HOT 4
- JAX backend doesn't use `jax.tree_util` HOT 1
- Sample from a partially known TensorShape inside the train_step function of a keras subclassed model HOT 1
- How to add new data to the pretrained Structural Time Series model in Tensorflow
- TurncatedNormal gives wrong results sometimes HOT 4
- Dirichlet distribution sampling issue when jit_compile=True HOT 1
- AttributeError: 'SymbolicTensor' object has no attribute 'log_prob' when exporting train signature with `IndependentNormal` layer HOT 1
- Add Poisson quantile
- Computing log_prob for tfd.Sample() with a different number of samples
- TruncatedCauchy gives wrong results sometimes
- `_parameter_properties` is not implemented for `LinearGaussianStateSpaceModel`
- tensorflow 2.16.1 breaks tensorflow-probability with Keras `3.0` API HOT 3
- `LinearGaussianStateSpaceModel` filtering initial state is incorrect
- Piecewise distribution
- Keras not accepting character `/` from build_factored_surrogate_posterior HOT 4
- A bug in Linear_Mixed_Effects_Models.ipynb
- Conditional input with multiple flows HOT 1
- mlx backend HOT 1
- Can't jit PoissonLogNormalQuadratureCompound log_prob
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from probability.