Git Product home page Git Product logo

numpyro's Introduction

Build Status Documentation Status Latest Version

NumPyro

Probabilistic programming powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

Docs and Examples | Forum


What is NumPyro?

NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU. NumPyro is under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.

NumPyro is designed to be lightweight and focuses on providing a flexible substrate that users can build on:

  • Pyro Primitives: NumPyro programs can contain regular Python and NumPy code, in addition to Pyro primitives like sample and param. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the example below.
  • Inference algorithms: NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include MixedHMC (which can accommodate discrete latent variables) as well as HMCECS (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose jit and grad to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using Iterative NUTS). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see TraceGraph_ELBO and TraceEnum_ELBO).
  • Distributions: The numpyro.distributions module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's functional pseudo-random number generator. The design of the distributions module largely follows from PyTorch. A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in torch.distributions. In addition to distributions, constraints and transforms are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability (TFP) can directly be used in NumPyro models.
  • Effect handlers: Like Pyro, primitives like sample and param can be provided nonstandard interpretations using effect-handlers from the numpyro.handlers module, and these can be easily extended to implement custom inference algorithms and inference utilities.

A Simple Example - 8 Schools

Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools.

The data is given by:

>>> import numpy as np

>>> J = 8
>>> y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
>>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

, where y are the treatment effects and sigma the standard error. We build a hierarchical model for the study where we assume that the group-level parameters theta for each school are sampled from a Normal distribution with unknown mean mu and standard deviation tau, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by theta (true effect) and sigma, respectively. This allows us to estimate the population-level parameters mu and tau by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level theta parameters.

>>> import numpyro
>>> import numpyro.distributions as dist

>>> # Eight Schools example
... def eight_schools(J, sigma, y=None):
...     mu = numpyro.sample('mu', dist.Normal(0, 5))
...     tau = numpyro.sample('tau', dist.HalfCauchy(5))
...     with numpyro.plate('J', J):
...         theta = numpyro.sample('theta', dist.Normal(mu, tau))
...         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the extra_fields argument in MCMC.run. By default, we only collect samples from the target (posterior) distribution when we run inference using MCMC. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the extra_fields argument. For a list of possible fields that can be collected, see the HMCState object. In this example, we will additionally collect the potential_energy for each sample.

>>> from jax import random
>>> from numpyro.infer import MCMC, NUTS

>>> nuts_kernel = NUTS(eight_schools)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

We can print the summary of the MCMC run, and examine if we observed any divergences during inference. Additionally, since we collected the potential energy for each of the samples, we can easily compute the expected log joint density.

>>> mcmc.print_summary()  # doctest: +SKIP

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.14      3.18      3.87     -0.76      9.50    115.42      1.01
       tau      4.12      3.58      3.12      0.51      8.56     90.64      1.02
  theta[0]      6.40      6.22      5.36     -2.54     15.27    176.75      1.00
  theta[1]      4.96      5.04      4.49     -1.98     14.22    217.12      1.00
  theta[2]      3.65      5.41      3.31     -3.47     13.77    247.64      1.00
  theta[3]      4.47      5.29      4.00     -3.22     12.92    213.36      1.01
  theta[4]      3.22      4.61      3.28     -3.72     10.93    242.14      1.01
  theta[5]      3.89      4.99      3.71     -3.39     12.54    206.27      1.00
  theta[6]      6.55      5.72      5.66     -1.43     15.78    124.57      1.00
  theta[7]      4.81      5.95      4.19     -3.90     13.40    299.66      1.00

Number of divergences: 19

>>> pe = mcmc.get_extra_fields()['potential_energy']
>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  # doctest: +SKIP
Expected log joint density: -54.55

The values above 1 for the split Gelman Rubin diagnostic (r_hat) indicates that the chain has not fully converged. The low value for the effective sample size (n_eff), particularly for tau, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a non-centered paramaterization for tau in our model. This is straightforward to do in NumPyro by using a TransformedDistribution instance together with a reparameterization effect handler. Let us rewrite the same model but instead of sampling theta from a Normal(mu, tau), we will instead sample it from a base Normal(0, 1) distribution that is transformed using an AffineTransform. Note that by doing so, NumPyro runs HMC by generating samples theta_base for the base Normal(0, 1) distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good!

>>> from numpyro.infer.reparam import TransformReparam

>>> # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered(J, sigma, y=None):
...     mu = numpyro.sample('mu', dist.Normal(0, 5))
...     tau = numpyro.sample('tau', dist.HalfCauchy(5))
...     with numpyro.plate('J', J):
...         with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
...             theta = numpyro.sample(
...                 'theta',
...                 dist.TransformedDistribution(dist.Normal(0., 1.),
...                                              dist.transforms.AffineTransform(mu, tau)))
...         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

>>> nuts_kernel = NUTS(eight_schools_noncentered)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
>>> mcmc.print_summary(exclude_deterministic=False)  # doctest: +SKIP

                   mean       std    median      5.0%     95.0%     n_eff     r_hat
           mu      4.08      3.51      4.14     -1.69      9.71    720.43      1.00
          tau      3.96      3.31      3.09      0.01      8.34    488.63      1.00
     theta[0]      6.48      5.72      6.08     -2.53     14.96    801.59      1.00
     theta[1]      4.95      5.10      4.91     -3.70     12.82   1183.06      1.00
     theta[2]      3.65      5.58      3.72     -5.71     12.13    581.31      1.00
     theta[3]      4.56      5.04      4.32     -3.14     12.92   1282.60      1.00
     theta[4]      3.41      4.79      3.47     -4.16     10.79    801.25      1.00
     theta[5]      3.58      4.80      3.78     -3.95     11.55   1101.33      1.00
     theta[6]      6.31      5.17      5.75     -2.93     13.87   1081.11      1.00
     theta[7]      4.81      5.38      4.61     -3.29     14.05    954.14      1.00
