Git Product home page Git Product logo

flowmc's Introduction

flowMC

Normalizing-flow enhanced sampling package for probabilistic inference

doc doc

flowMC_logo

flowMC is a Jax-based python package for normalizing-flow enhanced Markov chain Monte Carlo (MCMC) sampling. The code is open source under MIT license, and it is under active development.

  • Just-in-time compilation is supported.
  • Native support for GPU acceleration.
  • Suit for problems with multi-modality.
  • Minimal tuning.

Installation

The simplest way to install the package is to do it through pip

pip install flowMC

This will install the latest stable release and its dependencies. flowMC is based on Jax and Equinox. By default, installing flowMC will automatically install Jax and Equinox available on PyPI. Jax does not install GPU support by default. If you want to use GPU with Jax, you need to install Jax with GPU support according to their document. At the time of writing this documentation page, this is the command to install Jax with GPU support:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

If you want to install the latest version of flowMC, you can clone this repo and install it locally:

git clone https://github.com/kazewong/flowMC.git
cd flowMC
pip install -e .

Requirements

Here is a list of packages we use in the main library

* Python 3.9+
* Jax
* Jaxlib
* equinox

To visualize the inference results in the examples, we requrie the following packages in addtion to the above:

* matplotlib
* corner
* arviz

The test suite is based on pytest. To run the tests, one needs to install pytest and run pytest at the root directory of this repo.

Attribution

If you used flowMC in your research, we would really appericiate it if you could at least cite the following papers:

@article{Wong:2022xvh,
    author = "Wong, Kaze W. k. and Gabri\'e, Marylou and Foreman-Mackey, Daniel",
    title = "{flowMC: Normalizing flow enhanced sampling package for probabilistic inference in JAX}",
    eprint = "2211.06397",
    archivePrefix = "arXiv",
    primaryClass = "astro-ph.IM",
    doi = "10.21105/joss.05021",
    journal = "J. Open Source Softw.",
    volume = "8",
    number = "83",
    pages = "5021",
    year = "2023"
}

@article{Gabrie:2021tlu,
    author = "Gabri\'e, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric",
    title = "{Adaptive Monte Carlo augmented with normalizing flows}",
    eprint = "2105.12603",
    archivePrefix = "arXiv",
    primaryClass = "physics.data-an",
    doi = "10.1073/pnas.2109420119",
    journal = "Proc. Nat. Acad. Sci.",
    volume = "119",
    number = "10",
    pages = "e2109420119",
    year = "2022"
}

This will help flowMC getting more recognition, and the main benefit for you is this means the flowMC community will grow and it will be continuously improved. If you believe in the magic of open-source software, please support us by attributing our software in your work.

flowMC is a Jax implementation of methods described in:

Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods Gabrié M., Rotskoff G. M., Vanden-Eijnden E. - ICML INNF+ workshop 2021 - pdf

Adaptive Monte Carlo augmented with normalizing flows. Gabrié M., Rotskoff G. M., Vanden-Eijnden E. - PNAS 2022 - doi, arxiv

flowmc's People

Contributors

kazewong avatar marylou-gabrie avatar qazalbash avatar tedwards2412 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

flowmc's Issues

Write testing functions for sampler module

There is no actual tests script for the package now. I found some instability since commit #453456d914d40176f924564000ded665cdfa96a7 on main

We need testing mostly on the sampler behavior.

Inconsistent and missing docstring on public methods.

Raising issue as part of JOSS review openjournals/joss-reviews#5021

All public methods should have docstrings describing their functionality. Some are missing. E.g.,

class HMC(LocalSamplerBase):
def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable:
super().__init__(logpdf, jit, params)
self.potential = lambda x: -self.logpdf(x)
self.grad_potential = jax.grad(self.potential)
self.params = params
if "condition_matrix" in params:
self.inverse_metric = params["condition_matrix"]
else:
self.inverse_metric = 1
if "step_size" in params:
self.step_size = params["step_size"]
if "n_leapfrog" in params:
self.n_leapfrog = params["n_leapfrog"]
else:
raise NotImplementedError
self.kinetic = lambda p, params: 0.5*(p**2 * params['inverse_metric']).sum()
self.grad_kinetic = jax.grad(self.kinetic)
def get_initial_hamiltonian(self, rng_key: jax.random.PRNGKey, position: jnp.array, params: dict):
momentum = jax.random.normal(rng_key, shape=position.shape) * params['inverse_metric'] **-0.5
return self.potential(position) + self.kinetic(momentum, params)

