google / jaxopt Goto Github PK
View Code? Open in Web Editor NEWHardware accelerated, batchable and differentiable optimizers in JAX.
Home Page: https://jaxopt.github.io
License: Apache License 2.0
Hardware accelerated, batchable and differentiable optimizers in JAX.
Home Page: https://jaxopt.github.io
License: Apache License 2.0
I'm not entirely sure, but I suspect that tree_where is assuming that a pytree with a single leaf is an np array, which can lead to issues of "TypeError: Field elements must be 2- or 3-tuples,". In my case, I had a list of a single numpy array. I couldn't directly test my hypothesis because I was using tree_where through OSQP, but when I hacked my code for the primal variable in OSQP to have two leaves, the code worked.
This will allow the user to override the default behavior (jax.value_and_grad).
In the never-ending list of nice algorithms to add, there is SR1.
Based on section 2.3 of the paper, I am curious to see how the jacobian error varies as I change x_t
, holding \theta
fixed (to use notation from the paper).
I thought this would be possible for instance by using custom_jvp
and overwriting the primal_out
with x_t
. This would require me, though, to add x_t
as an auxiliary argument to custom_jvp
. I was wondering if this is possible.
It will be easier to discuss this on github rather than internal Google doc.
Current state of API:
Function | Without ridge | With ridge r > 0 | Remark |
---|---|---|---|
solve_cg | Ax=b | (A+rI)=b | well posed because A is PSD |
solve_gmres | Ax=b | (A+rI)=b | ill-posed if A=-rI |
solve_bicgstab | Ax=b | (A+rI)=b | ill-posed if A=-rI |
solve_normal_cg | A^TAx=A^Tb | (A^TA+rI)x=A^Tb | well posed because A^TA is PSD |
There are consistency issues here: with ridge regularization we expect (A^T+rI)(A+rI)x=(A^T+rI)b
for solve_normal_cg
. Consequently all of solve_cg
, solve_gmres
and solve_bicgstab
are interchangeable when r > 0
, but not with solve_nornal_cg
. Worse: when r=0
they are all interchangeable with each other (at least for PD matrices).
Tikhonov regularization
regularizes with A^TA+rI
- just like solve_cg
. This guarantees a well posed problem.
Other observation: most solvers of Sklearn for Ridge regression uses the A^TA+rI
trick.
No one uses A+rI
on a general matrix A: it only makes sense to do so on PSD matrix in general.
Two solutions:
(A^TA+rI)x=A^Tb
into (A^T+rI)(A+rI)x=(A^T+rI)b
, but in this case the problem is ill-posed for A=-rI
.I am in favor of the second option to remain consistent with literature; unless we can prove that the A+rI
approach makes sense.
Hi,
Defining a pre_update function works well in OptaxSolver when using the run_iterator function (like in the MNIST example). However, it does not work with the run function. I checked the implementation and indeed there is no call to the pre_update inside 'run' but there is one inside 'run_iterator'. Is this a small bug or is this by design?
Thanks!
Currently, when has_aux=True
, state.aux
is None
when state
is returned by init_state
and state.aux
is equal to fun(params, *args, **kwargs)[1]
when state
is returned by update
. This is problematic as it can trigger a jit recompilation. One way would be to set state.aux
to some dummy values of the correct type when returned by init_state
.
Some methods, such as line-search based methods, require more function / gradient evaluations than others. It would be great to keep track of the number of such calls in the state. For instance, we could include state.num_fun_calls
and state.num_grad_calls
. This would allow to plot objective value as a function of these numbers and therefore compare various methods objectively.
We should add binary_sparsemax_loss
, the binary classification counterpart of multiclass_sparsemax_loss
, see section 4.4 of https://arxiv.org/abs/1901.02324. This loss is also known modified Huber loss. The associated mapping is the projection on the unit cube so let's add projection_unit_cube
, which is easy to implement using projection_box
. Once done, they need to be added to the documentation.
I see that LbfgsState
contains a stepsize
and that LBFGS.init_state
hard-codes it to 1. I also see that the LBFGS.update
method performs a line search in which the initial step size is set from this LBFGS state.
I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run
that I have overlooked?
Thanks in advance, and thanks for a really cool library.
I would have expected the line search in LBFGS to use the same values for its jit
and unroll
as the enclosing LBFGS optimizer, however this doesn't seem to be the case:
Lines 258 to 262 in 0ff6be8
Is this intentional?
How should we initialize the parameters in this problem setting?
from jaxopt import OSQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
h = jnp.array([0.0, 0.0])
qp = OSQP()
sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params
print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)
TypeError: run() missing 1 required positional argument: 'init_params'
I've been using jaxopt.implicit_diff.custom_root
for differentiating through an jax-md energy minimization routine and I have noticed that if I am using a python for loop for my solver then I get a CustomVJPException
and an additional memory leak.
This memory leak only seems to show up when I get the CustomVJPException
and not when I modify my code to prevent the exception from happening. I believe the underlying reason for that exception is the same as in issue #31 and seems to stem from the fact how jax-md defines its energy functions.
I'd like to know how to change that part of jax-md to prevent the CustomVJPException
from happening in the first place but I've haven't managed to come up with a simplified version that would let me pinpoint the source of the error. But I can give it another shot if that helps you.
Here's a colab demo that demonstrates the issue.
https://colab.research.google.com/drive/1f_3EmFQpvW1p7A1AcNw8uqX5T79fjXRS?usp=sharing
With some solvers (e.g. Newton's method), it is possible to reuse some of the computations for solving the implicit differentiation linear system more efficiently. Since @custom_root
and @custom_fixed_point
accept a solve
argument for specifying a linear solver, the output of the forward pass solver could be given as an argument to solve
.
the example runs fine on CPU, but crashes when run on GPU. I believe the problem is in the innocent-looking normalize
function.
Below is the stack trace
RESOURCE_EXHAUSTED: Out of memory while trying to allocate 188160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 179.44MiB
constant allocation: 0B
maybe_live_out allocation: 179.44MiB
preallocated temp allocation: 0B
total allocation: 358.89MiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 179.44MiB
Entry Parameter Subshape: f32[60000,28,28,1]
==========================
Buffer 2:
Size: 179.44MiB
Operator: op_type="div" op_name="jit(true_divide)/div" source_file="examples/deep_learning/robust_training.py" source_line=52
XLA Label: fusion
Shape: f32[60000,28,28,1]
==========================
Buffer 3:
Size: 4B
Entry Parameter Subshape: f32[]
==========================
Traceback (most recent call last):
File "examples/deep_learning/robust_training.py", line 141, in <module>
train_ds, test_ds = load_datasets()
File "examples/deep_learning/robust_training.py", line 63, in load_datasets
train_ds['image'], test_ds['image'] = map(normalize, (train_ds['image'], test_ds['image']))
File "examples/deep_learning/robust_training.py", line 52, in normalize
return jnp.asarray(images).astype(jnp.float32) / 255.
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6585, in deferring_binary_op
return binary_op(self, other)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 690, in _xla_call_impl
out = compiled_fun(*args)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1100, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 188160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 179.44MiB
constant allocation: 0B
maybe_live_out allocation: 179.44MiB
preallocated temp allocation: 0B
total allocation: 358.89MiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 179.44MiB
Entry Parameter Subshape: f32[60000,28,28,1]
==========================
Buffer 2:
Size: 179.44MiB
Operator: op_type="div" op_name="jit(true_divide)/div" source_file="examples/deep_learning/robust_training.py" source_line=52
XLA Label: fusion
Shape: f32[60000,28,28,1]
==========================
Buffer 3:
Size: 4B
Entry Parameter Subshape: f32[]
==========================
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "examples/deep_learning/robust_training.py", line 141, in <module>
train_ds, test_ds = load_datasets()
File "examples/deep_learning/robust_training.py", line 63, in load_datasets
train_ds['image'], test_ds['image'] = map(normalize, (train_ds['image'], test_ds['image']))
File "examples/deep_learning/robust_training.py", line 52, in normalize
return jnp.asarray(images).astype(jnp.float32) / 255.
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6585, in deferring_binary_op
return binary_op(self, other)
File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1100, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 188160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 179.44MiB
constant allocation: 0B
maybe_live_out allocation: 179.44MiB
preallocated temp allocation: 0B
total allocation: 358.89MiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 179.44MiB
Entry Parameter Subshape: f32[60000,28,28,1]
==========================
Buffer 2:
Size: 179.44MiB
Operator: op_type="div" op_name="jit(true_divide)/div" source_file="examples/deep_learning/robust_training.py" source_line=52
XLA Label: fusion
Shape: f32[60000,28,28,1]
==========================
Buffer 3:
Size: 4B
Entry Parameter Subshape: f32[]
==========================
CC @GeoffNN
Hi all, thanks so much for the work on this project! It's going to have some awesome applications for sure.
I'm currently writing a library on differentiable operations for high-energy physics (here), and wanted to include jaxopt
as a dependency.
PyPI doesn't allow arbitrary repos as dependencies, so I wanted to ask if you're considering publishing the project? :)
jit and unroll don't seem to be documented in https://jaxopt.github.io/stable/_autosummary/jaxopt.GradientDescent.html#jaxopt-gradientdescent
When running old code that used jaxopt==0.1
, I found that it no longer runs due to the paradigm shift of
params, state = solver.init(...)
--> state = solver.init_state(...)
but could not find this in the release notes or documentation. Would be nice to have this written down somewhere just in case people prototyped with early jaxopt
releases.
This is a great package! Are bound constraints for LBFGS on the roadmap?
Is it possible to apply vmap
over outer_objective
?
To be concrete, consider the lasso example. outer_objective
takes as arguments theta, init_inner, data
Is it possible to provide the OptaxSolver
with the following objective function? vmap(outer_objective, in_axes=(None, 0, 0))
?''
---------------------------------------------------------------- EDITED --------------------------------------------------------
Yes you can!
Hi,
I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).
However, both my research code and the following modification of this test diverge even for a single step (maxiter=1
).
The modification just involves setting c=0, so:
def test_qp_eq_only_c_zero(self):
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([0.0, 0.0]) #ONLY CHANGE
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
qp = QuadraticProgramming(tol=1e-7)
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
sol = qp.run(**hyperparams).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_A_and_b(qp, hyperparams, A, b)
Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.
Thanks!
Hi,
Congrats on the great tool!
Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:
TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))
where I'm pretty sure [(*, *), (), (*, *)]
represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None))
could represent the optimality function.
I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).
To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun
, eq_fun
and ineq_fun
to use primal_var[0]
instead of primal_var
. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is:
([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)
Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.
Notice also that the test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).
Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?
Thanks!
Hi all,
Thanks for the awesome library. It would be fantastic to eventually see some options for trust-region optimization. For smaller dimensional problems in my field (e.g., variance components, or their generalization), it provides a much more robust approach for inference compared with 2nd order counterparts.
Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.
Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.
I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.
When A.shape = (N, P)
for N != P
, I run into shape errors when trying to use solve_normal_cg
for fitting the normal equations.
I have a small reproducible example below for N > P
, but the error holds for when P > N
.
import jax.numpy as jnp
import numpy as np
N = 1000
P = 3
prob = np.random.uniform(0.01, 0.5, size=P)
h2g = 0.1
X = np.random.binomial(2, p=prob, size=(N, P))
b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
import jaxopt as jopt
jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [11], in <module>
----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
148 if ridge is not None:
149 _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
--> 151 Ab = _rmatvec(matvec, b)
153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
112 def _rmatvec(matvec, x):
113 """Computes A^T x, from matvec(x) = A x, where A is square."""
--> 114 transpose = jax.linear_transpose(matvec, x)
115 return transpose(x)[0]
File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
2208 in_dtypes = map(dtypes.dtype, in_avals)
2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
-> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
2212 instantiate=True)
2213 out_avals, _ = unzip2(out_pvals)
2214 out_dtypes = map(dtypes.dtype, out_avals)
File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
503 with core.new_main(JaxprTrace) as main:
504 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
506 assert not env
507 del main, fun, env
File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
163 gen = gen_static_args = out_store = None
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
169 # interrupt them before reraising the exception. Otherwise they will only
170 # get garbage-collected at some later time, running their cleanup tasks only
171 # after this exception is handled, which can corrupt the global state.
172 while stack:
Input In [11], in <lambda>(x)
----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
4194 return lax.mul(a, b)
4195 if _max(a_ndim, b_ndim) <= 2:
-> 4196 return lax.dot(a, b, precision=precision)
4198 if b_ndim == 1:
4199 contract_dims = ((a_ndim - 1,), (0,))
File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
664 return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
665 precision=precision, preferred_element_type=preferred_element_type)
666 else:
--> 667 raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
668 lhs.shape, rhs.shape))
TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).
I get the following error "JaxStackTraceBeforeTransformation: NameError: name 'signature' is not defined"
with the following example when I set implicit_diff=True
, but it works when I set implicit_diff=False
import jaxopt
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True)
X = jax.random.normal(jax.random.PRNGKey(1), (100, 10))
y = jax.random.normal(jax.random.PRNGKey(2), (100, 1))
def ridge_reg_objective(params, l2reg, X, y):
residuals = jnp.dot(X, params) - y
return jnp.mean(residuals ** 2) + l2reg*jnp.linalg.norm(params)
def ridge_reg_solution(l2reg, X, y):
gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, implicit_diff=False)
return gd.run(init_params, l2reg=l2reg, X=X, y=y).params
l2reg = 0.01
init_params = jax.random.normal(jax.random.PRNGKey(0), shape=(X.shape[1],))
print(jax.jacobian(ridge_reg_solution, argnums=0)(l2reg, X, y))
Hi,
Here is a use-case
# Volume of a box
def vol(x):
return x[0]*x[1]*x[2]
# Surface of the box
def surf(x):
return 2.*(x[0]*x[1]+x[0]*x[2]+x[1]*x[2])
# Constraint on total surface
def g(x): return surf(x) - 24
#Lagrangien : p[0:3] = (x1,x2,x3), p[3] = multiplicateur de lagrange
@jax.jit
def Lag(p):
return vol(p[0:3]) - p[3]*g(p[0:3])
@jax.jit
def neg_Lag(p):
return -Lag(p)
I can solve this lagrangian based problem of optimisation by hand like this:
#Gradient Lagrangien
gLag = jax.jacfwd(Lag)
hLag = jax.hessian(Lag)
def solveLagrangian(p,lr=0.1):
return p - lr*jnp.linalg.inv(hLag(p)) @ gLag(p)
p_cur = jnp.array([1.5,0.5,1.0,0.1])
for t in range(200):
if (t % 10) == 0:
print(t, p_cur, Lag(p_cur))
new_p = solveLagrangian(p_cur)
rel_err = jnp.max(jnp.abs(p_cur - new_p))
if rel_err < 1e-6:
print(f"Converged after {t} epochs")
break
p_cur = new_p
p_fin=p_cur
v_fin = vol(p_fin[0:3])
s_fin = surf(p_fin[0:3])
print("p_fin: ",p_fin,": True x=y=z=2, lambda=0.5" )
print("v_fin: ",v_fin,": True vol = 2^3")
print("s_fin: ",s_fin,": True surf = 24")
I get
0 [1.5 0.5 1. 0.1] 2.6
10 [1.83358314 1.55167781 1.77716864 0.40679326] 7.609872257244211
20 [1.94433372 1.84887174 1.92882398 0.46981662] 7.95680842439189
30 [1.98087313 1.94785567 1.97583383 0.48971189] 7.9948966103715815
40 [1.99336478 1.98188252 1.99164835 0.49643995] 7.999385438210385
50 [1.99769054 1.99369051 1.99709684 0.49876193] 7.999925528308225
60 [1.99919524 1.99780094 1.9989888 0.4995687 ] 7.999990956279981
70 [1.99971946 1.99923335 1.99964755 0.49984966] 7.9999989009303
80 [1.99990219 1.9997327 1.99987712 0.49994759] 7.99999986639723
90 [1.9999659 1.9999068 1.99995716 0.49998173] 7.999999983757804
100 [1.99998811 1.9999675 1.99998506 0.49999363] 7.999999998025361
110 [1.99999585 1.99998867 1.99999479 0.49999778] 7.999999999759932
Converged after 112 epochs
p_fin: [1.99999664 1.99999082 1.99999578 0.4999982 ] : True x=y=z=2, lambda=0.5
v_fin: 7.9999329788029625 : True vol = 2^3
s_fin: 23.999865957438498 : True surf = 24
Okay, now is it possible to get the result with Optax
solver
opt = optax.adagrad(0.01)
solver = jaxopt.OptaxSolver(opt=opt, fun=neg_Lag, maxiter=2000)
init_params = jnp.array([1.5,0.5,1.0,0.1])
params, state = solver.init(init_params)
print('init', params, neg_Lag(params))
for i in range(2000):
params, state = solver.update(params=params, state=state)
if i%100 == 0:
print(i, params, neg_Lag(params))
Here I get:
init [1.5 0.5 1. 0.1] -2.6
0 [1.50534522 0.50953463 1.00741998 0.10999854] -2.797381684399013
100 [1.42831204 0.63479103 1.0126782 0.28539948] -6.057683328070636
200 [1.28418844 0.59294737 0.86638545 0.37066902] -7.7856194581887825
300 [1.18203126 0.5047476 0.75689433 0.43842879] -9.33122178685714
400 [1.1023296 0.4200589 0.67028698 0.49659404] -10.755253639874878
500 [1.03695639 0.34601625 0.598102 0.54828646] -12.072981877227912
600 [0.98179502 0.28104647 0.53587287 0.59520355] -13.298704241023824
700 [0.93451151 0.22319941 0.48100081 0.63840745] -14.444779039627818
800 [0.89367761 0.17099112 0.43183736 0.67861745] -15.521400626325187
900 [0.85837707 0.12333851 0.38726128 0.71634778] -16.536982715800075
1000 [0.82800999 0.07943679 0.34646534 0.75198133] -17.49857403108653
1100 [0.80218737 0.03866963 0.30883957 0.78581251] -18.412193035926062
1200 [7.80670563e-01 5.49785239e-04 2.73902885e-01 8.18073632e-01] -19.28308214582214
1300 [ 0.76333447 -0.0353201 0.24126048 0.84895208] -20.115900881866448
1400 [ 0.75014358 -0.06927386 0.21057603 0.87860182] -20.914875993347422
1500 [ 0.7411345 -0.10159995 0.18155273 0.90715143] -21.683921834137884
1600 [ 0.7364005 -0.13255408 0.15392011 0.93470985] -22.426739725247586
1700 [ 0.73607461 -0.16236803 0.12742484 0.96137061] -23.146900944390644
1800 [ 0.74030784 -0.19125565 0.10182456 0.987215 ] -23.847914006566935
1900 [ 0.74924029 -0.21941633 0.07688466 1.01231443] -24.533272887793775
which clearly is not the right way to go..
(nb. if I use Lag
as function this does not change the problem: no convergence; idem with sgd/adam...)
Is there a solution to get Optax
solver working?
Thanks
I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.
The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):
@jit
def f1(parameters):
....
return jax.numpy.array([a,b,c])
@jit
def opt_fun(x):
f1(x,params)
....
return float_value
when I call scipy.optimize.bisect(opt_fun,x0,x1)
it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None)
hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.
I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.
Hi, are there any plans to implement SQP with OSQP iterations? Regards,
Traceback (most recent call last):
File "/pscratch/sd/g/gnegiar/neural-dict-pinns/src/main.py", line 291, in <module>
), d_params = value_and_grad(
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/flax/linen/transforms.py", line 310, in wrapped_fn
return trafo_fn(module_scopes, *args, **kwargs)
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/flax/core/lift.py", line 201, in wrapper
y, out_variable_groups_xs_t = fn(
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/flax/core/lift.py", line 939, in inner
return jitted(mutable, variable_groups, rng_groups, *args)
File "/pscratch/sd/g/gnegiar/neural-dict-pinns/src/main.py", line 220, in forward
lam_star, dual_star = solve_QP(
File "/pscratch/sd/g/gnegiar/neural-dict-pinns/src/qp_layer.py", line 49, in solve_QP_jaxopt
sol_pytree = qp_layer.run(x0, (Q, c), (A, b)).params
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py", line 252, in wrapped_solver_fun
return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py", line 206, in solver_fun_fwd
res = solver_fun(*args, **kwargs)
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/eq_qp.py", line 198, in run
primal, dual_eq = self.solve(matvec, target, init_params,
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py", line 176, in solve_gmres
return jax.scipy.sparse.linalg.gmres(matvec, b, tol=tol, **kwargs)[0]
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 686, in gmres
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py", line 34, in ridge_matvec
return tree_add_scalar_mul(matvec(v), ridge, v)
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/tree_util.py", line 46, in tree_add_scalar_mul
return tree_multimap(lambda x, y: x + scalar * y, tree_x, tree_y)
File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/tree_util.py", line 46, in <lambda>
return tree_multimap(lambda x, y: x + scalar * y, tree_x, tree_y)
jax._src.errors.TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(float64[1001])>with<DynamicJaxprTrace(level=3/2)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
standard baselines are able to reach 90+ accuracy without much tweaking (see for example https://github.com/kuangliu/pytorch-cifar ), while our example never goes beyond 70% accuracy on the validation set
jaxopt/jaxopt/_src/tree_util.py
Line 51 in 1d55378
See this colab for details on bug and fix.
Here's a minimal example to reproduce the bug:
from jaxopt import tree_util
a =(1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0)
x =(1.0, {"k1": 2.0, "k2": (3.0, 4.0)}, 5.0)
out = tree_util.tree_vdot(a,x)
The issue is caused by unexpected behaviour of Jax.tree_util.tree_multimap
.
One way to fix this is to broadcast a product with tree_multimap, and then reduce using the sum
def tree_vdot(tree_x, tree_y):
"""Compute the inner product <tree_x, tree_y>."""
prod_pair = tu.tree_multimap(lambda x, y: x*y, tree_x, tree_y)
sums = tu.tree_map(jnp.sum, prod_pair)
return tu.tree_reduce(operator.add, sums)
I would be happy to implement this fix, implement new accompanying unit test, and make a pull request. More details here
See also #106 for a discussion
I've been trying to use OptaxSolver
to perform a simple function minimization, since I want to differentiate through it's solution (the fixed point of the solver), but ran into an issue I'm not familiar with.
Here's a MWE for the error message:
import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax
def pipeline(param_for_grad, data):
def to_minimize(latent):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(init_params = initial)
return result
jax.value_and_grad(pipeline)(2., data=6.)
which yields this error:
CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.
My versions are:
jax==0.2.20
jaxlib==0.1.71
jaxopt==0.0.1
optax==0.0.9
Am I doing something very silly? I guess I'm also wondering if this example within the scope of the solver API? I noticed that this doesn't occur with solver.update
, just with solver.run
.
Thanks :)
Hi,
I experience some pb with projection_polyhedron
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jaxopt
from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron
def myproj3(x):
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
h = jnp.array([0.0, 0.0])
x = projection_polyhedron(x,hyperparams = (A, b, G, h))
return x
rng_key = jax.random.PRNGKey(42)
x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
fig, ax = plt.subplots(figsize=(5,5))
ax.scatter(x[:,0],x[:,1],s=0.5)
ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
ax.set_xlabel("X")
ax.set_ylabel("Y")
plot.show()
First, I had to install cvxpy
#!pip install cvxpy
Then, I got this error
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
with val = DeviceArray([[-2.37103211, 2.33759997],
[ 2.76953806, -2.37750394],
[-0.87246632, 0.73224625],
...,
[ 2.29799773, 2.81894884],
[ 2.4022714 , 0.80693103],
[-0.41563116, 2.83898531]], dtype=float64)
batch_dim = 0
Is anyone has an hint? Thanks
Hello Everyone.
First of all, thanks for this great library!
I'm not sure if the following issue is with jaxopt or jax itself, but I started having problems applying vmap to a root solver.
I tried following the example in 'gradient_descent_test.py' (particularly 'test_jit_and_vmap').
However, I'm getting a TracerArrayConversionError when I try to evaluate the vmapped function.
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[1])>with<BatchTrace(level=1/1)>
with val = DeviceArray([[-17.14141909],[ 58.908974 ],....]], dtype=float64)
batch_dim = 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Below is a simple example to reproduce the issue. Any help would be greatly appreciated.
Thanks in advance.
John
import jax
from jax.config import config; config.update("jax_enable_x64", True)
import jax.numpy as np
from jax import random,vmap
from jaxopt import linear_solve
from jaxopt import ScipyRootFinding
def func(x, params):
a,b = params
return (x - a) * (x - b)
kwargs = {'implicit_diff_solve':linear_solve.solve_normal_cg, 'method':'hybr', 'tol':1e-10}
rootfinder = ScipyRootFinding(optimality_fun=func, **kwargs)
init = np.zeros(1)
def solve(params):
root, info = rootfinder.run(init, params)
return root
key = random.PRNGKey(1235711)
param_list = random.uniform(key, (100, 2), minval = -10, maxval = 10)
root = solve(param_list[0]) # this is ok
root_list = vmap(solve)(param_list) # this doesn't work
p.s. I'm using the following packages
python 3.8.8
jax 0.2.19
jaxlib 0.1.70
jaxopt 0.1
There is an error in the ridge_objective function:
return 0.5 * jnp.mean(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2)
As correctly stated in Figure 1 in the paper https://arxiv.org/pdf/2105.15183.pdf
the ridge objective should be
return 0.5 * jnp.sum(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2)
i.e. change the jnp.mean to jnp.sum.
Originally posted by jecampagne January 14, 2022
Hello,
Let me give you a snippet
import jax
import jax.numpy as jnp
import jaxopt
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
#######
def ridgeless_reg_objective(params, X, y):
residuals = jnp.dot(X, params) - y
return jnp.mean(residuals ** 2)
def gen_x_unif_sphere(r=1.0, d=20,ns=50, seed=42):
""" Generate ns vectors uniform on the sphere of radius r in d-dimension
d=20.
data = gen_x_unif_sphere(r=np.sqrt(d), d=d)
jnp.allclose(jnp.sum(data*data,axis=1),jnp.array([d]*data.shape[0]))
"""
key = jax.random.PRNGKey(seed)
x = jax.random.normal(key, shape=(ns,d))
norm = jnp.linalg.norm(x, axis=1)
x_normed = r * x / norm.reshape(x.shape[0],1)
return x_normed
def func(x):
def p1(x):
return x
def p2(x):
return (x*x - 1.)/jnp.sqrt(2.)
def p3(x):
return x*(x*x-3.)/jnp.sqrt(6.)
def p4(x):
return (x*x*x*x - 6. * x*x +3.)/(2.*jnp.sqrt(6.))
c1 = jnp.sqrt(2./5.)
c2 = jnp.sqrt(1./5.)
return c1*(p1(x)+p2(x))+c2*p3(x)
def gen_y(X, beta):
Xbeta = X @ beta.T # <beta, Xi>
y = func(Xbeta)
return y
def rho_prime(x):
"""
rho(x)=ReLU(x) => rho'(x)=Heaviside(x) with by convention ReLU'(0)=0.
"""
return jnp.heaviside(x, 0.)
def gen_Phi(X, W):
N, d = W.shape
Nd = N*d
ns = X.shape[0]
XW = X @ W.T
rhoXW = rho_prime(XW)/jnp.sqrt(Nd)
return jnp.tile(X,reps=N) * jnp.repeat(rhoXW,repeats=d, axis=1)
##########
d = 15
# gamma = ln(n)/ln(d)
# psi = ln(Nd)/ln(d)
gamma = 3.0
ns = int(d**gamma)
psi = 4.0
N = int(d**(psi-1))
print(f"gamma:{gamma}, d:{d}, ns:{ns}, N:{N}, Nd:{N*d}")
X = gen_x_unif_sphere(r=np.sqrt(d), d=d, ns=ns)
print("X.shape",X.shape)
key = jax.random.PRNGKey(70)
beta = jax.random.normal(key, shape=(1,d))
norm = jnp.linalg.norm(beta, axis=1)
beta = beta / norm
Y = gen_y(X, beta)
print("Y.shape",Y.shape)
W = gen_x_unif_sphere(r=1,d=d,ns=N,seed=60)
print("W.shape",W.shape)
Phi= gen_Phi(X,W)
print("Phi.shape",Phi.shape)
#######
gives
gamma:3.0, d:15, ns:3375, N:3375, Nd:50625
X.shape (3375, 15)
Y.shape (3375, 1)
W.shape (3375, 15)
Phi.shape (3375, 50625)
Now make a regression Y = Phi Theta using Adam
import optax
from jaxopt import OptaxSolver
init_theta = jnp.zeros_like(theta)
solver = OptaxSolver(opt=optax.adam(1e-2), fun=ridgeless_reg_objective,maxiter=1000)
theta_adam = solver.run(init_theta, X=Phi, y=Y)
Of course you can try a least squared min-norm to get theta but even here the matrix are already sizeable. Now, with the LSQ solution I get a train MSE error computed
diff_lsq = Phi @ theta_lsq - Y
MSE_lsq = (diff_lsq.T @ diff_lsq)/diff_lsq.shape[0]
which is equal to `DeviceArray([[5.98907097e-28]]l and is fine as Nd>>ns.
While with Adam I get
diff_adam = Phi @ theta_adam.params - Y
MSE_adam = (diff_adam.T @ diff_adam)/diff_adam.shape[0]
leads to DeviceArray([[0.01393647]])
which is far worse.
I wander if there is some batch options or some parameters to tune to get better results. Thanks.
Thanks for a very cool and useful library! I have a question on inti_params
in BoxOSQP().run.
I just copy and run the code In the tutorial of BoxOSQP.:
import jax.numpy as jnp
from jaxopt import BoxOSQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])
l = jnp.array([1.0, -jnp.inf, -jnp.inf])
u = jnp.array([1.0, 0.0, 0.0])
qp = BoxOSQP()
sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params
and got
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_11430/875702596.py in <module>
9
10 qp = BoxOSQP()
---> 11 sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params
12
13 print(sol.primal)
TypeError: run() missing 1 required positional argument: 'init_params'
So, just adding init_params=None
in the optiion of run, it worked. But I am not confident that it is OK or not.
sol = qp.run(init_params=None,params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params
print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)
(DeviceArray([0.25004143, 0.7500388 ], dtype=float32), DeviceArray([ 1. , -0.25003824, -0.75000846], dtype=float32))
[-2.7502573e+00 0.0000000e+00 3.0822962e-09]
(DeviceArray([0.0000000e+00, 0.0000000e+00, 3.0822962e-09], dtype=float32), DeviceArray([ 2.7502573, -0. , 0. ], dtype=float32))
Thanks in advance.
when running the deep learning examples (say) deep_learning/flax_image_classif.py , the GPU utilization is never above 5%, while for the equivalent flax example the GPU utilization is around 90%, and the example runs more than 20x faster.
My guess is that there's a crucial @jax.jit directive missing somewhere.
Hi,
Thank your for the great work.
I am trying to run python examples/sparse_coding.py but I am getting an ImportingError
python examples/sparse_coding.py
Traceback (most recent call last):
File "examples/sparse_coding.py", line 30, in
from jaxopt import proximal_gradient
ImportError: cannot import name 'proximal_gradient' from 'jaxopt' (/home/gem/repos/jaxopt/env/lib/python3.8/site-packages/jaxopt-0.0.1-py3.8.egg/jaxopt/init.py)
This will require to rename optimality_fun
to fun
in Bisection
.
As an example, the init
argument in solve_gmres
is not passed to the x0
argument in jax.scipy.sparse.linalg.gmres
.
def solve_gmres(matvec: Callable,
b: Any,
ridge: Optional[float] = None,
init: Optional[Any] = None,
tol: float = 1e-5,
**kwargs) -> Any:
"""Solves ``A x = b`` using gmres.
Args:
matvec: product between ``A`` and a vector.
b: pytree.
ridge: optional ridge regularization.
init: optional initialization to be used by gmres.
**kwargs: additional keyword arguments for solver.
Returns:
pytree with same structure as ``b``.
"""
if ridge is not None:
matvec = _make_ridge_matvec(matvec, ridge=ridge)
return jax.scipy.sparse.linalg.gmres(matvec, b, tol=tol, **kwargs)[0]
I can take care of this soon.
First of all, thanks a lot for this library! Really useful tools!
I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.
Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:
def inv_f(x, aux):
bisec = Bisection(optimality_fun=F, lower=0.0, upper=1.,
check_bracket=False, unroll=True)
return bisec.run(aux=aux).params
# Here I extract the value part of the vjp, but the grad part also gives wrong results
test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0]
jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients
Here I'm only trying to get gradients of the value returned by jax.value_and_grad
, but the gradients of the gradients returned by jax.value_and_grad
are also wrong (but not as obvious).
I made a small demo notebook that reproduces this issue here.
As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.
Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).
I benchmarked GMRES (currently default in jax.eq_qp) vs other scipy
solvers: minres and LGMRES. I sampled random equality constrained QP KKT matrices and targets, and found pretty stark differences between the solvers. Primal dim: 1500, dual dim: 1000.
TLDR: LGMRES is much faster, and GMRES is the slowest of the three. NB: Minres requires the matrix in the linear system to be symmetric (although indefinite), while gmres and lgmres don't.
I recall a discussion about linear solvers living in the jaxopt
package vs jax.scipy.sparse.linalg
.
I don't have time right now to implement LGMRES in jax, (and would be grateful if anyone has the bandwidth and the interest) but I will try to get around to it in the next few weeks. Should the solver's code live in jaxopt?
Colab for repro: https://colab.research.google.com/drive/1Ge1-gmuknDQq0rHpnSrHpvwLbU23i6YG?usp=sharing
When the projection set is contained in the non-negative orthant (simplex, Birkhoff, ...), it makes sense to do a projection in the KL sense. We should start a new module jaxopt.kl_projection
for that.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.