Hi there!
I'm a student working on implementing a Variational Autoencoder using this library, and I came across a bug where a function transformed with ppl.log_prob
produces wrong results when the function's jaxpr has more than one equation, which happens, for instance, when jit
or nest
are used or when the name
argument is passed to ppl.random_variable
.
For instance, consider the following random variables:
def f(rng):
eps = ppl.random_variable(tfd.Normal(0,1))(rng)
y = eps * 2
return y
def g(rng):
eps = ppl.random_variable(tfd.Normal(0,1), name = 'name')(rng)
y = eps * 2
return y
Here, the function f
produces a jaxpr with only one equation, while g
produces one with two equations (one that takes in rng
and outputs a sample from Normal(0,1)
, and another that multiplies that sample by 2).
When transformed by ppl.log_prob
, however, these functions produce different outputs:
ppl.log_prob(f)(0.0)
# -1.6120857
ppl.log_prob(g)(0.0)
# -2.305233
Here, the function f
produces the correct result of log(1/2) + log(1/sqrt(2*pi))
, but the function g
adds the transformation term twice, producing log(1/2) + log(1/2) + log(1/sqrt(2*pi))
.
Since this problem only happens for log_prob
, but not for inverse_and_ildj
, I believe the issue is with the reducer
function shown below:
def reducer(env, eqn, curr_log_prob, new_log_prob):
if (isinstance(curr_log_prob, FailedLogProb)
or isinstance(new_log_prob, FailedLogProb)):
# If `curr_log_prob` is `None` that means we were unable to compute
# a log_prob elsewhere, so the propagate failed.
return failed_log_prob
if eqn.primitive in log_prob_registry and new_log_prob is None:
# We are unable to compute a log_prob for this primitive.
return failed_log_prob
if new_log_prob is not None:
cells = [env.read(var) for var in eqn.outvars]
ildjs = sum([cell.ildj.sum() for cell in cells if cell.top()])
return curr_log_prob + new_log_prob + ildjs
return curr_log_prob
What I believe is happening here is that the equations are processed in the reverse order, and the outvar ildjs are counted twice. For example, considering the function g
above, we have two equations:
- Equation 1:
eps = nest(random_variable(rng))
- Equation 2:
y = 2 * eps
First, equation 2 is processed and eps.ildj
assumes the correct value of log(1/2)
.
Then, equation 1 is processed, which has a nest
primitive that triggers another call to propagate
for the random_variable
primitive, that then calls reducer
. However, since eps.ildj
already has the value log(1/2)
it gets added to the state of the equation, which becomes log(1/2) + log(1/sqrt(2*pi))
. I believe that, in this step, the correct value would be only log(1/sqrt(2*pi))
.
After equation 1 is processed, the results are aggregated using the reducer
function. However, while the state assigned to equation 1 is already log(1/2) + log(1/sqrt(2*pi))
, the value eps.ildj
is added again in this step, which leads to the wrong results.
If I replace the loop over eqn.outvars
by one over eqn.invars
, the issue in this case is solved, since in the nested calls to reducer
the cell.ildj
will still be undefined. I'm not totally sure this won't lead to potential problems in other cases, though. In any case, I've opened a pull request which changes this line and also adds a small test case corresponding to this issue, in case this is indeed what's causing the bug.