Git Product home page Git Product logo

Comments (12)

mblondel avatar mblondel commented on May 5, 2024

This issue is related to #61. We can't use projection_polyhedron with vmap yet because it is based on jaxopt.QuadraticProgramming, which currently relies on cvxpy when there are inequality constraints. The issue will be resolved once we ship our own pure-JAX QP solvers.

from jaxopt.

jecampagne avatar jecampagne commented on May 5, 2024

Ok for vmapnow, I experience another problem but it may be related to the way I have coded.
I have tried to make the following:
3 lines forming a triangle with A & b matrix/vector
y_0 + y_1 = 1
y_1 = 1
-y_0+y_1 = _1
and with G & h
y_0 >0 & y_1 >0

def myproj3(x):
    A = jnp.array([[1.0, 1.0],[0.0,1.0],[-1.0,1.0]])
    b = jnp.array([1.0,1.0,-1.0])
    G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
    h = jnp.array([0.0, 0.0])    
    x = projection_polyhedron(x,hyperparams = (A, b, G, h))
    return x

but
myproj3(jnp.array([0.,0.]))
get the error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_24565/1376122935.py in <module>
----> 1 myproj3(jnp.array([0.,0.]))

/tmp/ipykernel_24565/3808040114.py in myproj3(x)
      4     G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
      5     h = jnp.array([0.0, 0.0])
----> 6     x = projection_polyhedron(x,hyperparams = (A, b, G, h))
      7     return x

/sps/lsst/users/campagne/anaconda3/envs/jaxOptim/lib/python3.8/site-packages/jaxopt/_src/projection.py in projection_polyhedron(x, hyperparams)
    299   I = jnp.eye(len(x))
    300   hyperparams = dict(params_obj=(I, -x), params_eq=(A, b), params_ineq=(G, h))
--> 301   return qp.run(**hyperparams).params[0]
    302 
    303 

/sps/lsst/users/campagne/anaconda3/envs/jaxOptim/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py in wrapped_solver_fun(*args, **kwargs)
    238     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    239     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 240     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    241 
    242   return wrapped_solver_fun

    [... skipping hidden 5 frame]

/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py in solver_fun_flat(*flat_args)
    230     def solver_fun_flat(*flat_args):
    231       args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 232       return solver_fun(*args, **kwargs)
    233 
    234     solver_fun_flat.defvjp(solver_fun_fwd, solver_fun_bwd)

/lib/python3.8/site-packages/jaxopt/_src/quadratic_prog.py in run(self, init_params, params_obj, params_eq, params_ineq)
    208                                                        self.maxiter, tol=self.tol))
    209     else:
--> 210       sol = base.KKTSolution(*_solve_constrained_qp_cvxpy(params_obj,
    211                                                           params_eq,
    212                                                           params_ineq))

/lib/python3.8/site-packages/jaxopt/_src/quadratic_prog.py in _solve_constrained_qp_cvxpy(params_obj, params_eq, params_ineq)
    110   pb = cp.Problem(cp.Minimize(objective), constraints)
    111   pb.solve()
--> 112   return (jnp.array(x.value), jnp.array(pb.constraints[0].dual_value),
    113           jnp.array(pb.constraints[1].dual_value))
    114 

/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   3542   if type(object) is np.ndarray:
   3543     _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
-> 3544     lax._check_user_dtype_supported(_inferred_dtype, "array")
   3545     out = _np_array(object, copy=copy, dtype=dtype)
   3546     if dtype: assert _dtype(out) == dtype

/lib/python3.8/site-packages/jax/_src/lax/lax.py in _check_user_dtype_supported(dtype, fun_name)
   7011     msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
   7012     msg += f" in {fun_name}" if fun_name else ""