theta_base[0]      0.41      0.95      0.40     -1.09      1.95    851.45      1.00
theta_base[1]      0.15      0.95      0.20     -1.42      1.66   1568.11      1.00
theta_base[2]     -0.08      0.98     -0.10     -1.68      1.54   1037.16      1.00
theta_base[3]      0.06      0.89      0.05     -1.42      1.47   1745.02      1.00
theta_base[4]     -0.14      0.94     -0.16     -1.65      1.45    719.85      1.00
theta_base[5]     -0.10      0.96     -0.14     -1.57      1.51   1128.45      1.00
theta_base[6]      0.38      0.95      0.42     -1.32      1.82   1026.50      1.00
theta_base[7]      0.10      0.97      0.10     -1.51      1.65   1190.98      1.00

Number of divergences: 0

>>> pe = mcmc.get_extra_fields()['potential_energy']
>>> # Compare with the earlier value
>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  # doctest: +SKIP
Expected log joint density: -46.09

Note that for the class of distributions with loc,scale parameters such as Normal, Cauchy, StudentT, we also provide a LocScaleReparam reparameterizer to achieve the same purpose. The corresponding code will be

with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
    theta = numpyro.sample('theta', dist.Normal(mu, tau))

Now, let us assume that we have a new school for which we have not observed any test scores, but we would like to generate predictions. NumPyro provides a Predictive class for such a purpose. Note that in the absence of any observed data, we simply use the population-level parameters to generate predictions. The Predictive utility conditions the unobserved mu and tau sites to values drawn from the posterior distribution from our last MCMC run, and runs the model forward to generate predictions.

>>> from numpyro.infer import Predictive

>>> # New School
... def new_school():
...     mu = numpyro.sample('mu', dist.Normal(0, 5))
...     tau = numpyro.sample('tau', dist.HalfCauchy(5))
...     return numpyro.sample('obs', dist.Normal(mu, tau))

>>> predictive = Predictive(new_school, mcmc.get_samples())
>>> samples_predictive = predictive(random.PRNGKey(1))
>>> print(np.mean(samples_predictive['obs']))  # doctest: +SKIP
3.9886456

More Examples

For some more examples on specifying models and doing inference in NumPyro:

Pyro users will note that the API for model specification and inference is largely the same as Pyro, including the distributions API, by design. However, there are some important core differences (reflected in the internals) that users should be aware of. e.g. in NumPyro, there is no global parameter store or random state, to make it possible for us to leverage JAX's JIT compilation. Also, users may need to write their models in a more functional style that works better with JAX. Refer to FAQs for a list of differences.

Overview of inference algorithms

We provide an overview of most of the inference algorithms supported by NumPyro and offer some guidelines about which inference algorithms may be appropriate for different classes of models.

MCMC

  • NUTS, which is an adaptive variant of HMC, is probably the most commonly used inference algorithm in NumPyro. Note that NUTS and HMC are not directly applicable to models with discrete latent variables, but in cases where the discrete variables have finite support and summing them out (i.e. enumeration) is tractable, NumPyro will automatically sum out discrete latent variables and perform NUTS/HMC on the remaining continuous latent variables. As discussed above, model reparameterization may be important in some cases to get good performance. Note that, generally speaking, we expect inference to be harder as the dimension of the latent space increases. See the bad geometry tutorial for additional tips and tricks.
  • MixedHMC can be an effective inference strategy for models that contain both continuous and discrete latent variables.
  • HMCECS can be an effective inference strategy for models with a large number of data points. It is applicable to models with continuous latent variables. See here for an example.
  • BarkerMH is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables.
  • HMCGibbs combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
  • DiscreteHMCGibbs combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
  • SA is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a very large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see restrictions). Enumerated sites need to be marked with infer={'enumerate': 'parallel'} like in the annotation example.

Nested Sampling

Stochastic variational inference

  • Variational objectives
    • Trace_ELBO is our basic ELBO implementation.
    • TraceMeanField_ELBO is like Trace_ELBO but computes part of the ELBO analytically if doing so is possible.
    • TraceGraph_ELBO offers variance reduction strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables.
    • TraceEnum_ELBO offers variable enumeration strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables when enumeration is possible.
  • Automatic guides (appropriate for models with continuous latent variables)
    • AutoNormal and AutoDiagonalNormal are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing.
    • AutoMultivariateNormal and AutoLowRankMultivariateNormal also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting.
    • AutoDelta is used for computing point estimates via MAP (maximum a posteriori estimation). See here for example usage.
    • AutoBNAFNormal and AutoIAFNormal offer flexible variational distributions parameterized by normalizing flows.
    • AutoDAIS is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
    • AutoSurrogateLikelihoodDAIS is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.
    • AutoSemiDAIS constructs a posterior approximation like AutoDAIS for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
    • AutoLaplaceApproximation can be used to compute a Laplace approximation.

Stein Variational Inference

See the docs for more details.

Installation

Limited Windows Support: Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this JAX issue for more details. Alternatively, you can install Windows Subsystem for Linux and use NumPyro on it as on a Linux system. See also CUDA on Windows Subsystem for Linux and this forum post if you want to use GPUs on Windows.

To install NumPyro with the latest CPU version of JAX, you can use pip:

pip install numpyro

In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known compatible CPU version of JAX with

pip install numpyro[cpu]

To use NumPyro on the GPU, you need to install CUDA first and then use the following pip command:

pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

If you need further guidance, please have a look at the JAX GPU installation instructions.

To run NumPyro on Cloud TPUs, you can look at some JAX on Cloud TPU examples.

For Cloud TPU VM, you need to setup the TPU backend as detailed in the Cloud TPU VM JAX Quickstart Guide. After you have verified that the TPU backend is properly set up, you can install NumPyro using the pip install numpyro command.

Default Platform: JAX will use GPU by default if CUDA-supported jaxlib package is installed. You can use set_platform utility numpyro.set_platform("cpu") to switch to CPU at the beginning of your program.