Nice have/suggestion: give short examples in the docstrings of how to use the object.

It would be also good to address inconsistent docstring formatting. E.g.,

class Sampler():
"""
Sampler class that host configuration parameters, NF model, and local sampler
Args:
n_dim (int): Dimension of the problem.
rng_key_set (Tuple): Tuple of random number generator keys.
local_sampler (Callable): Local sampler maker
sampler_params (dict): Parameters for the local sampler.
likelihood (Callable): Likelihood function.
nf_model (Callable): Normalizing flow model.
n_loop_training (int, optional): Number of training loops. Defaults to 2.
n_loop_production (int, optional): Number of production loops. Defaults to 2.
n_local_steps (int, optional): Number of local steps per loop. Defaults to 5.
n_global_steps (int, optional): Number of global steps per loop. Defaults to 5.
n_chains (int, optional): Number of chains. Defaults to 5.
n_epochs (int, optional): Number of epochs per training loop. Defaults to 5.
learning_rate (float, optional): Learning rate for the NF model. Defaults to 0.01.
max_samples (int, optional): Maximum number of samples fed to training the NF model. Defaults to 10000.
momentum (float, optional): Momentum for the NF model. Defaults to 0.9.
batch_size (int, optional): Batch size for the NF model. Defaults to 10.
use_global (bool, optional): Whether to use global sampler. Defaults to True.
logging (bool, optional): Whether to log the training process. Defaults to True.
nf_variable (None, optional): Mean and variance variables for the NF model. Defaults to None.
keep_quantile (float, optional): Quantile of chains to keep when training the normalizing flow model. Defaults to 0..
local_autotune (None, optional): Auto-tune function for the local sampler. Defaults to None.
train_thinning (int, optional): Thinning for the data used to train the normalizing flow. Defaults to 1.

and

class MLP(nn.Module):
"""
Multi-layer perceptron in Flax.
Parameters
----------
features : list of int
The number of features in each layer.
activation : callable
The activation function at each level
use_bias : bool
Whether to use bias in the layers.
init_weight_scale : float
The initial weight scale for the layers.
kernel_init : callable
The kernel initializer for the layers.
We use a gaussian kernel with a standard deviation of `init_weight_scale` by default.
"""

Contribution guide

Raising issue as part of JOSS review openjournals/joss-reviews#5021

Contribution guide is brief and assumes everyone is confident with how to use GitHub and make PRs.

Suggestions:

I also recommend adding details/ a link to the contribution guide in your documentation.

Slow compilation of RQSpline flow operations

Currently the training and sampling phases takes around 3.5 minutes on a A100 node with icelake CPU for RQSpline model with 10 layers, [128,128] conditioners.

See whether swapping make_flow with a scan can solve this compilation problem. This could very well be issue associated with how distrax form the bijector, in that case, there won't be much I can do.

Possible issue with the Sampler function in the tutorials and docs

Thanks for the code, this seems super cool, trying to have a go with it now. Run into this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [1], line 22
     19 MALA_Sampler = MALA(log_posterior, True, {"step_size": step_size})
     20 local_sampler_caller = lambda x: MALA_Sampler.make_sampler()
---> 22 nf_sampler = Sampler(n_dim,
     23                     rng_key_set,
     24                     local_sampler_caller,
     25                     {"step_size": step_size},
     26                     log_posterior,
     27                     model,
     28                     n_local_steps = 50,
     29                     n_global_steps = 50,
     30                     n_epochs = 30,
     31                     learning_rate = 1e-2,
     32                     batch_size = 1000,
     33                     n_chains = n_chains)
     35 nf_sampler.sample(initial_position)
     36 chains,log_prob,local_accs, global_accs = nf_sampler.get_sampler_state().values()

File ~/Library/Mobile Documents/com~apple~CloudDocs/PhD/Turing.jl/flowMC/src/flowMC/sampler/Sampler.py:71, in Sampler.__init__(self, n_dim, rng_key_set, local_sampler, likelihood, nf_model, n_loop_training, n_loop_production, n_local_steps, n_global_steps, n_chains, n_epochs, learning_rate, max_samples, momentum, batch_size, use_global, logging, nf_variable, keep_quantile, local_autotune, train_thinning)
     68 rng_key_init, rng_keys_mcmc, rng_keys_nf, init_rng_keys_nf = rng_key_set
     70 self.likelihood = likelihood
---> 71 self.likelihood_vec = jax.jit(jax.vmap(self.likelihood))
     72 self.local_sampler_class = local_sampler
     73 self.local_sampler = local_sampler.make_sampler()

File ~/miniforge3/envs/flowmc/lib/python3.9/site-packages/jax/_src/api.py:1647, in vmap(fun, in_axes, out_axes, axis_name, axis_size, spmd_axis_name)
   1512 def vmap(fun: F,
   1513          in_axes: Union[int, Sequence[Any]] = 0,
   1514          out_axes: Any = 0,
   1515          axis_name: Optional[Hashable] = None,
   1516          axis_size: Optional[int] = None,
   1517          spmd_axis_name: Optional[Hashable] = None) -> F:
   1518   """Vectorizing map. Creates a function which maps ``fun`` over argument axes.
   1519
   1520   Args:
   (...)
   1645   See the :py:func:`jax.pmap` docstring for more examples involving collectives.
   1646   """
-> 1647   _check_callable(fun)
   1648   docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
   1649             "but with additional array axes over which {fun} is mapped.")
   1650   if fun.__doc__:

File ~/miniforge3/envs/flowmc/lib/python3.9/site-packages/jax/_src/api.py:181, in _check_callable(fun)
    179   raise TypeError(f"staticmethod arguments are not supported, got {fun}")
    180 if not callable(fun):
--> 181   raise TypeError(f"Expected a callable value, got {fun}")
    182 if _isgeneratorfunction(fun):
    183   raise TypeError(f"Expected a function, got a generator function: {fun}")

TypeError: Expected a callable value, got {'step_size': 0.1}

Looking at the source code, it seems that Sampler's docstring + the tutorials etc. dont match the actual arguments meaning they error out, the sampler_params dict is missing from the function's arguments.

Thanks!

Acceptance ratio possibly calculated incorrectly for Hamiltonian Monte Carlo implementation

Raising issue as part of JOSS review openjournals/joss-reviews#5021

The logarithm of the acceptance ratio computed in the lines

momentum = jax.random.normal(key1, shape=position.shape) * params['inverse_metric']**-0.5
proposed_position, proposed_momentum = leapfrog_step(position, momentum, params)
proposed_ham = self.potential(proposed_position) + self.kinetic(proposed_momentum, params)
log_acc = H - proposed_ham

appears to compute the acceptance ratio as the difference between the Hamiltonian H before the momentum resampling step in line 70 and the Hamiltonian for the proposed state pair at the final point in the simulated trajectory. The momentum resampling step is itself a Markov kernel which leaves the joint distribution on the positions and momenta invariant, in particular a Gibbs sampling like step which samples independently from the conditional distribution under the target of the momentum given the position (here equal to the marginal distribution as the momentum and position are independent). There is then a second Metropolis-Hastings based Markov kernel which simulates the Hamiltonian dynamics forward in time from the current position-momentum state pair to a new state pair, with the logarithm of the acceptance ratio then being the difference between the Hamiltonians as the start and end points of the simulated trajectory (providing a volume preserving, time-reversible integrator is used). I think therefore that the H value used to compute log_acc therefore needs to be recalculated between the current lines 70 and 71 that is adding a line

    H = self.potential(position) + self.kinetic(momentum, params) 

To avoid recomputing the potential energy function (which will remain the same before and after the momentum resampling) you could store pass through the value of the potential energy function rather than H as input argument and return value of hmc_kernel.

No tests for correctness of components

Raising issue as part of JOSS review openjournals/joss-reviews#5021

The tests in the test modules in the test directory do not appear to have any assert statements checking that the components being tested behave as expected, but instead appear to be more integration tests that will only fail if there is an error in functions called. While such tests are useful, I would say it is also essential to have unit tests for correctness of the key individual components in the package that check for invariants being maintained, output from functions being of the expected types / shapes / values and so on.

A non-exhaustive list of some things that could be checked

  • The Markov kernels which make local proposals governed by a step-size parameter and with a Metropolis-Hastings accept step always accept / give an acceptance probability of 1, when the step size is zero corresponding to proposing to move to the current point. It could also be checked that the acceptance probability on average decreases monotonically as the step size is made larger.
  • The leapfrog integrator used in the Hamiltonian Monte Carlo implementation is time reversible.
  • The various Markov kernels produce Markov chains with empirical moments matching known analytic values to within some Monte Carlo error which scales inversely with the square root of the number of samples for a tractable known target distribution such as a standard normal (with the chains initialised from samples from the target distribution to avoid issues with initial convergence to stationarity)
  • The Markov kernels behave deterministically when passed random number generators with fixed initial states / seeds.
  • The normalizing flow models are able to learn a close approximation to a simple target distribution such as a multivariate normal distribution with random mean and covariance matrix.
  • The Jacobian log determinant iteratively computed in evaluating the normalizing flow model is consistent with the log determinant of the Jacobian of the forward map computed explicitly using the JAX jacobian differential operator.
  • The normalizing flow map (overall) and affine coupling components are bijections.

No separate requirements specified for tests

Raising issue as part of JOSS review openjournals/joss-reviews#5021

The tests seem to have additional dependencies beyond those of the flowMC package - for example test/test_normalizingFlow.py imports from sklearn, optax and matplotlib, none of which are specified in the requirements in the README.md or setup.cfg files.

It would be useful to document the additional requirements for running the tests - for example by adding an [options.extras_require] section to the setup.cfg file with the additional test dependencies and/or listing the dependencies for running the tests in the README.md file.

Whitening

Add option in RealNVP to have a mean and covariance for the base distribution

  1. Add keyword argument to the class, with default to standard Gaussian
  2. Modify sample method
  3. Modify logprob method

Put short examples in generated functions docs

Currently, none of the generated function in the local samplers has usage example. For instance, the generated function in the following line does not have documentation to show how to use the kernel.

def mala_sampler(rng_key, n_steps, initial_position):
logp = lp(initial_position)
n_chains = rng_key.shape[0]
acceptance = jnp.zeros((n_chains, n_steps))
all_positions = (jnp.zeros((n_chains, n_steps) + initial_position.shape[-1:])) + initial_position[:, None]
all_logp = (jnp.zeros((n_chains, n_steps)) + logp[:, None])
state = (rng_key, all_positions, all_logp, acceptance, self.params)
for i in tqdm(range(1, n_steps)):
state = mala_update(i, state)
return state[:-1]

It would be nice to add documentation to these generated functions.

Add autotuning and production loops to the Sampler

Currently, the user has to tune the local sampler. Since the tuning is basically looking at the acceptance rate and changing the step size, there should be a simple tuning loop we can write for the Sampler before training and production run.

Also currently, we rely on the user to not pick the samples across multiple training loop to enforce detail balance. Add a production run loop after tuning in the Sampler class to generate detail balanced chains by default.

The pseudo-code of the sampler should looks something like:

Local tuning
Global tuning (training and sampling)
Production
output

Dependencies pinning

Jax has a tendency to roll unstable breaking changes pretty frequently. We should pin the version of jax.

Currently 0.4.1 is stable

Normalizing flow with Flax seems slower on A100 than Quadro6000

Tested this behaviour on a workstation with a Quadro6000 vs a cluster node with A100.

The current program seems to be pretty CPU bound, which probably cause this behaviour, since the A100 node's CPU has lower clock speed.

Need to figure out how to pack more computation on the GPU

Progress bar to be moved from inner loops to outer loops

Currently the tqdm runs on the local sampling steps and training steps, and no information is provided on the completion of the outer loops. It would more useful to the user to have the info on the progress at the level of the outer loops.

Unify the API for the different model of flows

  • RealNVP and RQSpline don't take their arguments in the same order (especially flipping the two first)
  • RQSpline take a list hidden widths while RealNVP takes only one, imposing also the depth of the MLPs to 1-hidden layer.
  • It would be useful to provide default values to model parameters.

The tutorials would need to be updated after such changes.

Tests fail on Python 3.8 and 3.9 due to upstream issue in `distrax`

Raising issue as part of JOSS review openjournals/joss-reviews#5021

Running the tests in clean Python 3.8 or 3.9 environments I get a series of

TypeError: Subscripted generics cannot be used with class and instance checks

errors which appear to be due to google-deepmind/distrax/issues/224. It looks like the most recent 0.1.3 release of distrax on PyPI fixes these issues as manually updating to that version and rerunning the tests in the Python 3.8 and Python 3.9 environments fixes the errors, so just bumping the pinned version of distrax to 0.1.3 in setup.cfg should fix.

Make command line tools to recommend diagnostics

I think it would be beneficial to have a command line tool that performs the list of checks we usually do and flags warnings.

For example, throughout helping people with their problems, there are a number of checks one can usually do:

  1. Is there nan in the log_prob?
  2. Is the local/global sampler accepting too many or too few?
  3. Is there a huge discrepancy in the log_prob?

These are essentially running the diagnostics suggested on the FAQ page, but I think it would be nice to have a button where the user can hit and get recommendation on what to do.

No documentation of how to run tests

Raising issue as part of JOSS review openjournals/joss-reviews#5021

There are what appear to be a set of tests in the test directory at the root of the repository but there are no details given that I can see of how these are expected to be run - for example is it expected that pytest is used to run the tests?

Slow NF learning

Try running the examples and check the NF samples. It doesn't train fast enough. Need to experiment a bit in terms of default parameters

Add parallelization over multiple devices

Currently the code runs on one device, which doesn't allow scaling to larger computational network such as TPU pods.

Parallelizing over local sampler should be relatively simple, since that does not required communication between devices. Note that if single evaluation of the likelihood demands more RAM than what's available on the chips (TPUv4 has 8GB RAM per core, gradient of functions may cause problem), the computation may need to be shard to multiple device, but that should be taken care separately.

Evaluation of global sampler should be similar to local sampler.

Training the normalizing flow requires collecting data from multiple devices and updating weights in a somewhat sync version. Have a look of pmap to see how to deal with that https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

Automated test workflows

Raising issue as part of JOSS review openjournals/joss-reviews#5021

Issue: You should set up a GitHub action to automatically run your tests for PRs and commits to master, report code-coverage.

Nice have: automated workflow to build your documentation, automated workflow to push package releases to PyPI.

Implementing HMC as a local sampler

I found the exploration power of MALA to be limited, especially in the case where different parameters may need different step size.

I am not sure if HMC is the answer, but having more capable sampler is always cool, if not useful

Benchmark for normalizing flow example included compilation time

Currently, the examples of training the normalizing flow on readthedoc include compilation time, which is not specified and may confuse some readers. Also, it seems really slow, it could be problem-related to Mac's compatibility issue.

Fixes: Warm up the flow before actually training it.

Compiling local sampler is slow

Compiling the sampler for complicated likelihood function seems pretty slow (Testing on a gravitational wave example now.)

I think this is related to defining the main loop inside the sampler.

Experimenting with abstracting that out for performance

Packaging

  1. Write wrapper for local sampler.
  2. Bundle local sampler and global sampler together
  3. Provide functionality for user to supply their own likelihood or local sampler
  4. Test Jaxopt
  5. Write notebooks for explaining the idea.

Suggested changes to text of JOSS paper

Raising issue as part of JOSS review openjournals/joss-reviews#5021

Some suggested changes to the text of the JOSS paper:

However the estimation of models'parameters
However the estimation of models' parameters
(missing space between models' and parameters)

A common strategy to explore parameter space is to sample through a Markov Chain Monte Carlo (MCMC).
A common strategy to explore a model's parameter space is to approximately sample the posterior distribution on the parameters with a Markov Chain Monte Carlo (MCMC) method.
(Markov chain Monte Carlo is a descriptor of a class of algorithms and wouldn't usually be used as an object grammatically; MCMC methods sample from a distribution so helpful to specify what distribution this is)

Yet even MCMC methods can struggle to faithfully represent the parameter space when only relying on local updates.
Yet even MCMC methods can struggle to faithfully represent the posterior distribution on the parameter space when only relying on local updates.
(the Markov chain samples generated by an MCMC method represent a distribution on the parameter space, typically a posterior distribution in a Bayesian inference context, not the parameter space itself)

FlowMC is a Python library for accelerated Markov Chain Monte Carlo (MCMC)
FlowMC is a Python library implementing accelerated MCMC [methods/algorithms]
(no need to repeat acronym definition and again better to use MCMC as a qualifier)

At its core, FlowMC uses a local sampler and a learnable global sampler in tandem to efficiently sample posterior distributions
At its core, FlowMC combines Metropolis-Hastings based Markov kernels proposing local moves and global moves (using a learned approximation to the target distribution) to efficiently sample posterior distributions with non-trivial geometry
(to me local and global sampler are too vague and in particular global sampler is a bit misleading as the moves proposed using the normalizing flow model are still accepted or rejected within a Metropolis acceptance step meaning the moves are still 'local' when the proposals are rejected)

While multiple chains of the local sampler generate samples over the region of interest in the target parameter space, the package uses these samples to train a normalizing flow model, then use it to propose global jumps across the parameter space.
While multiple chains using Markov kernels with local moves generate approximate samples from the target posterior distribution on the parameter space, the package uses these samples to train a normalizing flow model approximating the target distribution, and uses this approximation to propose global jumps across the parameter space within a Metropolis independence sampler.

MALAMetropolis adjusted Langevin algorithm (MALA)
GPUsgraphics processing units (GPUs)
TPUstensor processing units (TPUs)
SIMDsingle instruction multiple data (SIMD)
(all abbreviations should be defined on first usage)

As soon as the dimension of exceeds 3 or 4, it is necessary to resort to a robust sampling strategy such as a MCMC

I would not be so specific here as actually I would say MCMC methods are not necessarily the best choice for target distributions on spaces with dimension 5 with quadrature methods, rejection sampling and importance sampling methods all being potentially viable in these sort of dimensionalities. I would say something more vague like In high dimensions or For parameters spaces with more than a few dimensions would be better.

JaxJAX (capitalisation should be consistent throughout)

GPU and TPU supportsGPU and TPU support

The entire algorithm belongs to the class of adaptive MCMCsThe entire algorithm belongs to the class of adaptive MCMC methods

the chains previous stepsthe chains' previous steps

Use of AcceleratorUse of Accelerators

leverage Just-In-Timeleverage JIT (abbreviation has been previously defined so should be used for subsequent usages)

the log-posterior functiona function to evaluate the logarithm of the (unnormalized) density of the posterior distribution of interest

`progress_bar_scan` credit

Raising issue as part of JOSS review openjournals/joss-reviews#5021

The progress_bar_scan, i.e.,

def progress_bar_scan(num_samples, message=None):
"Progress bar for a JAX scan"
if message is None:
message = f"Running for {num_samples:,} iterations"
tqdm_bars = tqdm(range(num_samples))
tqdm_bars.set_description(message)

is seemingly adapted from this (excellent) blog post https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/

There is no problem in using this. But you should acknowledge credit to this in the decorator's docstring.

Installation fails in Python 3.11 environment due to TensorFlow dependency

Raising issue as part of JOSS review openjournals/joss-reviews/issues/5021

The README.md file lists the Python version requirements as Python 3.8+ and the setup.cfg file has python_requires = >=3.7 however on trying to install flowmc using pip into a Python 3.11 environment, installation fails with a ERROR: Failed building wheel for tensorstore error. This appears to be due to wheels only being available for TensorFlow (and related dependencies) for Python versions 3.7 to 3.10 currently. Given Python 3.11 is now in stable release it would be helpful to put more specific Python version requirements in the README.md and setup.cfg files. It would also make sense to have the minimum Python version in the README.md and setup.cfg files match (currently respectively 3.8 and 3.7) - as the latest versions of NumPy and JAX now support Python 3.8+ in line with NEP 29, a minimum of Python 3.8 might make sense.

Parsing arguments other than x into log-posterior probability functions

Hi guys, thanks for the really cool code that I'm currently hoping to implement into my own workflow!

Is it possible to parse other arguments other than x into log-posterior probability functions? This is particularly important for when attempting to sample a distribution which is resultant on comparing observed data to a generated model that is dependent on the sampled parameters.

As an example, the docs specify that a potential target distribution might be:

def log_posterior(x):
    return -0.5 * jnp.sum(x ** 2)

but it seems impossible to be able to write a posterior function that compares to some observed data. As a very simple example:

n_dim = 2

data_x = np.arange(1,10)
m = 2
c = 5
observation = m*data_x + c

def log_posterior(x, observation):
    x_model = np.arange(1,10)
    m = x[0]
    c = x[1]
    y = m*x_model+c
    gaussian = -((y-observation)**2)/(2*np.sqrt(observation)**2)
    return jnp.sum(gaussians)

Is this a possible posterior function that just isn't recorded in the docs? I appreciate that it would be hypothetically possible to "generate" or load the observed data within the posterior function but this becomes much less practically feasible when generating the data becomes computationally expensive and it makes sense to do this operation only once.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.