Comments (12)
This issue is related to #61. We can't use projection_polyhedron
with vmap
yet because it is based on jaxopt.QuadraticProgramming
, which currently relies on cvxpy when there are inequality constraints. The issue will be resolved once we ship our own pure-JAX QP solvers.
from jaxopt.
Ok for vmap
now, I experience another problem but it may be related to the way I have coded.
I have tried to make the following:
3 lines forming a triangle with A & b matrix/vector
y_0 + y_1 = 1
y_1 = 1
-y_0+y_1 = _1
and with G & h
y_0 >0 & y_1 >0
def myproj3(x):
A = jnp.array([[1.0, 1.0],[0.0,1.0],[-1.0,1.0]])
b = jnp.array([1.0,1.0,-1.0])
G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
h = jnp.array([0.0, 0.0])
x = projection_polyhedron(x,hyperparams = (A, b, G, h))
return x
but
myproj3(jnp.array([0.,0.]))
get the error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_24565/1376122935.py in <module>
----> 1 myproj3(jnp.array([0.,0.]))
/tmp/ipykernel_24565/3808040114.py in myproj3(x)
4 G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
5 h = jnp.array([0.0, 0.0])
----> 6 x = projection_polyhedron(x,hyperparams = (A, b, G, h))
7 return x
/sps/lsst/users/campagne/anaconda3/envs/jaxOptim/lib/python3.8/site-packages/jaxopt/_src/projection.py in projection_polyhedron(x, hyperparams)
299 I = jnp.eye(len(x))
300 hyperparams = dict(params_obj=(I, -x), params_eq=(A, b), params_ineq=(G, h))
--> 301 return qp.run(**hyperparams).params[0]
302
303
/sps/lsst/users/campagne/anaconda3/envs/jaxOptim/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py in wrapped_solver_fun(*args, **kwargs)
238 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
239 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 240 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
241
242 return wrapped_solver_fun
[... skipping hidden 5 frame]
/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py in solver_fun_flat(*flat_args)
230 def solver_fun_flat(*flat_args):
231 args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 232 return solver_fun(*args, **kwargs)
233
234 solver_fun_flat.defvjp(solver_fun_fwd, solver_fun_bwd)
/lib/python3.8/site-packages/jaxopt/_src/quadratic_prog.py in run(self, init_params, params_obj, params_eq, params_ineq)
208 self.maxiter, tol=self.tol))
209 else:
--> 210 sol = base.KKTSolution(*_solve_constrained_qp_cvxpy(params_obj,
211 params_eq,
212 params_ineq))
/lib/python3.8/site-packages/jaxopt/_src/quadratic_prog.py in _solve_constrained_qp_cvxpy(params_obj, params_eq, params_ineq)
110 pb = cp.Problem(cp.Minimize(objective), constraints)
111 pb.solve()
--> 112 return (jnp.array(x.value), jnp.array(pb.constraints[0].dual_value),
113 jnp.array(pb.constraints[1].dual_value))
114
/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
3542 if type(object) is np.ndarray:
3543 _inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
-> 3544 lax._check_user_dtype_supported(_inferred_dtype, "array")
3545 out = _np_array(object, copy=copy, dtype=dtype)
3546 if dtype: assert _dtype(out) == dtype
/lib/python3.8/site-packages/jax/_src/lax/lax.py in _check_user_dtype_supported(dtype, fun_name)
7011 msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
7012 msg += f" in {fun_name}" if fun_name else ""
-> 7013 raise TypeError(msg)
7014 if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
7015 msg = ("Explicitly requested dtype {} {} is not available, "
TypeError: JAX only supports number and bool dtypes, got dtype object in array
Have an idea?
from jaxopt.
If my understanding of your constraints is correct you want vectors of coordinates (x,y)
such that:
x+y == 1 # (1)
y == 1 # (2)
-x+y == -1 # (3)
Clearly the problem is unfeasible: from (1)
and (2)
we have that x=0
and y=1
which contradicts (3)
. Your polyhedron is an empty set which causes the error.
In general, with N unknowns you want no more than N linear equality constraints, because either your problem is unfeasible, either some equality constraints are redundant (linear combination of others).
I agree that the error message is quite cryptic since we wrap over another library (cvxpy) without checking the feasiblity of the problem, assuming users would check the feasability themselves.
from jaxopt.
Indeed, the resulting QP was infeasible. With PR #76 we will now raise an exception.
from jaxopt.
Of course your would have been right if the system had required equalities, but my use-case is a domain inside the lines
x+y >=1
y<=1
-x+y >= -1
x>=0
y>=0
Is it possible to get it with polyhedron projection?
from jaxopt.
Yes of course:
from jaxopt.projection import projection_polyhedron
import jax.numpy as jnp
import numpy as np
def myproj3(x):
A = np.array([[0, 0]])
b = np.array([0])
G = np.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]])
h = np.array([-1, 1, 1, 0, 0])
x = projection_polyhedron(x, hyperparams = (A, b, G, h))
return x
Please read the doc: https://jaxopt.github.io/stable/_autosummary/jaxopt.projection.projection_polyhedron.html
It clearly states that inequalities constraints should be put in matrices (G,h)
. If you don't want equality constraints you can put null constraint in matrices (A,b)
(just like I did).
from jaxopt.
@Algue-Rythme Thanks, I was a bit confused reading the doc and other material where A is the matrix of inequality. Sorry.
from jaxopt.
I still have a crash
def f(x):
return x[0]**2-x[1]**2
A = jnp.array([[0, 0]])
b = jnp.array([0])
G = jnp.array([[-1, -1], [0, 1], [1, -1], [-1, 0], [0, -1]])
h = jnp.array([-1, 1, 1, 0, 0])
hyperparams = (A, b, G, h)
pg=jaxopt.ProjectedGradient(fun=f,projection=projection_polyhedron)
res_poly = pg.run(init_params=jnp.array([0.,1.]), hyperparams_proj=hyperparams)
The end of TraceBack
....
/python3.8/site-packages/cvxpy/interface/numpy_interface/ndarray_interface.py in const_to_matrix(self, value, convert_scalars)
46 result = numpy.asarray(value).T
47 else:
---> 48 result = numpy.asarray(value)
49 if result.dtype in [complex, numpy.float64]:
50 return result
python3.8/site-packages/jax/core.py in __array__(self, *args, **kw)
476
477 def __array__(self, *args, **kw):
--> 478 raise TracerArrayConversionError(self)
479
480 def __index__(self):
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2,2])>with<DynamicJaxprTrace(level=1/3)>
While tracing the function _body_fun at /python3.8/site-packages/jaxopt/_src/loop.py:55 for while_loop, this concrete value was not available in Python because it depends on the value of the argument '_val'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
from jaxopt.
This is because by default the run()
method will jit the projection (which cannot be jitted currently). You should pass jit=False
to ProjectedGradient
.
from jaxopt.
Well well, jit=False
in the jaxopt.ProjectedGradient
call does not solve the crash and leads to the same error message.
from jaxopt.
The PR #79 should correct the bug. It will work with jit=False
after the merge.
from jaxopt.
Ok.
from jaxopt.
Related Issues (20)
- drop markdown for notebook examples? HOT 2
- Using `vmap` for root finding with a vector of parameters HOT 8
- LBFGSB produces NaN for certain conditions HOT 1
- Document solver attributes
- Add usage examples for the documentations HOT 4
- LevenbergMarquardt do not seems to work with non-flat input. HOT 1
- Parallel execution of multiple optimization processes HOT 3
- OSQP should inherit from IterativeSolver HOT 1
- Number of gradient evaluations not applicable to scipy's gradient-free optimizers
- Constrained Optimization and Spanning Tree Polytope HOT 4
- ScipyMinimize(method="TNC") is broken HOT 2
- PolyakSGD should take an optional parameter f_star
- Implicitly differentiate the KKT conditions HOT 5
- OSQP solvers with fun set and no init_params have misleading error message HOT 8
- BoxOSQP does not work without equality constraints HOT 5
- Type precision issue in BoxOSQP HOT 10
- Garbage collection issues HOT 1
- Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` HOT 3
- Attempted boolean conversion of traced array - for hager-zhang HOT 2
- Expression tree API like CVXPY HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.