Git Product home page Git Product logo

jaxopt's Introduction

JAXopt

Installation | Documentation | Examples | Cite us

⚠️ We are in the process of merging JAXopt into Optax. Because of this, JAXopt is now in maintenance mode and we will not be implementing new features ⚠️

Hardware accelerated, batchable and differentiable optimizers in JAX.

  • Hardware accelerated: our implementations run on GPU and TPU, in addition to CPU.
  • Batchable: multiple instances of the same optimization problem can be automatically vectorized using JAX's vmap.
  • Differentiable: optimization problem solutions can be differentiated with respect to their inputs either implicitly or via autodiff of unrolled algorithm iterations.

Installation

To install the latest release of JAXopt, use the following command:

$ pip install jaxopt

To install the development version, use the following command instead:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be installed from sources with the following command:

$ python setup.py install

Cite us

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy 
    and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian 
    and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

jaxopt's People

Contributors

algue-rythme avatar amelieh avatar amir-saadat avatar aymgal avatar ayush-1506 avatar emilyfertig avatar fabianp avatar fllinares avatar froystig avatar geggo avatar geoffnn avatar gowerrobert avatar hawkinsp avatar ianwilliamson avatar liutianlin0121 avatar marcocuturi avatar mblondel avatar michaelsdr avatar neilgirdhar avatar phinate avatar pipme avatar q-berthet avatar srvasude avatar superbobry avatar tachukao avatar vikas-sindhwani avatar vroulet avatar yashk2810 avatar yueshengys avatar zaccharieramzi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxopt's Issues

`solver.init` -> `solver.init_state` is an undocumented breaking change

When running old code that used jaxopt==0.1, I found that it no longer runs due to the paradigm shift of

params, state = solver.init(...) --> state = solver.init_state(...)

but could not find this in the release notes or documentation. Would be nice to have this written down somewhere just in case people prototyped with early jaxopt releases.

Keep track of number of function / gradient evaluations

Some methods, such as line-search based methods, require more function / gradient evaluations than others. It would be great to keep track of the number of such calls in the state. For instance, we could include state.num_fun_calls and state.num_grad_calls. This would allow to plot objective value as a function of these numbers and therefore compare various methods objectively.

Incompatible shape in solve_normal_cg

When A.shape = (N, P) for N != P, I run into shape errors when trying to use solve_normal_cg for fitting the normal equations.

I have a small reproducible example below for N > P, but the error holds for when P > N.

import jax.numpy as jnp
import numpy as np
N = 1000
P = 3
prob = np.random.uniform(0.01, 0.5, size=P)
h2g = 0.1
X = np.random.binomial(2, p=prob, size=(N, P))
b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))

import jaxopt as jopt
jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [11], in <module>
----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)

File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
    148 if ridge is not None:
    149   _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
--> 151 Ab = _rmatvec(matvec, b)
    153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]

File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
    112 def _rmatvec(matvec, x):
    113   """Computes A^T x, from matvec(x) = A x, where A is square."""
--> 114   transpose = jax.linear_transpose(matvec, x)
    115   return transpose(x)[0]

File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
   2208 in_dtypes = map(dtypes.dtype, in_avals)
   2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
-> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
   2212                                              instantiate=True)
   2213 out_avals, _ = unzip2(out_pvals)
   2214 out_dtypes = map(dtypes.dtype, out_avals)

File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
    503 with core.new_main(JaxprTrace) as main:
    504   fun = trace_to_subjaxpr(fun, main, instantiate)
--> 505   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    506   assert not env
    507   del main, fun, env

File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
    163 gen = gen_static_args = out_store = None
    165 try:
--> 166   ans = self.f(*args, **dict(self.params, **kwargs))
    167 except:
    168   # Some transformations yield from inside context managers, so we have to
    169   # interrupt them before reraising the exception. Otherwise they will only
    170   # get garbage-collected at some later time, running their cleanup tasks only
    171   # after this exception is handled, which can corrupt the global state.
    172   while stack:

Input In [11], in <lambda>(x)
----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)

File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
   4194   return lax.mul(a, b)
   4195 if _max(a_ndim, b_ndim) <= 2:
-> 4196   return lax.dot(a, b, precision=precision)
   4198 if b_ndim == 1:
   4199   contract_dims = ((a_ndim - 1,), (0,))

File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
    664   return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
    665                      precision=precision, preferred_element_type=preferred_element_type)
    666 else:
--> 667   raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
    668       lhs.shape, rhs.shape))

TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).

SQP

Hi, are there any plans to implement SQP with OSQP iterations? Regards,

trust-region optimization

Hi all,

Thanks for the awesome library. It would be fantastic to eventually see some options for trust-region optimization. For smaller dimensional problems in my field (e.g., variance components, or their generalization), it provides a much more robust approach for inference compared with 2nd order counterparts.

