Comments (3)
@shoyer may have an idea.
from jaxopt.
Yes, this looks like the same issue as #31.
You definitely need the dependency on params
inside the optimality_fun passed to jaxopt.implicit_diff.custom_root
. Otherwise jaxopt is not going to calculate the gradients corectly.
I think something like this should work, but it results in a different strange error:
def implicit_diff_3(params, R_init, box_size, use_for_loop=True):
energy_fn = energy.soft_sphere_pair(displacement, **params)
force_fn = jit(quantity.force(energy_fn))
def optimality_fun(sol, params):
energy_fn = energy.soft_sphere_pair(displacement, **params)
force_fn = jit(quantity.force(energy_fn))
return force_fn(sol)
def solver(params, x):
del params
return run_minimization_scan(force_fn, x, shift, use_for_loop, num_steps = 19400)
decorated_solver = custom_root(optimality_fun)(solver)
R_final = decorated_solver(params, R_init)
return (energy_fn(R_final,**params), jnp.amax(jnp.abs(force_fn(R_final,**params))))
NotImplementedError: Differentiation rule for 'custom_lin' not implemented
from jaxopt.
I'm a bit confused by the variable names above. In JAXopt, we usually use optimality_fun(params, hyperparams)
and solver_fun(init_params, hyperparams)
, where params
is what is optimized and hyperparams
is what is differentiated. See e.g. this example.
from jaxopt.
Related Issues (20)
- PolyakSGD should take an optional parameter f_star
- Implicitly differentiate the KKT conditions HOT 5
- OSQP solvers with fun set and no init_params have misleading error message HOT 8
- BoxOSQP does not work without equality constraints HOT 5
- Type precision issue in BoxOSQP HOT 10
- Garbage collection issues HOT 1
- Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` HOT 3
- Attempted boolean conversion of traced array - for hager-zhang HOT 2
- Expression tree API like CVXPY HOT 3
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.