You can also install NumPyro from source:

git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev]  # contains additional dependencies for NumPyro development

You can also install NumPyro with conda:

conda install -c conda-forge numpyro

Frequently Asked Questions

  1. Unlike in Pyro, numpyro.sample('x', dist.Normal(0, 1)) does not work. Why?

    You are most likely using a numpyro.sample statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key (PRNGKey) to generate samples from. NumPyro's inference algorithms use the seed handler to thread in a random number generator key, behind the scenes.

    Your options are:

    • Call the distribution directly and provide a PRNGKey, e.g. dist.Normal(0, 1).sample(PRNGKey(0))

    • Provide the rng_key argument to numpyro.sample. e.g. numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0)).

    • Wrap the code in a seed handler, used either as a context manager or as a function that wraps over the original callable. e.g.

      with handlers.seed(rng_seed=0):  # random.PRNGKey(0) is used
          x = numpyro.sample('x', dist.Beta(1, 1))    # uses a PRNGKey split from random.PRNGKey(0)
          y = numpyro.sample('y', dist.Bernoulli(x))  # uses different PRNGKey split from the last one

      , or as a higher order function:

      def fn():
          x = numpyro.sample('x', dist.Beta(1, 1))
          y = numpyro.sample('y', dist.Bernoulli(x))
          return y
      
      print(handlers.seed(fn, rng_seed=0)())
  2. Can I use the same Pyro model for doing inference in NumPyro?

    As you may have noticed from the examples, NumPyro supports all Pyro primitives like sample, param, plate and module, and effect handlers. Additionally, we have ensured that the distributions API is based on torch.distributions, and the inference classes like SVI and MCMC have the same interface. This along with the similarity in the API for NumPy and PyTorch operations ensures that models containing Pyro primitive statements can be used with either backend with some minor changes. Example of some differences along with the changes needed, are noted below:

    • Any torch operation in your model will need to be written in terms of the corresponding jax.numpy operation. Additionally, not all torch operations have a numpy counterpart (and vice-versa), and sometimes there are minor differences in the API.
    • pyro.sample statements outside an inference context will need to be wrapped in a seed handler, as mentioned above.
    • There is no global parameter store, and as such using numpyro.param outside an inference context will have no effect. To retrieve the optimized parameter values from SVI, use the SVI.get_params method. Note that you can still use param statements inside a model and NumPyro will use the substitute effect handler internally to substitute values from the optimizer when running the model in SVI.
    • PyTorch neural network modules will need to rewritten as stax, flax, or haiku neural networks. See the VAE and ProdLDA examples for differences in syntax between the two backends.
    • JAX works best with functional code, particularly if we would like to leverage JIT compilation, which NumPyro does internally for many inference subroutines. As such, if your model has side-effects that are not visible to the JAX tracer, it may need to rewritten in a more functional style.

    For most small models, changes required to run inference in NumPyro should be minor. Additionally, we are working on pyro-api which allows you to write the same code and dispatch it to multiple backends, including NumPyro. This will necessarily be more restrictive, but has the advantage of being backend agnostic. See the documentation for an example, and let us know your feedback.

  3. How can I contribute to the project?

    Thanks for your interest in the project! You can take a look at beginner friendly issues that are marked with the good first issue tag on Github. Also, please feel to reach out to us on the forum.

Future / Ongoing Work

In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:

  • Improving robustness of inference on different models, profiling and performance tuning.
  • Supporting more functionality as part of the pyro-api generic modeling interface.
  • More inference algorithms, particularly those that require second order derivatives or use HMC.
  • Integration with Funsor to support inference algorithms with delayed sampling.
  • Other areas motivated by Pyro's research goals and application focus, and interest from the community.

Citing NumPyro

The motivating ideas behind NumPyro and a description of Iterative NUTS can be found in this paper that appeared in NeurIPS 2019 Program Transformations for Machine Learning Workshop.

If you use NumPyro, please consider citing:

@article{phan2019composable,
  title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
  author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
  journal={arXiv preprint arXiv:1912.11554},
  year={2019}
}

as well as