Ridge Regularization in linear solve: consistency

It will be easier to discuss this on github rather than internal Google doc.

The issue

Current state of API:

Function Without ridge With ridge r > 0 Remark
solve_cg Ax=b (A+rI)=b well posed because A is PSD
solve_gmres Ax=b (A+rI)=b ill-posed if A=-rI
solve_bicgstab Ax=b (A+rI)=b ill-posed if A=-rI
solve_normal_cg A^TAx=A^Tb (A^TA+rI)x=A^Tb well posed because A^TA is PSD

There are consistency issues here: with ridge regularization we expect (A^T+rI)(A+rI)x=(A^T+rI)b for solve_normal_cg. Consequently all of solve_cg, solve_gmres and solve_bicgstab are interchangeable when r > 0, but not with solve_nornal_cg. Worse: when r=0 they are all interchangeable with each other (at least for PD matrices).

Discussion:

Tikhonov regularization
regularizes with A^TA+rI - just like solve_cg. This guarantees a well posed problem.

Other observation: most solvers of Sklearn for Ridge regression uses the A^TA+rI trick.
No one uses A+rI on a general matrix A: it only makes sense to do so on PSD matrix in general.

Solution

Two solutions:

  1. Change (A^TA+rI)x=A^Tb into (A^T+rI)(A+rI)x=(A^T+rI)b, but in this case the problem is ill-posed for A=-rI.
  2. Remove ridge regularization from gmres/bicgstab because currently the regularization may lead to matrix with worse condition number (unexpected behavior when we regularize a system). This happens in particular for NSD matrices.

I am in favor of the second option to remain consistent with literature; unless we can prove that the A+rI approach makes sense.

Using refine_regularization=0. breaks jit compile in EqualityConstrainedQP

