pyro-ppl / numpyro Goto Github PK
View Code? Open in Web Editor NEWProbabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
Home Page: https://num.pyro.ai
License: Apache License 2.0
Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
Home Page: https://num.pyro.ai
License: Apache License 2.0
There are various sampler available upstream in the last version of JAX. Let's use these samplers instead.
The error happens because new version of jax requires a specific pattern for using jit
's static_argnums. See jax-ml/jax#595 for more context. The failed functions include:
In addition, I think that we can simplify the implementation of these functions by using decorator custom_transform
as in cumsum
, cumprod
.
cc @neerajprad
Lets host the jax/jaxlib wheels for CI so that we can track JAX master directly for development.
Related issue - #41.
It will be nice to provide an easy interface for doing predictions (ideally a vectorized version of TracePredictive
), so that users do not need to write code to do this by themselves. e.g. baseball. More detailed discussion on the design is needed, and hopefully, we'll have some more insights on this after attempting this in Pyro pyro-ppl/pyro#1725.
Currently, cumsum/cumprod does not have jvp rule in jax yet. And these operators are necessary for simplex constraint. We should make custom primitives to support these operators.
Refer to #44. We can implement dirichlet sample and grad methods that are similar to the ones we have in PyTorch.
Based on recent examples of @neerajprad for translation_rule of custom primitive functions, it seems that we can make standard_gamma jittable. Let's update standard_gamma with
From discussion in #70:
z
and do the flattening. The latter will be better for debuggability, and I'm not sure if there is a significant overhead under JIT.hmc_kernel
can have a separate get_kernel
method that would return either an HMC or NUTS kernel depending on args. This can be called by the user or from inside mcmc
.lax.scan
to only scan z
by default. @fehiepsilax.cond
to lax.while
.Currently, only discrete distributions support argcheck. We should do it for continuous distributions too.
As discussed in pyro-ppl/pyro#1790, to support models written in Pyro, we can have a compatibility
wrapper around distributions and inference algorithms.
I think we may have to think about some of the interface issues here, because the efficient way of doing things in Pyro (using classes, storing state) may not be ideal in Numpyro. Some of this can be abstracted within this compatibility module, but other things like user for
loops may need to be re-written in jax to make them more efficient.
Some of the questions that we will need to think about (from pyro-ppl/pyro#1790), and my initial thoughts. I might expand this list as I work more on this, but feel free to add your thoughts:
pyro.get_param_store
to work, we'll need to simply cache the optimizer state somewhere towards the end of svi training. A more faithful implementation will unfortunately need to copy arrays every time the optimizer state is updated, and that will be too expensive. It might be best to not provide a compatibility wrapper for this, at least to begin with.for ..: svi.step()
will be suboptimal, and should ideally be replaced by lax.fori_loop
when users are on the JAX backend. This could be an order of magnitude faster! - I think the solution here might be to absorb the for loop within SVI's API so that Numpyro and Pyro can handle it differently.pyro.module
, because the stax NNs don't have any parameters that need to be registered. All parameters need to be cycled back and forth with the update function. I think it should be okay to skip pyro.module
as well since we don't need to register any NN parameters.This is to track minor API issues as we notice them, that will be nice to clean up. This is not super urgent - these are known things that can be cleaned up before release.
Uniform(-2, 2)
directly on unconstrained parameters like in Stan.initialize_model(rng, model, model_args, model_kwargs)
to initialize_model(rng, model, *args, **kwargs)
.model
or potential_fn
is not jittable, I think we'll end up throwing an exception given than sample_kernel
has a jit
decorator. I think the right approach would be to pass a jit_compile=True
flag to hmc
so that the user does not need to modify the source code in case their model is not jittable.To make the internals of the sampler
and logpdf
methods visible to the tracer and compatible with jax, we need to:
jax.numpy
operations.PRNGKey
), and not the default used by scipy distributions which is numpy's globally mutable mtrand
. Currently, the PRNGKey
is being passed in through the random_state
kwarg to the distribution's _rvs
method.Currently, only the normal distribution is wrapped. We can similarly wrap the following distributions (we may need to wrap over the samplers / rewrite the logic using jax operations):
Continuous
Discrete
This issue tracks tasks for MCMC. Here are necessary ingredients to be able to benchmark. Feel free to add yourselves to your interest.
This issue is observed while running the LGT example. It is resolved if we replace refresh=True
at
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=True)
to refresh=False
.
Already Done:
Currently, we only develop and test for float32 array. When the functionalities are ready, we should test for float64 array too.
Currently, we do compiling 2 times: 1 at init, 1 at sample, so it would be better if we only compile 1 time.
One way is to modify tscan
with additional parameters: skip
to skip first warmup_steps; put warmup_state to a lax.cond to decide if it will be updated or not (depending on we are in warmup phase or sampling phase).
This will be important for inference on models that contain distributions with constrained support. I think PyTorch's transforms module is pretty nice, and we can follow the same pattern here.
cc. @fehiepsi
There are a bunch of features available in arviz. We can utilize that great library without adding dependency by making a utility function which converting our fori_append
results to a dictionary arviz_dict
. Then in arviz, we just need to use arviz.from_dict to get density, trace_plot, and a bunch of new stats such as loo
, waic
,...
This will allow us to create new transform classes out of existing ones. e.g. LogitTransform
as SigmoidTransform().inv
.
It will also be nice if the client code can just take in a list of transforms and call .inv
on it without having any knowledge of whether the original or the inverted transform was passed in.
e.g.
Numpyro
>>> t = SigmoidTransform()
>>> t.inv
<bound method SigmoidTransform.inv of <numpyro.distributions.constraints.SigmoidTransform object at 0x7f92680ce710>>
>>> t.inv.inv
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-10-15abeba100ea> in <module>
----> 1 t.inv.inv
AttributeError: 'function' object has no attribute 'inv'
Pytorch
>>> t = SigmoidTransform()
>>> t.inv
_InverseTransform()
>>> t.inv.inv
SigmoidTransform()
As it stands, unlike other primitives, our tscan implementation is not jittable. It is worth making it work well with jax.jit
for benchmarking purposes.
@stefanwebb pointed out that Pyro's NUTS on the earnings latin square model gives extremely different results from Stan. cc. @jpchen
To debug this, I have tried running the model on Pyro's NUTS and Numpyro's NUTS and both return results which are very far off from Stan with high r_hat values indicating that the procedure hasn't converged. Creating this issue to track progress on investigating this bug / discrepancy.
Some notes:
dist.HalfCauchy(1.)
(instead of dist.Uniform(0, 100)
which is more faithful to the Stan implementation) to see if that helps convergence. Pyro code:
import csv
from collections import defaultdict
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
torch.set_default_tensor_type('torch.DoubleTensor')
use_uniform = False
def scale():
return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)
def model(data):
eth = data['eth']
age = data['age']
x = data['x']
y = data['y']
mu_a1 = pyro.sample('mu_a1', dist.Normal(0., 1.))
mu_a2 = pyro.sample('mu_a2', dist.Normal(0., 1.))
sigma_a1 = pyro.sample('sigma_a1', scale())
sigma_a2 = pyro.sample('sigma_a2', scale())
mu_b1 = pyro.sample('mu_b1', dist.Normal(0., 1.))
mu_b2 = pyro.sample('mu_b2', dist.Normal(0., 1.))
sigma_b1 = pyro.sample('sigma_b1', scale())
sigma_b2 = pyro.sample('sigma_b2', scale())
mu_c = pyro.sample('mu_c', dist.Normal(0., 1.))
sigma_c = pyro.sample('sigma_c', scale())
mu_d = pyro.sample('mu_d', dist.Normal(0., 1.))
sigma_d = pyro.sample('sigma_d', scale())
nage = pyro.plate("n_age", 3, dim=-1)
neth = pyro.plate("neth", 4, dim=-2)
with neth:
a1 = pyro.sample('a1', dist.Normal(10 * mu_a1, sigma_a1))
a2 = pyro.sample('a2', dist.Normal(mu_a2, sigma_a2))
with nage:
b1 = pyro.sample('b1', dist.Normal(10 * mu_b1, sigma_b1))
b2 = pyro.sample('b2', dist.Normal(0.1 * mu_b2, sigma_b2))
with neth, nage:
c = pyro.sample('c', dist.Normal(10 * mu_c, sigma_c))
d = pyro.sample('d', dist.Normal(0.1 * mu_d, sigma_d))
y_hat = a1[eth].squeeze(-1) + a2[eth].squeeze(-1) * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
simga_y = pyro.sample('sigma_y', scale())
with pyro.plate('N', 1059):
pyro.sample('obs', dist.Normal(y_hat, simga_y), obs=y)
data = defaultdict(list)
with open('earnings.csv', 'r') as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
data['x'].append(float(row['x']))
data['y'].append(float(row['y']))
data['age'].append(int(row['age']) - 1)
data['eth'].append(int(row['eth']) - 1)
data['x'] = torch.tensor(data['x'])
data['y'] = torch.tensor(data['y'])
data['age'] = torch.tensor(data['age'], dtype=torch.long)
data['eth'] = torch.tensor(data['eth'], dtype=torch.long)
nuts_kernel = NUTS(model, max_tree_depth=6, jit_compile=True, ignore_jit_warnings=True)
posterior_fully_pooled = MCMC(nuts_kernel,
num_samples=500,
warmup_steps=500,
num_chains=2).run(data)
print(posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2']).diagnostics())
marginals = posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2'])
for k, v in marginals.empirical.items():
print(k, v.mean)
Numpyro code:
import csv
from collections import defaultdict
from jax.random import PRNGKey
import numpyro.distributions as dist
from numpyro.handlers import sample
import jax.numpy as np
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
use_uniform = False
def scale():
return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)
def model(data):
eth = data['eth']
age = data['age']
x = data['x']
y = data['y']
mu_a1 = sample('mu_a1', dist.Normal(0., 1.))
mu_a2 = sample('mu_a2', dist.Normal(0., 1.))
sigma_a1 = sample('sigma_a1', scale())
sigma_a2 = sample('sigma_a2', scale())
mu_b1 = sample('mu_b1', dist.Normal(0., 1.))
mu_b2 = sample('mu_b2', dist.Normal(0., 1.))
sigma_b1 = sample('sigma_b1', scale())
sigma_b2 = sample('sigma_b2', scale())
mu_c = sample('mu_c', dist.Normal(0., 1.))
sigma_c = sample('sigma_c', scale())
mu_d = sample('mu_d', dist.Normal(0., 1.))
sigma_d = sample('sigma_d', scale())
a1 = sample('a1', dist.Normal(10 * np.broadcast_to(mu_a1, (4,)), sigma_a1))
a2 = sample('a2', dist.Normal(np.broadcast_to(mu_a2, (4,)), sigma_a2))
b1 = sample('b1', dist.Normal(10 * np.broadcast_to(mu_b1, (3,)), sigma_b1))
b2 = sample('b2', dist.Normal(0.1 * np.broadcast_to(mu_b2, (3,)), sigma_b2))
c = sample('c', dist.Normal(10 * np.broadcast_to(mu_c, (4, 3)), sigma_c))
d = sample('d', dist.Normal(0.1 * np.broadcast_to(mu_d, (4, 3)), sigma_d))
y_hat = a1[eth] + a2[eth] * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
simga_y = sample('sigma_y', scale())
sample('obs', dist.Normal(y_hat, simga_y), obs=y)
data = defaultdict(list)
with open('earnings.csv', 'r') as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
data['x'].append(float(row['x']))
data['y'].append(float(row['y']))
data['age'].append(int(row['age']) - 1)
data['eth'].append(int(row['eth']) - 1)
data['x'] = np.array(data['x'])
data['y'] = np.array(data['y'])
data['age'] = np.array(data['age']).astype(np.int64)
data['eth'] = np.array(data['eth']).astype(np.int64)
init_params, potential_fn, transform_fn = initialize_model(PRNGKey(0), model, data)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, 2000)
hmc_states = fori_collect(2000, sample_kernel, hmc_state,
transform=lambda hmc_state: transform_fn(hmc_state.z))
print(hmc_states)
Stan results
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
a1[1] 5.60 1.57 8.55 -9.18 -1.05 5.24 11.48 22.41 30 1.14
a1[2] 5.44 1.58 8.57 -9.37 -1.17 5.14 11.46 22.07 29 1.14
a1[3] 5.00 1.59 8.60 -9.92 -1.64 4.65 11.04 21.43 29 1.14
a1[4] 5.48 1.57 8.55 -9.31 -1.16 5.14 11.46 22.14 30 1.14
a2[1] 0.07 0.03 0.20 -0.34 -0.06 0.07 0.19 0.42 34 1.06
a2[2] 0.07 0.03 0.20 -0.34 -0.05 0.07 0.19 0.42 33 1.06
a2[3] 0.08 0.03 0.20 -0.33 -0.04 0.08 0.20 0.44 33 1.07
a2[4] 0.07 0.03 0.20 -0.33 -0.05 0.07 0.20 0.43 33 1.06
b1[1] 3.97 1.52 8.68 -14.80 -1.11 3.81 9.56 20.60 33 1.14
b1[2] 2.36 1.52 8.66 -16.69 -2.77 2.25 8.05 18.87 33 1.14
b1[3] 1.78 1.52 8.69 -17.57 -3.48 1.64 7.32 18.27 33 1.13
b2[1] -0.01 0.02 0.16 -0.29 -0.09 -0.02 0.06 0.38 60 1.03
b2[2] 0.02 0.02 0.16 -0.25 -0.06 0.01 0.09 0.40 59 1.03
b2[3] 0.02 0.02 0.16 -0.24 -0.06 0.02 0.10 0.42 59 1.03
c[1,1] -2.43 2.17 7.72 -14.76 -7.90 -3.95 2.13 13.98 13 1.36
c[1,2] -2.36 2.17 7.72 -14.89 -7.85 -3.87 2.21 14.05 13 1.36
c[1,3] -2.46 2.17 7.72 -14.83 -7.98 -4.04 2.13 13.95 13 1.36
c[2,1] -2.37 2.17 7.72 -14.80 -7.87 -3.89 2.23 14.05 13 1.36
c[2,2] -2.51 2.17 7.73 -14.93 -8.04 -4.02 2.11 14.02 13 1.36
c[2,3] -2.38 2.17 7.72 -14.79 -7.89 -3.87 2.24 13.83 13 1.36
c[3,1] -2.41 2.17 7.72 -14.89 -7.89 -3.89 2.23 13.97 13 1.36
c[3,2] -2.44 2.17 7.72 -14.80 -7.94 -3.92 2.19 14.01 13 1.36
c[3,3] -2.43 2.17 7.72 -14.75 -7.92 -3.93 2.19 13.97 13 1.36
c[4,1] -2.44 2.17 7.72 -14.85 -7.94 -3.97 2.17 13.99 13 1.36
c[4,2] -2.38 2.17 7.72 -14.75 -7.86 -3.88 2.19 14.04 13 1.36
c[4,3] -2.42 2.17 7.72 -14.80 -7.91 -3.93 2.19 13.95 13 1.36
d[1,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[1,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[1,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[2,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[2,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[2,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[3,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[3,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[3,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[4,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[4,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[4,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
mu_a1 0.54 0.16 0.85 -0.94 -0.12 0.50 1.13 2.17 30 1.14
mu_a2 0.07 0.03 0.20 -0.34 -0.05 0.07 0.20 0.43 33 1.06
mu_b1 0.25 0.13 0.85 -1.58 -0.27 0.24 0.79 1.87 42 1.11
mu_b2 0.04 0.09 0.98 -1.90 -0.61 0.03 0.68 2.06 125 1.03
mu_c -0.24 0.22 0.77 -1.48 -0.80 -0.40 0.22 1.40 13 1.36
mu_d -0.17 0.21 1.14 -2.22 -0.98 -0.22 0.63 2.09 29 1.11
sigma_a1 0.96 0.11 1.92 0.02 0.12 0.34 1.01 5.30 319 1.01
sigma_a2 0.01 0.00 0.03 0.00 0.00 0.01 0.02 0.08 328 1.01
sigma_b1 4.07 0.28 6.07 0.14 1.14 2.19 4.43 20.92 487 1.02
sigma_b2 0.12 0.03 0.30 0.00 0.02 0.04 0.09 0.87 138 1.05
sigma_c 0.16 0.01 0.13 0.02 0.06 0.13 0.22 0.48 232 1.02
sigma_d 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.01 179 1.01
sigma_y 0.88 0.00 0.02 0.84 0.86 0.87 0.89 0.91 633 1.01
Multivariate distributions does not have parameters loc
, scale
and it should stand on its own. So it would be nice if we make a common interface for it.
We can thin out our distribution wrappers, and let numpy do the dispatching, e.g. np.log
or np.exp
will be correctly dispatched by numpy even though they will incur a small dispatch cost. Like discrete.py
we can move all the lightly wrapped continuous distribution into continuous.py
.
Many a time, the arguments to model and guides are different, and it is cumbersome to pass the same set of args to both and have some be unused. This suggests a further refactoring of the SVI class such that:
model_args
and guide_args
can be specified separately, and are jittable. These can potentially change during the course of training, e.g. data batches, or other dynamic arguments.kwargs
will contain common static arguments common to both and can simply be specified during the initial call to svi
rather than piping it in through init
and update
functions. These can be functions pass to model or guides as well as static arrays.Even with the disable_jit
flag, it is inconvenient to debug our existing code since we have many utilities that use these lax primitives which need to be manually rewritten each time we want to debug something. Lets build light wrappers around these to make debugging easy.
We should also move this upstream to JAX if we find it useful.
For convenience, we are currently using many JAX functions that haven't been exposed publicly like _promote_args
. Let us reimplement (and customize these functions as needed) in util.py
, so that we don't incur dependency on private functions that might be removed or whose API changed without notice.
#54 implements both recursive and iterative NUTS. It has been shown that iterative method outperforms recursive method in term of reducing overhead. However, currently iterative NUTS consumes a bit more memory than recursive NUTS as demonstrated in this gist.
latent dim = 10
latent dim = 100
latent dim = 1000
Right now we don't have any diagnostics to measure the quality of mixing. It will be nice to either implement diagnostics like we have done in Pyro, or perhaps just use the arviz
library for this (preferable to offload this to arviz if possible).
I would like to suggest a way to use support information for default transforms and incorporate it in potential_fn:
tscan(..., transform=lambda latent: z_tranform_fn(latent.z))
JAX is cool but docs/examples of its features are still lacking. Let's share tips while using jax here. :)
pytree
is a dict/tuple/array or a combination of them. Many transformations of jax works with pytree arguments. For example, we can apply grad(f)(x)
will return a dict of grad arrays if x
is a dict of arrays. There are also many utilities available at jax.tree_util such as
tree_map
: as in the below example of lax.scan
,tree_multimap
: as in the implementation of velocity_verlet.We can do the following
import jax.lax as lax
import jax.numpy as np
from jax.tree_util import tree_map
def f(trace, i):
# create a new_trace given current trace
next_trace = tree_map(lambda a: a + 1, trace)
return next_trace
initial_trace = {"a": np.array([1., 2.])}
num_samples = 10
traces = lax.scan(f, initial_trace, np.arange(num_samples))
print(traces)
which returns
{'a': array([[ 2., 3.],
[ 3., 4.],
[ 4., 5.],
[ 5., 6.],
[ 6., 7.],
[ 7., 8.],
[ 8., 9.],
[ 9., 10.],
[10., 11.],
[11., 12.]], dtype=float32)}
Hope that we can jit the whole hmc loops using the above pattern.
Start with jax version 0.1.28, as @neerajprad observed, tests for Binomial/Multinomial failed. I am not sure what is the reason for it because except for the tests, these distributions work well in notebooks.
This will be needed for certain distributions, and will be a good exercise in implementing primitive operations in jax.
Since we use float32 values by default instead of float64 like scipy.stats
does, having support for logits in multinomial, bernoulli and binomial distributions is important, and will make operations that manipulate probabilities more numerically stable.
Currently the logpmf
method in rv_discrete
does a lot of args checking and substituting NaN
or -Inf
values for pmf of out of support values. These operations are opaque to the JAX tracer and I think should simply be removed or done in a way that doesn't come in the way of operations like grad
.
Currently, numpyro's distribution is based on scipy's implementation. The approach is a bit different (though simpler) from jax.scipy.stats
module. It would be nice if we only maintain the frozen
wrapper for the purpose of pyro modelling and rely on the upstream implementation of logpdf
.
For many models that seem to work well in Pyro with the default values for trajectory_length
and step_size
(see test_mcmc.py
), the behavior in numpyro can be finicky in that either HMC / NUTS is too slow or we get wrong results, despite our tests running for many more steps than in Pyro.
Batching is currently not implemented for custom_transforms
like xlogy
and xlog1py
so functions using these primitives will fail with vmap
.
With #53, we can turn off distributions args checking with a contextmanager called validation_disabled()
. Disabling is required because the args checking doesn't work with JIT (which isn't itself an issue since we anyways shouldn't need to JIT these extra debugging related features).
The current context manager is a hack because it is a global flag, and we might call the distributions' logpdf
method after the distribution is initialized which may result in the flag being applied incorrectly. Given scipy's inheritance pattern and use of frozen
vs non-frozen
instances, I couldn't see an easy way of storing this flag in the constructor (like we do in PyTorch), but it is worth addressing this as we clean up the interface.
We can start with a basic implementation that just does distributions broadcasting under the hood and stores the plate information at sample sites for any inference algorithms.
Due to issues faced in #51, #50 and #81, it seems that we would be better off not inheriting from scipy and having to work around their class hierarchy which was designed to support its own use cases. Instead, I think it will help if our design was closer to PyTorch, in terms of Pyro <--> Numpyro compatibility (#66), and more generally reuse of functionality like constraint checks and bijective transforms that are useful for HMC and SVI.
I think to begin with we can start by only supporting sample
, log_prob
methods and adding more as we go along. It should be relatively straightforward to do so for our existing implementations. In cases where we don't have a JAX sampler available, we could just use scipy's sampler internally (as suggested by @fehiepsi).
The following repro returns different results for different runs.
onp.random.seed(0)
dist.bernoulli(0.5).rvs(size=(100,))
JIT is currently not working as expected in the minipyro example - currently the jitted
function within step
is being compiled each time step
is called. If we pre-compile and cache the function, it is giving static results.
Currently, SVI class just holds constant properties. It is much like a function IMO. So I create this thread to discuss about the possibility of splitting SVI's methods into functions: svi_init
(to get initial svi_state/opt_state), svi_step
(to update svi_state/opt_state). To avoid repeated arguments such as model, guide,... users can use functools.partial
or we can keep the object SVI class as a wrapper.
I list here some advantages which I have in mind:
lax.fori_loop
(as in this lax's example).pyro.param
or not. This will be helpful when params come from jax.experimental.stax
.Note that this is just for discussion, not a request to change. We'd better focus on adding more distributions and benchmarking. Some initial tests suggest that dispatch
overhead is large but it scales pretty well.
Due to issues like jax-ml/jax#480, it makes sense to use discrete samplers directly from scipy and then transfer the results back to device using jax.device_put()
. I checked that this is often an order of magnitude faster for the CPU, and should be safe since the samplers aren't reparametrized. Once we have have a JAX native multinomial distribution, we can change to that later.
I would have expected the second call to Beta.sample()
to be much faster due to 2 reasons - no compilation cost for standard_gamma
, and faster samples from the compiled kernel.
This seems to not be the case.
In [4]: %timeit Beta(1.1, 1.1).sample(PRNGKey(1), (1000,))
2.24 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [5]: %timeit Beta(1.1, 1.1).sample(PRNGKey(1), (1000,))
2.3 s ± 46.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Currently, we use laxtuple
to make HMCState. This way however make it not possible to pickle HMCState (unless we transform it to a dict/tuple/namedtuple).
These distributions will be useful for a multilevel model which I am intending to make an example as hierarchy model.
(split from #29 to track down the progress)
Here is a list of various models to benchmark. @neerajprad Please add more if you find they are needed.
grad
can pass through np.where
in JAX. It took a bit of effort to achieve the same behaviour in other frameworks.A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.