Comments (6)
The optimizers of Optax are meant for minimization, not root finding. If we take a look at your function neg_Lag
we can see that the minimum does not exist: the Lagrange multiplier x[3]
is allowed to take any value in , so as long as your surface is not zero it can be used to reach any real number. Optax is working well by diverging since the minimum does not exist anyway.
So you must:
- either use an optimizer to perform constrained optimization of your function
vol
; hereProjectedGradient
might be deceiving since your feasible set is not convex. - either use a root finding algorithm to find the zero of
jax.grad(Lag)
. Unfortunately, currently, we lack options for multidimensional root finding (best we have currently isScipyRootFinding
). Other options based on Fixed point finding will be availble soon. - reformulate the multidimensional root finding algorithm into an optimization problem: instead of finding p such that we seek to minimize
The latter works:
opt = optax.adagrad(0.1)
@jax.jit
def objective_fun(p):
delta = gLag(p)
return jnp.sum(delta**2) # minimize gradient norm
solver = jaxopt.OptaxSolver(opt=opt, fun=objective_fun, maxiter=2000)
init_params = jnp.array([1.5,0.5,1.0,0.1])
params, state = solver.init(init_params)
print('init', params, neg_Lag(params))
@jax.jit
def jitted_update(params, state):
return solver.update(params=params, state=state)
for i in range(20*1000):
params, state = jitted_update(params, state)
if i%100 == 0:
print(i, params, Lag(params), objective_fun(params))
However this algorithm is far from being efficient.
from jaxopt.
@mblondel you are absolutely right, my problem is a root finding
and not a minimization. Sorry I have forget this point when I was using my (old) solveLagrangian
that I have jaxized , it is exactly doing that root search thanks to Newton step.
Thanks for your different method discussion and snipped, too. I am not sure that I can contribute but your lib is really nice and I encourage for new code implementation.
from jaxopt.
rf = jaxopt.ScipyRootFinding(optimality_fun=gLag, method='hybr')
rf.run(jnp.array([1.5,0.5,1.0,0.1]))
gives
OptStep(params=DeviceArray([2. , 2. , 2. , 0.5], dtype=float64), state=ScipyRootInfo(fun_val=DeviceArray([-2.05051975e-10, 3.02247116e-11, -3.61966457e-10,
6.25949070e-10], dtype=float64), success=True, status=1))
So I wander why you reject this method in your comment? may be I have misunderstood something.
from jaxopt.
So I wander why you reject this method in your comment? may be I have misunderstood something.
I don't reject it, I was just mentioning that we had nothing else for this purpose.
For some reason you wanted to use Optax so I showed you an example with Optax.
But ScipyRootFinding
is fine too.
from jaxopt.
Note: I am not @Algue-Rythme :)
from jaxopt.
ho sorry @mblondel, I was also in contact with @Algue-Rythme in an other thread :)
Now, your code with Optax was really nice.
Thanks a lot.
from jaxopt.
Related Issues (20)
- PolyakSGD should take an optional parameter f_star
- Implicitly differentiate the KKT conditions HOT 5
- OSQP solvers with fun set and no init_params have misleading error message HOT 8
- BoxOSQP does not work without equality constraints HOT 5
- Type precision issue in BoxOSQP HOT 10
- Garbage collection issues HOT 1
- Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` HOT 3
- Attempted boolean conversion of traced array - for hager-zhang HOT 2
- Expression tree API like CVXPY HOT 3
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.