Traceback (most recent call last):
  File "/pscratch/sd/g/gnegiar/neural-dict-pinns/src/main.py", line 291, in <module>
    ), d_params = value_and_grad(
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/flax/linen/transforms.py", line 310, in wrapped_fn
    return trafo_fn(module_scopes, *args, **kwargs)
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/flax/core/lift.py", line 201, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/flax/core/lift.py", line 939, in inner
    return jitted(mutable, variable_groups, rng_groups, *args)
  File "/pscratch/sd/g/gnegiar/neural-dict-pinns/src/main.py", line 220, in forward
    lam_star, dual_star = solve_QP(
  File "/pscratch/sd/g/gnegiar/neural-dict-pinns/src/qp_layer.py", line 49, in solve_QP_jaxopt
    sol_pytree = qp_layer.run(x0, (Q, c), (A, b)).params
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py", line 252, in wrapped_solver_fun
    return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/implicit_diff.py", line 206, in solver_fun_fwd
    res = solver_fun(*args, **kwargs)
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/eq_qp.py", line 198, in run
    primal, dual_eq = self.solve(matvec, target, init_params,
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py", line 176, in solve_gmres
    return jax.scipy.sparse.linalg.gmres(matvec, b, tol=tol, **kwargs)[0]
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jax/_src/scipy/sparse/linalg.py", line 686, in gmres
    x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py", line 34, in ridge_matvec
    return tree_add_scalar_mul(matvec(v), ridge, v)
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/tree_util.py", line 46, in tree_add_scalar_mul
    return tree_multimap(lambda x, y: x + scalar * y, tree_x, tree_y)
  File "/global/homes/g/gnegiar/.conda/envs/neural-dict/lib/python3.9/site-packages/jaxopt/_src/tree_util.py", line 46, in <lambda>
    return tree_multimap(lambda x, y: x + scalar * y, tree_x, tree_y)
jax._src.errors.TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(float64[1001])>with<DynamicJaxprTrace(level=3/2)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

Add binary_sparsemax_loss and projection_unit_cube

We should add binary_sparsemax_loss, the binary classification counterpart of multiclass_sparsemax_loss, see section 4.4 of https://arxiv.org/abs/1901.02324. This loss is also known modified Huber loss. The associated mapping is the projection on the unit cube so let's add projection_unit_cube, which is easy to implement using projection_box. Once done, they need to be added to the documentation.

Publication to PyPI

Hi all, thanks so much for the work on this project! It's going to have some awesome applications for sure.

I'm currently writing a library on differentiable operations for high-energy physics (here), and wanted to include jaxopt as a dependency.

PyPI doesn't allow arbitrary repos as dependencies, so I wanted to ask if you're considering publishing the project? :)

Mini Batch with OptaxSolver

Discussed in #150

Originally posted by jecampagne January 14, 2022
Hello,
Let me give you a snippet

import jax
import jax.numpy as jnp
import jaxopt

from jax.config import config
config.update("jax_enable_x64", True)

import numpy as np
#######
def ridgeless_reg_objective(params, X, y):
  residuals = jnp.dot(X, params) - y
  return jnp.mean(residuals ** 2)

def gen_x_unif_sphere(r=1.0, d=20,ns=50, seed=42):
    """ Generate ns vectors uniform on the sphere of radius r in d-dimension
        d=20.
        data = gen_x_unif_sphere(r=np.sqrt(d), d=d)
        jnp.allclose(jnp.sum(data*data,axis=1),jnp.array([d]*data.shape[0]))
        
    """
    key = jax.random.PRNGKey(seed)
    x = jax.random.normal(key, shape=(ns,d))
    norm = jnp.linalg.norm(x, axis=1)
    x_normed = r * x / norm.reshape(x.shape[0],1)
    return x_normed

def func(x):
    def p1(x):
        return x
    def p2(x):
        return (x*x - 1.)/jnp.sqrt(2.)
    def p3(x):
        return x*(x*x-3.)/jnp.sqrt(6.)
    def p4(x):
        return (x*x*x*x - 6. * x*x +3.)/(2.*jnp.sqrt(6.))
    c1 = jnp.sqrt(2./5.)
    c2 = jnp.sqrt(1./5.)
    return c1*(p1(x)+p2(x))+c2*p3(x)


def gen_y(X, beta):
    Xbeta = X @ beta.T  # <beta, Xi>
    y = func(Xbeta)
    return y

def rho_prime(x):
    """
        rho(x)=ReLU(x) => rho'(x)=Heaviside(x) with  by convention ReLU'(0)=0.
    """
    return jnp.heaviside(x, 0.)

def gen_Phi(X, W):
    N, d = W.shape
    Nd = N*d
    ns = X.shape[0]
    XW = X @ W.T
    rhoXW = rho_prime(XW)/jnp.sqrt(Nd)
    return jnp.tile(X,reps=N) * jnp.repeat(rhoXW,repeats=d, axis=1)

##########
d = 15
# gamma = ln(n)/ln(d)
# psi   = ln(Nd)/ln(d)
gamma = 3.0
ns = int(d**gamma)
psi = 4.0
N = int(d**(psi-1))
print(f"gamma:{gamma}, d:{d}, ns:{ns}, N:{N}, Nd:{N*d}")

X = gen_x_unif_sphere(r=np.sqrt(d), d=d, ns=ns)
print("X.shape",X.shape)

key = jax.random.PRNGKey(70)
beta = jax.random.normal(key, shape=(1,d))
norm = jnp.linalg.norm(beta, axis=1)
beta =  beta / norm


Y = gen_y(X, beta)
print("Y.shape",Y.shape)

W = gen_x_unif_sphere(r=1,d=d,ns=N,seed=60)
print("W.shape",W.shape)


Phi= gen_Phi(X,W)
print("Phi.shape",Phi.shape)

#######

gives

gamma:3.0, d:15, ns:3375, N:3375, Nd:50625
X.shape (3375, 15)
Y.shape (3375, 1)
W.shape (3375, 15)
Phi.shape (3375, 50625)

Now make a regression Y = Phi Theta using Adam

import optax
from jaxopt import OptaxSolver
init_theta = jnp.zeros_like(theta)
solver = OptaxSolver(opt=optax.adam(1e-2), fun=ridgeless_reg_objective,maxiter=1000)
theta_adam = solver.run(init_theta, X=Phi, y=Y)

Of course you can try a least squared min-norm to get theta but even here the matrix are already sizeable. Now, with the LSQ solution I get a train MSE error computed

diff_lsq = Phi @ theta_lsq - Y
MSE_lsq = (diff_lsq.T @ diff_lsq)/diff_lsq.shape[0]

which is equal to `DeviceArray([[5.98907097e-28]]l and is fine as Nd>>ns.

While with Adam I get

diff_adam = Phi @ theta_adam.params - Y
MSE_adam = (diff_adam.T @ diff_adam)/diff_adam.shape[0]

leads to DeviceArray([[0.01393647]]) which is far worse.

I wander if there is some batch options or some parameters to tune to get better results. Thanks.

Linear solver benchmarking on EqQPs

I benchmarked GMRES (currently default in jax.eq_qp) vs other scipy solvers: minres and LGMRES. I sampled random equality constrained QP KKT matrices and targets, and found pretty stark differences between the solvers. Primal dim: 1500, dual dim: 1000.

TLDR: LGMRES is much faster, and GMRES is the slowest of the three. NB: Minres requires the matrix in the linear system to be symmetric (although indefinite), while gmres and lgmres don't.

I recall a discussion about linear solvers living in the jaxopt package vs jax.scipy.sparse.linalg.
I don't have time right now to implement LGMRES in jax, (and would be grateful if anyone has the bandwidth and the interest) but I will try to get around to it in the next few weeks. Should the solver's code live in jaxopt?

Colab for repro: https://colab.research.google.com/drive/1Ge1-gmuknDQq0rHpnSrHpvwLbU23i6YG?usp=sharing

image

KKT conditions when the primal solution is a pytree

Hi,
Congrats on the great tool!
Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is:
([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

Notice also that the test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?

Thanks!

Avoid state.aux=None in state returned by initial_state

Currently, when has_aux=True, state.aux is None when state is returned by init_state and state.aux is equal to fun(params, *args, **kwargs)[1] when state is returned by update. This is problematic as it can trigger a jit recompilation. One way would be to set state.aux to some dummy values of the correct type when returned by init_state.

pre_update for OptaxSolver run function?

Hi,
Defining a pre_update function works well in OptaxSolver when using the run_iterator function (like in the MNIST example). However, it does not work with the run function. I checked the implementation and indeed there is no call to the pre_update inside 'run' but there is one inside 'run_iterator'. Is this a small bug or is this by design?

Thanks!

Reusing forward computation for implicit diff

With some solvers (e.g. Newton's method), it is possible to reuse some of the computations for solving the implicit differentiation linear system more efficiently. Since @custom_root and @custom_fixed_point accept a solve argument for specifying a linear solver, the output of the forward pass solver could be given as an argument to solve.

name 'signature' is not defined

I get the following error "JaxStackTraceBeforeTransformation: NameError: name 'signature' is not defined"
with the following example when I set implicit_diff=True, but it works when I set implicit_diff=False

import jaxopt
import jax 
import jax.numpy as jnp 
import matplotlib.pyplot as plt 
jax.config.update("jax_enable_x64", True)

X = jax.random.normal(jax.random.PRNGKey(1), (100, 10))
y =  jax.random.normal(jax.random.PRNGKey(2), (100, 1))

def ridge_reg_objective(params, l2reg, X, y):
  residuals = jnp.dot(X, params) - y
  return jnp.mean(residuals ** 2) + l2reg*jnp.linalg.norm(params)

def ridge_reg_solution(l2reg, X, y):
  gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, implicit_diff=False)
  return gd.run(init_params, l2reg=l2reg, X=X, y=y).params

l2reg = 0.01
init_params = jax.random.normal(jax.random.PRNGKey(0), shape=(X.shape[1],))
print(jax.jacobian(ridge_reg_solution, argnums=0)(l2reg, X, y))

Complex gradients

Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.

Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.

I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.

"ImportError: Cannot import name 'proximal_gradient' from jaxopt" when running python examples/sparse_coding.py

Hi,
Thank your for the great work.

I am trying to run python examples/sparse_coding.py but I am getting an ImportingError

python examples/sparse_coding.py
Traceback (most recent call last):
File "examples/sparse_coding.py", line 30, in
from jaxopt import proximal_gradient
ImportError: cannot import name 'proximal_gradient' from 'jaxopt' (/home/gem/repos/jaxopt/env/lib/python3.8/site-packages/jaxopt-0.0.1-py3.8.egg/jaxopt/init.py)

Initial stepsize not exposed in LBFGS constructor [question/bug?]

I see that LbfgsState contains a stepsize and that LBFGS.init_state hard-codes it to 1. I also see that the LBFGS.update method performs a line search in which the initial step size is set from this LBFGS state.

I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run that I have overlooked?

Thanks in advance, and thanks for a really cool library.

OptaxSolver how to proceed?

Hi,
Here is a use-case

# Volume of a box
def vol(x): 
    return x[0]*x[1]*x[2]
# Surface of the box
def surf(x):
    return 2.*(x[0]*x[1]+x[0]*x[2]+x[1]*x[2])
# Constraint on total surface
def g(x): return surf(x) - 24

#Lagrangien : p[0:3] = (x1,x2,x3), p[3] = multiplicateur de lagrange
@jax.jit
def Lag(p): 
    return vol(p[0:3]) - p[3]*g(p[0:3])
@jax.jit
def neg_Lag(p):
    return -Lag(p)

I can solve this lagrangian based problem of optimisation by hand like this:

#Gradient Lagrangien
gLag = jax.jacfwd(Lag)
hLag = jax.hessian(Lag)

def solveLagrangian(p,lr=0.1): 
    return p - lr*jnp.linalg.inv(hLag(p)) @ gLag(p)

p_cur = jnp.array([1.5,0.5,1.0,0.1])

for t in range(200):

    if (t % 10) == 0:
        print(t, p_cur, Lag(p_cur))

    new_p = solveLagrangian(p_cur)
    
    rel_err = jnp.max(jnp.abs(p_cur - new_p))
    if rel_err < 1e-6:
        print(f"Converged after {t} epochs")
        break
    
    p_cur = new_p

p_fin=p_cur
v_fin = vol(p_fin[0:3])
s_fin = surf(p_fin[0:3])

print("p_fin: ",p_fin,": True x=y=z=2, lambda=0.5" )
print("v_fin: ",v_fin,": True vol  = 2^3")
print("s_fin: ",s_fin,": True surf = 24")

I get

0 [1.5 0.5 1.  0.1] 2.6
10 [1.83358314 1.55167781 1.77716864 0.40679326] 7.609872257244211
20 [1.94433372 1.84887174 1.92882398 0.46981662] 7.95680842439189
30 [1.98087313 1.94785567 1.97583383 0.48971189] 7.9948966103715815
40 [1.99336478 1.98188252 1.99164835 0.49643995] 7.999385438210385
50 [1.99769054 1.99369051 1.99709684 0.49876193] 7.999925528308225
60 [1.99919524 1.99780094 1.9989888  0.4995687 ] 7.999990956279981
70 [1.99971946 1.99923335 1.99964755 0.49984966] 7.9999989009303
80 [1.99990219 1.9997327  1.99987712 0.49994759] 7.99999986639723
90 [1.9999659  1.9999068  1.99995716 0.49998173] 7.999999983757804
100 [1.99998811 1.9999675  1.99998506 0.49999363] 7.999999998025361
110 [1.99999585 1.99998867 1.99999479 0.49999778] 7.999999999759932
Converged after 112 epochs
p_fin:  [1.99999664 1.99999082 1.99999578 0.4999982 ] : True x=y=z=2, lambda=0.5
v_fin:  7.9999329788029625 : True vol  = 2^3
s_fin:  23.999865957438498 : True surf = 24

Okay, now is it possible to get the result with Optaxsolver

opt = optax.adagrad(0.01)
solver = jaxopt.OptaxSolver(opt=opt, fun=neg_Lag, maxiter=2000)
init_params = jnp.array([1.5,0.5,1.0,0.1])
params, state = solver.init(init_params)
print('init', params, neg_Lag(params))
for i in range(2000):
    params, state = solver.update(params=params, state=state)
    if i%100 == 0: 
        print(i, params, neg_Lag(params))

Here I get:

init [1.5 0.5 1.  0.1] -2.6
0 [1.50534522 0.50953463 1.00741998 0.10999854] -2.797381684399013
100 [1.42831204 0.63479103 1.0126782  0.28539948] -6.057683328070636
200 [1.28418844 0.59294737 0.86638545 0.37066902] -7.7856194581887825
300 [1.18203126 0.5047476  0.75689433 0.43842879] -9.33122178685714
400 [1.1023296  0.4200589  0.67028698 0.49659404] -10.755253639874878
500 [1.03695639 0.34601625 0.598102   0.54828646] -12.072981877227912
600 [0.98179502 0.28104647 0.53587287 0.59520355] -13.298704241023824
700 [0.93451151 0.22319941 0.48100081 0.63840745] -14.444779039627818
800 [0.89367761 0.17099112 0.43183736 0.67861745] -15.521400626325187
900 [0.85837707 0.12333851 0.38726128 0.71634778] -16.536982715800075
1000 [0.82800999 0.07943679 0.34646534 0.75198133] -17.49857403108653
1100 [0.80218737 0.03866963 0.30883957 0.78581251] -18.412193035926062
1200 [7.80670563e-01 5.49785239e-04 2.73902885e-01 8.18073632e-01] -19.28308214582214
1300 [ 0.76333447 -0.0353201   0.24126048  0.84895208] -20.115900881866448
1400 [ 0.75014358 -0.06927386  0.21057603  0.87860182] -20.914875993347422
1500 [ 0.7411345  -0.10159995  0.18155273  0.90715143] -21.683921834137884
1600 [ 0.7364005  -0.13255408  0.15392011  0.93470985] -22.426739725247586
1700 [ 0.73607461 -0.16236803  0.12742484  0.96137061] -23.146900944390644
1800 [ 0.74030784 -0.19125565  0.10182456  0.987215  ] -23.847914006566935
1900 [ 0.74924029 -0.21941633  0.07688466  1.01231443] -24.533272887793775

which clearly is not the right way to go..
(nb. if I use Lagas function this does not change the problem: no convergence; idem with sgd/adam...)

Is there a solution to get Optax solver working?
Thanks

Issue with gradients wrt optimality fn parameters through root finding vjp

First of all, thanks a lot for this library! Really useful tools!
I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.

Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:

def inv_f(x, aux):
  bisec = Bisection(optimality_fun=F, lower=0.0, upper=1., 
                    check_bracket=False, unroll=True)
  return bisec.run(aux=aux).params

# Here I extract the value part of the vjp, but the grad part also gives wrong results
test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0] 

jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients

Here I'm only trying to get gradients of the value returned by jax.value_and_grad, but the gradients of the gradients returned by jax.value_and_grad are also wrong (but not as obvious).

I made a small demo notebook that reproduces this issue here.

As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.

Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).

OSQP init_params

How should we initialize the parameters in this problem setting?

from jaxopt import OSQP

Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
h = jnp.array([0.0, 0.0])

qp = OSQP()
sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params

print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)

TypeError: run() missing 1 required positional argument: 'init_params'

Linear solvers don't use the `init` argument to warm start solvers.

As an example, the init argument in solve_gmres is not passed to the x0 argument in jax.scipy.sparse.linalg.gmres.

def solve_gmres(matvec: Callable,
                b: Any,
                ridge: Optional[float] = None,
                init: Optional[Any] = None,
                tol: float = 1e-5,
                **kwargs) -> Any:
  """Solves ``A x = b`` using gmres.
  Args:
    matvec: product between ``A`` and a vector.
    b: pytree.
    ridge: optional ridge regularization.
    init: optional initialization to be used by gmres.
    **kwargs: additional keyword arguments for solver.
  Returns:
    pytree with same structure as ``b``.
  """
  if ridge is not None:
    matvec = _make_ridge_matvec(matvec, ridge=ridge)
  return jax.scipy.sparse.linalg.gmres(matvec, b, tol=tol, **kwargs)[0]

I can take care of this soon.

Infinities and NaNs in quadratic_prog when c=0

Hi,

I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

The modification just involves setting c=0, so:

def test_qp_eq_only_c_zero(self):
  Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
  c = jnp.array([0.0, 0.0]) #ONLY CHANGE
  A = jnp.array([[1.0, 1.0]])
  b = jnp.array([1.0])
  qp = QuadraticProgramming(tol=1e-7)
  hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
  sol = qp.run(**hyperparams).params
  self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
  self._check_derivative_A_and_b(qp, hyperparams, A, b)

Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

Thanks!

Possible bug in tree_where when tree has a single leaf

I'm not entirely sure, but I suspect that tree_where is assuming that a pytree with a single leaf is an np array, which can lead to issues of "TypeError: Field elements must be 2- or 3-tuples,". In my case, I had a list of a single numpy array. I couldn't directly test my hypothesis because I was using tree_where through OSQP, but when I hacked my code for the primal variable in OSQP to have two leaves, the code worked.

tree_util.tree_vdot bugging on example.

def tree_vdot(tree_x, tree_y):

See this colab for details on bug and fix.

Here's a minimal example to reproduce the bug:

from jaxopt import tree_util
a =(1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0)
x =(1.0, {"k1": 2.0, "k2": (3.0, 4.0)}, 5.0)
out = tree_util.tree_vdot(a,x)

The issue is caused by unexpected behaviour of Jax.tree_util.tree_multimap.

One way to fix this is to broadcast a product with tree_multimap, and then reduce using the sum

def tree_vdot(tree_x, tree_y):
  """Compute the inner product <tree_x, tree_y>."""
  prod_pair = tu.tree_multimap(lambda x, y: x*y, tree_x, tree_y)
  sums = tu.tree_map(jnp.sum, prod_pair)
  return tu.tree_reduce(operator.add, sums)

I would be happy to implement this fix, implement new accompanying unit test, and make a pull request. More details here

Bisection hanging

I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.

The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):

@jit
def f1(parameters):
    ....
    return jax.numpy.array([a,b,c])

@jit
def opt_fun(x):
    f1(x,params)
    .... 
    return float_value

when I call scipy.optimize.bisect(opt_fun,x0,x1) it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None) hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.

I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.

Jacobian Estimate Error

Based on section 2.3 of the paper, I am curious to see how the jacobian error varies as I change x_t, holding \theta fixed (to use notation from the paper).

I thought this would be possible for instance by using custom_jvp and overwriting the primal_out with x_t. This would require me, though, to add x_t as an auxiliary argument to custom_jvp. I was wondering if this is possible.

vmap support in QPs

Hi,
I experience some pb with projection_polyhedron

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import jaxopt
from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron

def myproj3(x):
    A = jnp.array([[1.0, 1.0]])
    b = jnp.array([1.0])
    G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
    h = jnp.array([0.0, 0.0])    
    x = projection_polyhedron(x,hyperparams = (A, b, G, h))
    return x

rng_key = jax.random.PRNGKey(42)
x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
fig, ax = plt.subplots(figsize=(5,5))
ax.scatter(x[:,0],x[:,1],s=0.5)
ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
ax.set_xlabel("X")
ax.set_ylabel("Y")
plot.show()

First, I had to install cvxpy
#!pip install cvxpy
Then, I got this error

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
  with val = DeviceArray([[-2.37103211,  2.33759997],
                          [ 2.76953806, -2.37750394],
                          [-0.87246632,  0.73224625],
                          ...,
                          [ 2.29799773,  2.81894884],
                          [ 2.4022714 ,  0.80693103],
                          [-0.41563116,  2.83898531]], dtype=float64)
       batch_dim = 0

Is anyone has an hint? Thanks

Problem with vmapping rootfinder

Hello Everyone.

First of all, thanks for this great library!

I'm not sure if the following issue is with jaxopt or jax itself, but I started having problems applying vmap to a root solver.
I tried following the example in 'gradient_descent_test.py' (particularly 'test_jit_and_vmap').
However, I'm getting a TracerArrayConversionError when I try to evaluate the vmapped function.

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[1])>with<BatchTrace(level=1/1)>
  with val = DeviceArray([[-17.14141909],[ 58.908974  ],....]], dtype=float64)
       batch_dim = 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Below is a simple example to reproduce the issue. Any help would be greatly appreciated.
Thanks in advance.
John

import jax
from   jax.config import config; config.update("jax_enable_x64", True)
import jax.numpy as np
from   jax import random,vmap

from jaxopt import linear_solve
from jaxopt import ScipyRootFinding

def func(x, params):
    a,b = params
    return (x - a) * (x - b)
kwargs = {'implicit_diff_solve':linear_solve.solve_normal_cg, 'method':'hybr', 'tol':1e-10}
rootfinder = ScipyRootFinding(optimality_fun=func, **kwargs)

init = np.zeros(1)
def solve(params):
    root, info = rootfinder.run(init, params)
    return root
key        = random.PRNGKey(1235711)
param_list = random.uniform(key, (100, 2), minval = -10, maxval = 10)

root       = solve(param_list[0]) # this is ok
root_list  = vmap(solve)(param_list) # this doesn't work

p.s. I'm using the following packages
python 3.8.8
jax 0.2.19
jaxlib 0.1.70
jaxopt 0.1

init_params in BoxOSQP().run

Thanks for a very cool and useful library! I have a question on inti_params in BoxOSQP().run.
I just copy and run the code In the tutorial of BoxOSQP.:

import jax.numpy as jnp
from jaxopt import BoxOSQP

Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])
l = jnp.array([1.0, -jnp.inf, -jnp.inf])
u = jnp.array([1.0, 0.0, 0.0])

qp = BoxOSQP()
sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params

and got

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_11430/875702596.py in <module>
      9 
     10 qp = BoxOSQP()
---> 11 sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params
     12 
     13 print(sol.primal)

TypeError: run() missing 1 required positional argument: 'init_params'

So, just adding init_params=None in the optiion of run, it worked. But I am not confident that it is OK or not.

sol = qp.run(init_params=None,params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params

print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)

(DeviceArray([0.25004143, 0.7500388 ], dtype=float32), DeviceArray([ 1.        , -0.25003824, -0.75000846], dtype=float32))
[-2.7502573e+00  0.0000000e+00  3.0822962e-09]
(DeviceArray([0.0000000e+00, 0.0000000e+00, 3.0822962e-09], dtype=float32), DeviceArray([ 2.7502573, -0.       ,  0.       ], dtype=float32))

Thanks in advance.

Problem differentiating through `solver.run` in `OptaxSolver`

I've been trying to use OptaxSolver to perform a simple function minimization, since I want to differentiate through it's solution (the fixed point of the solver), but ran into an issue I'm not familiar with.

Here's a MWE for the error message:

import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax

def pipeline(param_for_grad, data):
    def to_minimize(latent):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(init_params = initial)

    return result

jax.value_and_grad(pipeline)(2., data=6.)

which yields this error:

CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.

My versions are:

jax==0.2.20
jaxlib==0.1.71
jaxopt==0.0.1
optax==0.0.9

Am I doing something very silly? I guess I'm also wondering if this example within the scope of the solver API? I noticed that this doesn't occur with solver.update, just with solver.run.

Thanks :)

robust_trainin.py crashes on GPU

the example runs fine on CPU, but crashes when run on GPU. I believe the problem is in the innocent-looking normalize function.

Below is the stack trace

RESOURCE_EXHAUSTED: Out of memory while trying to allocate 188160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  179.44MiB
              constant allocation:         0B
        maybe_live_out allocation:  179.44MiB
     preallocated temp allocation:         0B
                 total allocation:  358.89MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 179.44MiB
		Entry Parameter Subshape: f32[60000,28,28,1]
		==========================

	Buffer 2:
		Size: 179.44MiB
		Operator: op_type="div" op_name="jit(true_divide)/div" source_file="examples/deep_learning/robust_training.py" source_line=52
		XLA Label: fusion
		Shape: f32[60000,28,28,1]
		==========================

	Buffer 3:
		Size: 4B
		Entry Parameter Subshape: f32[]
		==========================


Traceback (most recent call last):
  File "examples/deep_learning/robust_training.py", line 141, in <module>
    train_ds, test_ds = load_datasets()
  File "examples/deep_learning/robust_training.py", line 63, in load_datasets
    train_ds['image'], test_ds['image'] = map(normalize, (train_ds['image'], test_ds['image']))
  File "examples/deep_learning/robust_training.py", line 52, in normalize
    return jnp.asarray(images).astype(jnp.float32) / 255.
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6585, in deferring_binary_op
    return binary_op(self, other)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 690, in _xla_call_impl
    out = compiled_fun(*args)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1100, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 188160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  179.44MiB
              constant allocation:         0B
        maybe_live_out allocation:  179.44MiB
     preallocated temp allocation:         0B
                 total allocation:  358.89MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 179.44MiB
		Entry Parameter Subshape: f32[60000,28,28,1]
		==========================

	Buffer 2:
		Size: 179.44MiB
		Operator: op_type="div" op_name="jit(true_divide)/div" source_file="examples/deep_learning/robust_training.py" source_line=52
		XLA Label: fusion
		Shape: f32[60000,28,28,1]
		==========================

	Buffer 3:
		Size: 4B
		Entry Parameter Subshape: f32[]
		==========================

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "examples/deep_learning/robust_training.py", line 141, in <module>
    train_ds, test_ds = load_datasets()
  File "examples/deep_learning/robust_training.py", line 63, in load_datasets
    train_ds['image'], test_ds['image'] = map(normalize, (train_ds['image'], test_ds['image']))
  File "examples/deep_learning/robust_training.py", line 52, in normalize
    return jnp.asarray(images).astype(jnp.float32) / 255.
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6585, in deferring_binary_op
    return binary_op(self, other)
  File "/home/pedregosa/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1100, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 188160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  179.44MiB
              constant allocation:         0B
        maybe_live_out allocation:  179.44MiB
     preallocated temp allocation:         0B
                 total allocation:  358.89MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 179.44MiB
		Entry Parameter Subshape: f32[60000,28,28,1]
		==========================

	Buffer 2:
		Size: 179.44MiB
		Operator: op_type="div" op_name="jit(true_divide)/div" source_file="examples/deep_learning/robust_training.py" source_line=52
		XLA Label: fusion
		Shape: f32[60000,28,28,1]
		==========================

	Buffer 3:
		Size: 4B
		Entry Parameter Subshape: f32[]
		==========================


CC @GeoffNN

Wrong ridge_objective function in the ridge_reg_implicit_diff example

There is an error in the ridge_objective function:
return 0.5 * jnp.mean(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2)

return 0.5 * jnp.mean(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2)

As correctly stated in Figure 1 in the paper https://arxiv.org/pdf/2105.15183.pdf
the ridge objective should be
return 0.5 * jnp.sum(residuals ** 2) + 0.5 * lam * jnp.sum(params ** 2)
i.e. change the jnp.mean to jnp.sum.

CustomVJPException plus memory leak when using a for loop instead of a scan.

I've been using jaxopt.implicit_diff.custom_root for differentiating through an jax-md energy minimization routine and I have noticed that if I am using a python for loop for my solver then I get a CustomVJPException and an additional memory leak.

This memory leak only seems to show up when I get the CustomVJPException and not when I modify my code to prevent the exception from happening. I believe the underlying reason for that exception is the same as in issue #31 and seems to stem from the fact how jax-md defines its energy functions.

I'd like to know how to change that part of jax-md to prevent the CustomVJPException from happening in the first place but I've haven't managed to come up with a simplified version that would let me pinpoint the source of the error. But I can give it another shot if that helps you.

Here's a colab demo that demonstrates the issue.
https://colab.research.google.com/drive/1f_3EmFQpvW1p7A1AcNw8uqX5T79fjXRS?usp=sharing

poor GPU utilization on the deep learning examples

when running the deep learning examples (say) deep_learning/flax_image_classif.py , the GPU utilization is never above 5%, while for the equivalent flax example the GPU utilization is around 90%, and the example runs more than 20x faster.

My guess is that there's a crucial @jax.jit directive missing somewhere.

Vmap over outer_objective?

Is it possible to apply vmap over outer_objective?

To be concrete, consider the lasso example. outer_objective takes as arguments theta, init_inner, data

Is it possible to provide the OptaxSolver with the following objective function? vmap(outer_objective, in_axes=(None, 0, 0))?''

---------------------------------------------------------------- EDITED --------------------------------------------------------
Yes you can!

Projections in the KL sense

When the projection set is contained in the non-negative orthant (simplex, Birkhoff, ...), it makes sense to do a projection in the KL sense. We should start a new module jaxopt.kl_projection for that.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.