@article{bingham2019pyro,
  author    = {Eli Bingham and
               Jonathan P. Chen and
               Martin Jankowiak and
               Fritz Obermeyer and
               Neeraj Pradhan and
               Theofanis Karaletsos and
               Rohit Singh and
               Paul A. Szerlip and
               Paul Horsfall and
               Noah D. Goodman},
  title     = {Pyro: Deep Universal Probabilistic Programming},
  journal   = {J. Mach. Learn. Res.},
  volume    = {20},
  pages     = {28:1--28:6},
  year      = {2019},
  url       = {http://jmlr.org/papers/v20/18-403.html}
}

numpyro's People

Contributors

amifalk avatar brendancooley avatar deoxyribose avatar eb8680 avatar elchorro avatar fehiepsi avatar freddyaboulton avatar fritzo avatar hessammehr avatar jpchen avatar juanitorduz avatar kylejcaron avatar lumip avatar marcogorelli avatar martinjankowiak avatar neerajprad avatar olaronning avatar omarfsosa avatar ordabayevy avatar pierreglaser avatar quattro avatar raulpl avatar tare avatar tcbegley avatar theorashid avatar tillahoffmann avatar tuannguyen27 avatar vanamsterdam avatar xidulu avatar yayami3 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

numpyro's Issues

Avoid compiling 2 times in HMC

Currently, we do compiling 2 times: 1 at init, 1 at sample, so it would be better if we only compile 1 time.

One way is to modify tscan with additional parameters: skip to skip first warmup_steps; put warmup_state to a lax.cond to decide if it will be updated or not (depending on we are in warmup phase or sampling phase).

Use jax master for CI

Lets host the jax/jaxlib wheels for CI so that we can track JAX master directly for development.

Checklist for first release

  • Add sphinx to generate documentation.
  • Add docstring to existing modules.
  • Update README - what is currently supported, some small examples, and future plans.
    • Add pypi version badge to readme
    • Add readthedocs badge to readme
  • Include examples in unit-test to ensure compatibility with release.
  • A couple of good notebooks with explanations (not for benchmarking).
    • local global trend. @fehiepsi
    • bayesian regression. @neerajprad
    • bayesian NNs. @martinjankowiak who has volunteered to contribute a Bayesian NN tutorial! Not a release blocker though, we can publish it whenever it is ready.
  • Clean up minor API issues - #144
  • Register docs on readthedocs (or publish on github.io or somewhere public)
    https://numpyro.readthedocs.io/en/latest/distributions.html

Already Done:

  • A few starter issues for external contributors.
  • More interesting examples:
    • SVI - maybe a VAE or SS-VAE example with MNIST.
    • HMC - baseball, hierarchical modeling, bayesian regression.
  • Better coverage for popular distributions - #1
  • Working HMC, NUTS - doesn't have to be a fully optimized implementation.
  • Benchmarking HMC, NUTS vis-a-vis Pyro, Stan, PyMC3

Make MCMC diagnostics available in Numpyro

Right now we don't have any diagnostics to measure the quality of mixing. It will be nice to either implement diagnostics like we have done in Pyro, or perhaps just use the arviz library for this (preferable to offload this to arviz if possible).

Add logits support to discrete distributions

Since we use float32 values by default instead of float64 like scipy.stats does, having support for logits in multinomial, bernoulli and binomial distributions is important, and will make operations that manipulate probabilities more numerically stable.

Benchmark HMC/NUTS on various models against popular frameworks

(split from #29 to track down the progress)

Here is a list of various models to benchmark. @neerajprad Please add more if you find they are needed.

  • Logistic regression on covtype dataset. Comparing against performance of PyMC/Stan/Edward which is reported in Simple, Distributed, and Accelerated Probabilistic Programming. @fehiepsi
  • Semi-supervised HMM model. Comparing against Pyro and Stan @fehiepsi
  • Baseball. Comparing against Pyro and Stan
  • Bayesian regression. Comparing against Pyro.
  • A hierarchy model
  • Time series model: Local global trend model
  • Move change point detection model to examples. It is a great example to show grad can pass through np.where in JAX. It took a bit of effort to achieve the same behaviour in other frameworks. It is better to make a notebook when multi-chain is supported, where we can point out that there are some initializations which make chains failed.

Follow-up issues for HMC

From discussion in #70:

  • Uniform handling of packing / unpacking pytree values. I think we can just flatten everything in the kernel itself and unflatten the result received. Alternatively (if this is relatively cheap), we can also just let the adapter function take in z and do the flattening. The latter will be better for debuggability, and I'm not sure if there is a significant overhead under JIT.
  • The current hmc_kernel can have a separate get_kernel method that would return either an HMC or NUTS kernel depending on args. This can be called by the user or from inside mcmc.
  • Default kinetic energy function. This shouldn't be needed as an arg from the user, at least not while we are using euclidean KE.
  • Pyro style model/guide syntax and wrapper utilities so that users don't need to specify the PE computation explicitly.
  • Clamp probability values (also, provide the logit parametrization), which will really help stabilize HMC trajectories as pointed out by @fehiepsi.
  • Correct the behaviour of find_reasonable_step_size
  • Either drop z_grad, potential_grad from HMCState or modify the behaviour of lax.scan to only scan z by default. @fehiepsi
  • Add context manager to convert lax.cond to lax.while. @neerajprad unfortunately, using while_loop requires initial_state has the same format as the output; but we don't know the output format without evaluating one of true or false functions.
  • Investigate why test_unnormalized_normal keeps failing in @fehiepsi's system.

[discussion] use functional approach for SVI

Currently, SVI class just holds constant properties. It is much like a function IMO. So I create this thread to discuss about the possibility of splitting SVI's methods into functions: svi_init (to get initial svi_state/opt_state), svi_step (to update svi_state/opt_state). To avoid repeated arguments such as model, guide,... users can use functools.partial or we can keep the object SVI class as a wrapper.

I list here some advantages which I have in mind:

  • We can easily choose jit or non-jit in the main application.
  • We can jit the whole svi loops by using lax.fori_loop (as in this lax's example).
  • The SVI class will be cleaner.
  • We can define initial opt_state outside of svi loops and we can freely choose to use pyro.param or not. This will be helpful when params come from jax.experimental.stax.

Note that this is just for discussion, not a request to change. We'd better focus on adding more distributions and benchmarking. Some initial tests suggest that dispatch overhead is large but it scales pretty well.

Do not use private utility functions from JAX

For convenience, we are currently using many JAX functions that haven't been exposed publicly like _promote_args. Let us reimplement (and customize these functions as needed) in util.py, so that we don't incur dependency on private functions that might be removed or whose API changed without notice.

API issues to clean up before release

This is to track minor API issues as we notice them, that will be nice to clean up. This is not super urgent - these are known things that can be cleaned up before release.

  • It will be nice to see other diagnostic information like acceptance rate, and step size, since that's the easiest thing to make note of if the model is misspecified. @neerajprad
  • Support initialization from Uniform(-2, 2) directly on unconstrained parameters like in Stan.
  • Renaming: Potential candidates for renaming (up for discussion)
    • heuristic_step_size: The name makes it appear as though it is a float arg and not a boolean arg. Don't have a better idea at the moment though.
    • num_warmup_steps --> num_warmup or warmup_steps
  • Certain functions need args, kwargs to be tuples, dicts so that the args / kwargs repacking does not result in jit recompilation. However, this style isn't very pythonic otherwise, and should be avoided when there is no need to JIT. So lets change initialize_model(rng, model, model_args, model_kwargs) to initialize_model(rng, model, *args, **kwargs).
  • If the model or potential_fn is not jittable, I think we'll end up throwing an exception given than sample_kernel has a jit decorator. I think the right approach would be to pass a jit_compile=True flag to hmc so that the user does not need to modify the source code in case their model is not jittable.
  • Use default args such as loc=0, scale=1 in distributions
  • Rename args in Pareto for consistence.

Revise standard_gamma sampler

Based on recent examples of @neerajprad for translation_rule of custom primitive functions, it seems that we can make standard_gamma jittable. Let's update standard_gamma with

  • translation_rule
  • using lax.cond
  • using various utilities from lax instead of relying on numpy

Convenient way to turn off distributions arg checks

With #53, we can turn off distributions args checking with a contextmanager called validation_disabled(). Disabling is required because the args checking doesn't work with JIT (which isn't itself an issue since we anyways shouldn't need to JIT these extra debugging related features).

The current context manager is a hack because it is a global flag, and we might call the distributions' logpdf method after the distribution is initialized which may result in the flag being applied incorrectly. Given scipy's inheritance pattern and use of frozen vs non-frozen instances, I couldn't see an easy way of storing this flag in the constructor (like we do in PyTorch), but it is worth addressing this as we clean up the interface.

Support float64 array

Currently, we only develop and test for float32 array. When the functionalities are ready, we should test for float64 array too.

Remove dependency on scipy.stats for distributions

Due to issues faced in #51, #50 and #81, it seems that we would be better off not inheriting from scipy and having to work around their class hierarchy which was designed to support its own use cases. Instead, I think it will help if our design was closer to PyTorch, in terms of Pyro <--> Numpyro compatibility (#66), and more generally reuse of functionality like constraint checks and bijective transforms that are useful for HMC and SVI.

I think to begin with we can start by only supporting sample, log_prob methods and adding more as we go along. It should be relatively straightforward to do so for our existing implementations. In cases where we don't have a JAX sampler available, we could just use scipy's sampler internally (as suggested by @fehiepsi).

Collections of tips while using jax

JAX is cool but docs/examples of its features are still lacking. Let's share tips while using jax here. :)

pytree

pytree is a dict/tuple/array or a combination of them. Many transformations of jax works with pytree arguments. For example, we can apply grad(f)(x) will return a dict of grad arrays if x is a dict of arrays. There are also many utilities available at jax.tree_util such as

  • tree_map: as in the below example of lax.scan,
  • tree_multimap: as in the implementation of velocity_verlet.

lax.scan

We can do the following

import jax.lax as lax
import jax.numpy as np
from jax.tree_util import tree_map

def f(trace, i):
    # create a new_trace given current trace
    next_trace = tree_map(lambda a: a + 1, trace)
    return next_trace

initial_trace = {"a": np.array([1., 2.])}
num_samples = 10
traces = lax.scan(f, initial_trace, np.arange(num_samples))
print(traces)

which returns

{'a': array([[ 2.,  3.],
        [ 3.,  4.],
        [ 4.,  5.],
        [ 5.,  6.],
        [ 6.,  7.],
        [ 7.,  8.],
        [ 8.,  9.],
        [ 9., 10.],
        [10., 11.],
        [11., 12.]], dtype=float32)}

Hope that we can jit the whole hmc loops using the above pattern.

Separate out model and guide args in SVI

Many a time, the arguments to model and guides are different, and it is cumbersome to pass the same set of args to both and have some be unused. This suggests a further refactoring of the SVI class such that:

  • model_args and guide_args can be specified separately, and are jittable. These can potentially change during the course of training, e.g. data batches, or other dynamic arguments.
  • kwargs will contain common static arguments common to both and can simply be specified during the initial call to svi rather than piping it in through init and update functions. These can be functions pass to model or guides as well as static arrays.

standard_gamma not caching compiled function

I would have expected the second call to Beta.sample() to be much faster due to 2 reasons - no compilation cost for standard_gamma, and faster samples from the compiled kernel.

This seems to not be the case.

In [4]: %timeit Beta(1.1, 1.1).sample(PRNGKey(1), (1000,))
2.24 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: %timeit Beta(1.1, 1.1).sample(PRNGKey(1), (1000,))
2.3 s ± 46.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

[discussion] use transforms in potential_fn

I would like to suggest a way to use support information for default transforms and incorporate it in potential_fn:

  • initialize_model takes an additional argument named transforms, which is a dict map each latent variable to its user-defined transform. Otherwise, we use the default transform.
  • initialize model returns a z transform fn (in addition to init unconstrained param and potential_fn) for tranforming unconstrained variables to contrained variable. So users can use this transform fn in transform argument of tscan: tscan(..., transform=lambda latent: z_tranform_fn(latent.z))

Disable generic args checking for discrete distributions

Currently the logpmf method in rv_discrete does a lot of args checking and substituting NaN or -Inf values for pmf of out of support values. These operations are opaque to the JAX tracer and I think should simply be removed or done in a way that doesn't come in the way of operations like grad.

Make tscan jittable

As it stands, unlike other primitives, our tscan implementation is not jittable. It is worth making it work well with jax.jit for benchmarking purposes.

Thin out distribution wrappers and aggregate in continuous.py

We can thin out our distribution wrappers, and let numpy do the dispatching, e.g. np.log or np.exp will be correctly dispatched by numpy even though they will incur a small dispatch cost. Like discrete.py we can move all the lightly wrapped continuous distribution into continuous.py.

Support exporting to arviz InferenceData

There are a bunch of features available in arviz. We can utilize that great library without adding dependency by making a utility function which converting our fori_append results to a dictionary arviz_dict. Then in arviz, we just need to use arviz.from_dict to get density, trace_plot, and a bunch of new stats such as loo, waic,...

Make HMCState picklable

Currently, we use laxtuple to make HMCState. This way however make it not possible to pickle HMCState (unless we transform it to a dict/tuple/namedtuple).

Build wrappers for lax primitives to disable tracing in debug mode

Even with the disable_jit flag, it is inconvenient to debug our existing code since we have many utilities that use these lax primitives which need to be manually rewritten each time we want to debug something. Lets build light wrappers around these to make debugging easy.

We should also move this upstream to JAX if we find it useful.

Make inference utilities available for prediction

Related issue - #41.

It will be nice to provide an easy interface for doing predictions (ideally a vectorized version of TracePredictive), so that users do not need to write code to do this by themselves. e.g. baseball. More detailed discussion on the design is needed, and hopefully, we'll have some more insights on this after attempting this in Pyro pyro-ppl/pyro#1725.

Add support for plate context manager

We can start with a basic implementation that just does distributions broadcasting under the hood and stores the plate information at sample sites for any inference algorithms.

Add support for InverseTransform

This will allow us to create new transform classes out of existing ones. e.g. LogitTransform as SigmoidTransform().inv.

It will also be nice if the client code can just take in a list of transforms and call .inv on it without having any knowledge of whether the original or the inverted transform was passed in.

e.g.
Numpyro

>>> t = SigmoidTransform()

>>> t.inv
 <bound method SigmoidTransform.inv of <numpyro.distributions.constraints.SigmoidTransform object at 0x7f92680ce710>>

>>> t.inv.inv
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-10-15abeba100ea> in <module>
----> 1 t.inv.inv

AttributeError: 'function' object has no attribute 'inv'

Pytorch

>>> t = SigmoidTransform()

>>> t.inv
 _InverseTransform()

>>> t.inv.inv
 SigmoidTransform()

Implement JVP rules for cumsum/cumprod

Currently, cumsum/cumprod does not have jvp rule in jax yet. And these operators are necessary for simplex constraint. We should make custom primitives to support these operators.

NUTS doesn't converge on a stan model

@stefanwebb pointed out that Pyro's NUTS on the earnings latin square model gives extremely different results from Stan. cc. @jpchen

To debug this, I have tried running the model on Pyro's NUTS and Numpyro's NUTS and both return results which are very far off from Stan with high r_hat values indicating that the procedure hasn't converged. Creating this issue to track progress on investigating this bug / discrepancy.

Some notes:

  • It is possible that my translation of the model to Pyro/Numpyro is wrong.
  • I have tried changing the default tensor type to Double in Pyro and changing the scale parameters to dist.HalfCauchy(1.) (instead of dist.Uniform(0, 100) which is more faithful to the Stan implementation) to see if that helps convergence. While it does seem to help somewhat, we still get very different results. This seems to help quite a lot on Numpyro (still checking on Pyro).
  • Numpyro is much faster than Pyro (I think also faster than Stan), but seems to give incorrect results. Not surprising since the underlying issue, either in my code or in the inference algorithm, is likely the same for both the implementations.

Pyro code:

import csv
from collections import defaultdict

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC


torch.set_default_tensor_type('torch.DoubleTensor')
use_uniform = False


def scale():
    return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)


def model(data):
    eth = data['eth']
    age = data['age']
    x = data['x']
    y = data['y']
    mu_a1 = pyro.sample('mu_a1', dist.Normal(0., 1.))
    mu_a2 = pyro.sample('mu_a2', dist.Normal(0., 1.))
    sigma_a1 = pyro.sample('sigma_a1', scale())
    sigma_a2 = pyro.sample('sigma_a2', scale())

    mu_b1 = pyro.sample('mu_b1', dist.Normal(0., 1.))
    mu_b2 = pyro.sample('mu_b2', dist.Normal(0., 1.))
    sigma_b1 = pyro.sample('sigma_b1', scale())
    sigma_b2 = pyro.sample('sigma_b2', scale())

    mu_c = pyro.sample('mu_c', dist.Normal(0., 1.))
    sigma_c = pyro.sample('sigma_c', scale())
    mu_d = pyro.sample('mu_d', dist.Normal(0., 1.))
    sigma_d = pyro.sample('sigma_d', scale())

    nage = pyro.plate("n_age", 3, dim=-1)
    neth = pyro.plate("neth", 4, dim=-2)

    with neth:
        a1 = pyro.sample('a1', dist.Normal(10 * mu_a1, sigma_a1))
        a2 = pyro.sample('a2', dist.Normal(mu_a2, sigma_a2))

    with nage:
        b1 = pyro.sample('b1', dist.Normal(10 * mu_b1, sigma_b1))
        b2 = pyro.sample('b2', dist.Normal(0.1 * mu_b2, sigma_b2))

    with neth, nage:
        c = pyro.sample('c', dist.Normal(10 * mu_c, sigma_c))
        d = pyro.sample('d', dist.Normal(0.1 * mu_d, sigma_d))

    y_hat = a1[eth].squeeze(-1) + a2[eth].squeeze(-1) * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
    simga_y = pyro.sample('sigma_y', scale())
    with pyro.plate('N', 1059):
        pyro.sample('obs', dist.Normal(y_hat, simga_y), obs=y)


data = defaultdict(list)
with open('earnings.csv', 'r') as f:
    csv_reader = csv.DictReader(f)
    for row in csv_reader:
        data['x'].append(float(row['x']))
        data['y'].append(float(row['y']))
        data['age'].append(int(row['age']) - 1)
        data['eth'].append(int(row['eth']) - 1)


data['x'] = torch.tensor(data['x'])
data['y'] = torch.tensor(data['y'])
data['age'] = torch.tensor(data['age'], dtype=torch.long)
data['eth'] = torch.tensor(data['eth'], dtype=torch.long)


nuts_kernel = NUTS(model, max_tree_depth=6, jit_compile=True, ignore_jit_warnings=True)
posterior_fully_pooled = MCMC(nuts_kernel,
                              num_samples=500,
                              warmup_steps=500,
                              num_chains=2).run(data)

print(posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2']).diagnostics())
marginals = posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2'])
for k, v in marginals.empirical.items():
    print(k, v.mean)

Numpyro code:

import csv
from collections import defaultdict

from jax.random import PRNGKey

import numpyro.distributions as dist
from numpyro.handlers import sample
import jax.numpy as np

from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect

use_uniform = False


def scale():
    return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)


def model(data):
    eth = data['eth']
    age = data['age']
    x = data['x']
    y = data['y']
    mu_a1 = sample('mu_a1', dist.Normal(0., 1.))
    mu_a2 = sample('mu_a2', dist.Normal(0., 1.))
    sigma_a1 = sample('sigma_a1', scale())
    sigma_a2 = sample('sigma_a2', scale())

    mu_b1 = sample('mu_b1', dist.Normal(0., 1.))
    mu_b2 = sample('mu_b2', dist.Normal(0., 1.))
    sigma_b1 = sample('sigma_b1', scale())
    sigma_b2 = sample('sigma_b2', scale())

    mu_c = sample('mu_c', dist.Normal(0., 1.))
    sigma_c = sample('sigma_c', scale())
    mu_d = sample('mu_d', dist.Normal(0., 1.))
    sigma_d = sample('sigma_d', scale())

    a1 = sample('a1', dist.Normal(10 * np.broadcast_to(mu_a1, (4,)), sigma_a1))
    a2 = sample('a2', dist.Normal(np.broadcast_to(mu_a2, (4,)), sigma_a2))

    b1 = sample('b1', dist.Normal(10 * np.broadcast_to(mu_b1, (3,)), sigma_b1))
    b2 = sample('b2', dist.Normal(0.1 * np.broadcast_to(mu_b2, (3,)), sigma_b2))

    c = sample('c', dist.Normal(10 * np.broadcast_to(mu_c, (4, 3)), sigma_c))
    d = sample('d', dist.Normal(0.1 * np.broadcast_to(mu_d, (4, 3)), sigma_d))

    y_hat = a1[eth] + a2[eth] * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
    simga_y = sample('sigma_y', scale())
    sample('obs', dist.Normal(y_hat, simga_y), obs=y)


data = defaultdict(list)
with open('earnings.csv', 'r') as f:
    csv_reader = csv.DictReader(f)
    for row in csv_reader:
        data['x'].append(float(row['x']))
        data['y'].append(float(row['y']))
        data['age'].append(int(row['age']) - 1)
        data['eth'].append(int(row['eth']) - 1)


data['x'] = np.array(data['x'])
data['y'] = np.array(data['y'])
data['age'] = np.array(data['age']).astype(np.int64)
data['eth'] = np.array(data['eth']).astype(np.int64)

init_params, potential_fn, transform_fn = initialize_model(PRNGKey(0), model, data)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, 2000)
hmc_states = fori_collect(2000, sample_kernel, hmc_state,
                          transform=lambda hmc_state: transform_fn(hmc_state.z))
print(hmc_states)

Stan results

               mean se_mean    sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
a1[1]          5.60    1.57  8.55   -9.18   -1.05    5.24   11.48   22.41    30 1.14
a1[2]          5.44    1.58  8.57   -9.37   -1.17    5.14   11.46   22.07    29 1.14
a1[3]          5.00    1.59  8.60   -9.92   -1.64    4.65   11.04   21.43    29 1.14
a1[4]          5.48    1.57  8.55   -9.31   -1.16    5.14   11.46   22.14    30 1.14
a2[1]          0.07    0.03  0.20   -0.34   -0.06    0.07    0.19    0.42    34 1.06
a2[2]          0.07    0.03  0.20   -0.34   -0.05    0.07    0.19    0.42    33 1.06
a2[3]          0.08    0.03  0.20   -0.33   -0.04    0.08    0.20    0.44    33 1.07
a2[4]          0.07    0.03  0.20   -0.33   -0.05    0.07    0.20    0.43    33 1.06
b1[1]          3.97    1.52  8.68  -14.80   -1.11    3.81    9.56   20.60    33 1.14
b1[2]          2.36    1.52  8.66  -16.69   -2.77    2.25    8.05   18.87    33 1.14
b1[3]          1.78    1.52  8.69  -17.57   -3.48    1.64    7.32   18.27    33 1.13
b2[1]         -0.01    0.02  0.16   -0.29   -0.09   -0.02    0.06    0.38    60 1.03
b2[2]          0.02    0.02  0.16   -0.25   -0.06    0.01    0.09    0.40    59 1.03
b2[3]          0.02    0.02  0.16   -0.24   -0.06    0.02    0.10    0.42    59 1.03
c[1,1]        -2.43    2.17  7.72  -14.76   -7.90   -3.95    2.13   13.98    13 1.36
c[1,2]        -2.36    2.17  7.72  -14.89   -7.85   -3.87    2.21   14.05    13 1.36
c[1,3]        -2.46    2.17  7.72  -14.83   -7.98   -4.04    2.13   13.95    13 1.36
c[2,1]        -2.37    2.17  7.72  -14.80   -7.87   -3.89    2.23   14.05    13 1.36
c[2,2]        -2.51    2.17  7.73  -14.93   -8.04   -4.02    2.11   14.02    13 1.36
c[2,3]        -2.38    2.17  7.72  -14.79   -7.89   -3.87    2.24   13.83    13 1.36
c[3,1]        -2.41    2.17  7.72  -14.89   -7.89   -3.89    2.23   13.97    13 1.36
c[3,2]        -2.44    2.17  7.72  -14.80   -7.94   -3.92    2.19   14.01    13 1.36
c[3,3]        -2.43    2.17  7.72  -14.75   -7.92   -3.93    2.19   13.97    13 1.36
c[4,1]        -2.44    2.17  7.72  -14.85   -7.94   -3.97    2.17   13.99    13 1.36
c[4,2]        -2.38    2.17  7.72  -14.75   -7.86   -3.88    2.19   14.04    13 1.36
c[4,3]        -2.42    2.17  7.72  -14.80   -7.91   -3.93    2.19   13.95    13 1.36
d[1,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[1,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[1,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[2,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[2,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[2,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[3,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[3,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[3,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[4,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[4,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[4,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
mu_a1          0.54    0.16  0.85   -0.94   -0.12    0.50    1.13    2.17    30 1.14
mu_a2          0.07    0.03  0.20   -0.34   -0.05    0.07    0.20    0.43    33 1.06
mu_b1          0.25    0.13  0.85   -1.58   -0.27    0.24    0.79    1.87    42 1.11
mu_b2          0.04    0.09  0.98   -1.90   -0.61    0.03    0.68    2.06   125 1.03
mu_c          -0.24    0.22  0.77   -1.48   -0.80   -0.40    0.22    1.40    13 1.36
mu_d          -0.17    0.21  1.14   -2.22   -0.98   -0.22    0.63    2.09    29 1.11
sigma_a1       0.96    0.11  1.92    0.02    0.12    0.34    1.01    5.30   319 1.01
sigma_a2       0.01    0.00  0.03    0.00    0.00    0.01    0.02    0.08   328 1.01
sigma_b1       4.07    0.28  6.07    0.14    1.14    2.19    4.43   20.92   487 1.02
sigma_b2       0.12    0.03  0.30    0.00    0.02    0.04    0.09    0.87   138 1.05
sigma_c        0.16    0.01  0.13    0.02    0.06    0.13    0.22    0.48   232 1.02
sigma_d        0.00    0.00  0.00    0.00    0.00    0.00    0.00    0.01   179 1.01
sigma_y        0.88    0.00  0.02    0.84    0.86    0.87    0.89    0.91   633 1.01

earnings.tar.gz

Enable jit compilation in minipyro

JIT is currently not working as expected in the minipyro example - currently the jitted function within step is being compiled each time step is called. If we pre-compile and cache the function, it is giving static results.

Defer implementation `logpdf`, `pdf` to `jax.scipy.stats`

Currently, numpyro's distribution is based on scipy's implementation. The approach is a bit different (though simpler) from jax.scipy.stats module. It would be nice if we only maintain the frozen wrapper for the purpose of pyro modelling and rely on the upstream implementation of logpdf.

Support compatibility module in numpyro to support Pyro API

As discussed in pyro-ppl/pyro#1790, to support models written in Pyro, we can have a compatibility wrapper around distributions and inference algorithms.

I think we may have to think about some of the interface issues here, because the efficient way of doing things in Pyro (using classes, storing state) may not be ideal in Numpyro. Some of this can be abstracted within this compatibility module, but other things like user for loops may need to be re-written in jax to make them more efficient.

Some of the questions that we will need to think about (from pyro-ppl/pyro#1790), and my initial thoughts. I might expand this list as I work more on this, but feel free to add your thoughts:

  • There is no global param store, so to get something like pyro.get_param_store to work, we'll need to simply cache the optimizer state somewhere towards the end of svi training. A more faithful implementation will unfortunately need to copy arrays every time the optimizer state is updated, and that will be too expensive. It might be best to not provide a compatibility wrapper for this, at least to begin with.
  • the loop for ..: svi.step() will be suboptimal, and should ideally be replaced by lax.fori_loop when users are on the JAX backend. This could be an order of magnitude faster! - I think the solution here might be to absorb the for loop within SVI's API so that Numpyro and Pyro can handle it differently.
  • There is no pyro.module, because the stax NNs don't have any parameters that need to be registered. All parameters need to be cycled back and forth with the update function. I think it should be okay to skip pyro.module as well since we don't need to register any NN parameters.

Wrap distributions in scipy.stats for use with JAX

To make the internals of the sampler and logpdf methods visible to the tracer and compatible with jax, we need to:

  • Write the operations in terms of jax.numpy operations.
  • We need to use jax's count based random number generator (via PRNGKey), and not the default used by scipy distributions which is numpy's globally mutable mtrand. Currently, the PRNGKey is being passed in through the random_state kwarg to the distribution's _rvs method.

Currently, only the normal distribution is wrapped. We can similarly wrap the following distributions (we may need to wrap over the samplers / rewrite the logic using jax operations):

Continuous

Discrete

Iterative NUTS consumes more memory than recursive in big models

#54 implements both recursive and iterative NUTS. It has been shown that iterative method outperforms recursive method in term of reducing overhead. However, currently iterative NUTS consumes a bit more memory than recursive NUTS as demonstrated in this gist.

Current benchmark (for tree depth = 10)

  • latent dim = 10

    • iterative: 791 µs, 233.23 MiB
    • recursive: 3.71 s, 232.61 MiB
  • latent dim = 100

    • iterative: 1.34 ms, 297.91 MiB
    • recursive: 3.76 s, 283.20 MiB
  • latent dim = 1000

    • iterative: 8.46 ms, 503.02 MiB!
    • recursive: 3.79 s, 380.55 MiB

Batching for custom transforms

Batching is currently not implemented for custom_transforms like xlogy and xlog1py so functions using these primitives will fail with vmap.

Utilities for MCMC

This issue tracks tasks for MCMC. Here are necessary ingredients to be able to benchmark. Feel free to add yourselves to your interest.

  • Port dual_averaging
  • Port welford_covariance
  • Port verlocity_verlet @fehiepsi
  • Port find_reasonable_step_size @fehiepsi
  • Port WarmupAdapter @fehiepsi
  • Port build_tree @fehiepsi
  • Port sample

Use upstream samplers

There are various sampler available upstream in the last version of JAX. Let's use these samplers instead.

Use scipy samplers directly for discrete distributions

Due to issues like google/jax#480, it makes sense to use discrete samplers directly from scipy and then transfer the results back to device using jax.device_put(). I checked that this is often an order of magnitude faster for the CPU, and should be safe since the samplers aren't reparametrized. Once we have have a JAX native multinomial distribution, we can change to that later.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.