Traceback (most recent call last):
File ".../python/bug_report.py", line 74, in <module>
loss, grads = loss_fn_w_grad(
File ".../python/bug_report.py", line 56, in loss_fn
output = batched_model(
File ".../python/bug_report.py", line 44, in __call__
return self.output_layer(opt_2st_vec(t))
File ".../python/bug_report.py", line 22, in opt_2st_vec
solution = optx.root_find(obj, solver, x0)
File ".../venv/lib/python3.11/site-packages/optimistix/_root_find.py", line 227, in root_find
return iterative_solve(
File ".../venv/lib/python3.11/site-packages/optimistix/_iterate.py", line 346, in iterative_solve
) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
File ".../venv/lib/python3.11/site-packages/optimistix/_adjoint.py", line 148, in apply
return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
File ".../venv/lib/python3.11/site-packages/optimistix/_ad.py", line 72, in implicit_jvp
root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: pytree does not match out_structure
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ".../python/bug_report.py", line 74, in <module>
loss, grads = loss_fn_w_grad(
^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/_ad.py", line 79, in __call__
return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 413, in _vprim_transpose
return transpose(cts, *inputs)
^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 211, in _wrapper
cts = rule(inputs, cts_out)
^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 272, in _linear_solve_transpose
cts_vector, _, _ = eqxi.filter_primitive_bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 264, in filter_primitive_bind
flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 299, in batch_rule
out = _vprim_p.bind(
^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 337, in _vprim_abstract_eval
outs = abstract_eval(*inputs, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 147, in _wrapper
out = rule(*args)
^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 115, in _linear_solve_abstract_eval
out = eqx.filter_eval_shape(
^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 86, in _linear_solve_impl
out = solver.compute(state, vector, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 632, in compute
solution, result, _ = solver.compute(state, vector, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solver/lu.py", line 62, in compute
vector = ravel_vector(vector, packed_structures)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.11/site-packages/lineax/_solver/misc.py", line 84, in ravel_vector
raise ValueError("pytree does not match out_structure")
ValueError: pytree does not match out_structure
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
import equinox as eqx
import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
jax.config.update("jax_enable_x64", True)
VMAP = True
def rf(x, args, g):
c = 1 - x[0] - x[1]
f = -x[0] * jnp.exp(-g) - x[1]
return f, c
def opt_2st_vec(g):
x0 = (1 / 2, 1 / 2)
obj = eqx.Partial(rf, g=g.squeeze())
solver = optx.Newton(atol=1e-8, rtol=1e-8)
solution = optx.root_find(obj, solver, x0)
return jnp.expand_dims(solution.value[1], 0)
class Model(eqx.Module):
input_layer: eqx.nn.Linear
output_layer: eqx.nn.Linear
def __init__(
self,
n_inputs,
key,
):
self.input_layer = eqx.nn.Linear(
in_features=n_inputs, out_features=1, use_bias=False, key=key
)
self.output_layer = eqx.nn.Linear(
in_features=1, out_features=1, use_bias=True, key=key
)
def __call__(self, inputs):
t = self.input_layer(inputs)
return self.output_layer(opt_2st_vec(t))
def loss_fn(
params,
static,
inputs_folding,
target,
):
model = eqx.combine(params, static)
if VMAP:
batched_model = jax.vmap(model)
output = batched_model(
inputs_folding,
)
else:
output = jnp.array([
model(inputs_folding[i])
for i in range(inputs_folding.shape[0])
])
loss = jnp.mean(jnp.abs(target - output[:, 0]))
return loss
inputs = jr.uniform(jr.PRNGKey(0), (128, 10))
target = jr.uniform(jr.PRNGKey(0), (128,))
model = Model(inputs.shape[1], jr.PRNGKey(0))
params, static = eqx.partition(model, eqx.is_array)
loss_fn_w_grad = eqx.filter_value_and_grad(loss_fn)
loss, grads = loss_fn_w_grad(
params,
static,
inputs,
target,
)
equinox==0.11.3
jax==0.4.25
jaxlib==0.4.25
lineax==0.0.4
optimistix==0.0.6