-> 7013     raise TypeError(msg)
   7014   if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
   7015     msg = ("Explicitly requested dtype {} {} is not available, "

TypeError: JAX only supports number and bool dtypes, got dtype object in array

Have an idea?

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on May 5, 2024

If my understanding of your constraints is correct you want vectors of coordinates (x,y) such that:

x+y == 1  # (1)
y == 1  # (2)
-x+y == -1 # (3)

Clearly the problem is unfeasible: from (1) and (2) we have that x=0 and y=1 which contradicts (3). Your polyhedron is an empty set which causes the error.

In general, with N unknowns you want no more than N linear equality constraints, because either your problem is unfeasible, either some equality constraints are redundant (linear combination of others).

I agree that the error message is quite cryptic since we wrap over another library (cvxpy) without checking the feasiblity of the problem, assuming users would check the feasability themselves.

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

Indeed, the resulting QP was infeasible. With PR #76 we will now raise an exception.

from jaxopt.

jecampagne avatar jecampagne commented on May 5, 2024

Of course your would have been right if the system had required equalities, but my use-case is a domain inside the lines
x+y >=1
y<=1
-x+y >= -1
x>=0
y>=0
image

Is it possible to get it with polyhedron projection?

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on May 5, 2024

Yes of course:

from jaxopt.projection import projection_polyhedron
import jax.numpy as jnp
import numpy as np

def myproj3(x):
    A = np.array([[0, 0]])
    b = np.array([0])
    G = np.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]])
    h = np.array([-1, 1, 1, 0, 0])    
    x = projection_polyhedron(x, hyperparams = (A, b, G, h))
    return x

Please read the doc: https://jaxopt.github.io/stable/_autosummary/jaxopt.projection.projection_polyhedron.html

It clearly states that inequalities constraints should be put in matrices (G,h). If you don't want equality constraints you can put null constraint in matrices (A,b) (just like I did).

from jaxopt.

jecampagne avatar jecampagne commented on May 5, 2024

@Algue-Rythme Thanks, I was a bit confused reading the doc and other material where A is the matrix of inequality. Sorry.

from jaxopt.

jecampagne avatar jecampagne commented on May 5, 2024

I still have a crash

def f(x):
    return x[0]**2-x[1]**2

A = jnp.array([[0, 0]])
b = jnp.array([0])
G = jnp.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]])
h = jnp.array([-1, 1, 1, 0, 0])    
hyperparams = (A, b, G, h)

image

pg=jaxopt.ProjectedGradient(fun=f,projection=projection_polyhedron)
res_poly = pg.run(init_params=jnp.array([0.,1.]), hyperparams_proj=hyperparams)

The end of TraceBack

....
/python3.8/site-packages/cvxpy/interface/numpy_interface/ndarray_interface.py in const_to_matrix(self, value, convert_scalars)
     46             result = numpy.asarray(value).T
     47         else:
---> 48             result = numpy.asarray(value)
     49         if result.dtype in [complex, numpy.float64]:
     50             return result

python3.8/site-packages/jax/core.py in __array__(self, *args, **kw)
    476 
    477   def __array__(self, *args, **kw):
--> 478     raise TracerArrayConversionError(self)
    479 
    480   def __index__(self):

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2,2])>with<DynamicJaxprTrace(level=1/3)>
While tracing the function _body_fun at /python3.8/site-packages/jaxopt/_src/loop.py:55 for while_loop, this concrete value was not available in Python because it depends on the value of the argument '_val'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on May 5, 2024

This is because by default the run() method will jit the projection (which cannot be jitted currently). You should pass jit=False to ProjectedGradient.

from jaxopt.

jecampagne avatar jecampagne commented on May 5, 2024

Well well, jit=False in the jaxopt.ProjectedGradient call does not solve the crash and leads to the same error message.

from jaxopt.

Algue-Rythme avatar Algue-Rythme commented on May 5, 2024

The PR #79 should correct the bug. It will work with jit=False after the merge.

from jaxopt.

jecampagne avatar jecampagne commented on May 5, 2024

Ok.

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.