Comments (21)
Hi,
I think only a couple of small changes would be needed.
To use implicit differentiation with solver.run
, you should (1) expose the args
with respect to which you'd like to differentiate the solver's solution explicitly in the signature of fun
and (2) avoid using keyword arguments in the call to solver.run
.
In your MWE:
def pipeline(param_for_grad, data):
def to_minimize(latent, param_for_grad):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(initial, param_for_grad)
return result
jax.value_and_grad(pipeline)(2., data=6.)
P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps.
from jaxopt.
Thanks @phinate for the question and @fllinares for the answer!
Indeed, as Felipe explained, your param_for_grad
was in the scope (this is what is meant by closed-over value) but it wasn't an explicit argument of run
.
By the way, since run
calls init
for you, the line
initial, _ = solver.init(init_params = 5.)
is not needed. You can just set initial = 5
and then call run(initial, params_for_grad)
.
from jaxopt.
We are working on a documentation, hopefully these things will become clearer soon.
from jaxopt.
We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don't. This is problematic if you have several big variables in your scope, such as data matrices.
from jaxopt.
I'm not sure what your setup_objective
function is doing but I would try to decompose it
def pipeline(param_for_grad, latent):
res = intermediary_step(param_for_grad, **kwargs)
def objective_fun(params, intermediary_result):
[...] # do not use param_for_grad or res here!
solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
return solver.run(init_params, intermediary_result=res * latent).params
jax.jacobian(pipeline)(param_for_grad, latent)
The key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea.
from jaxopt.
hi again @mblondel, sorry to resurrect this from the dead -- my solution that uses closure_convert
randomly started leaking tracers, with one of the jax/jaxlib
updates and it's a bit of a nightmare to debug. Luckily, I found a fairly simple MWE:
from functools import partial
import jax
import jax.numpy as jnp
import jaxopt
import optax
# dummy model for test purposes
class Model:
x: jax.Array
def __init__(self, x) -> None:
self.x = x
def logpdf(self, pars, data):
return jnp.sum(pars*data*self.x)
@partial(jax.jit, static_argnames=["objective_fn"])
def _minimize(
objective_fn,
init_pars,
lr,
):
# this is the line added from our discussion above
converted_fn, aux_pars = jax.closure_convert(objective_fn, init_pars)
# aux_pars seems to be empty -- would have assumed it was the closed-over vals or similar?
solver = jaxopt.OptaxSolver(
fun=converted_fn, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, *aux_pars)[0]
@partial(jax.jit, static_argnames=["model"])
def fit(
data,
model,
init_pars,
lr = 1e-3,
):
def fit_objective(pars):
return -model.logpdf(pars, data)
fit_res = _minimize(fit_objective, init_pars, lr)
return fit_res
def pipeline(x):
model = Model(x)
mle_pars = fit(
model=model,
data=jnp.array([5.0, 5.0]),
init_pars=jnp.array([1.0, 1.1]),
lr=1e-3,
)
return mle_pars
jax.jacrev(pipeline)(jnp.asarray(0.5))
# >> JaxStackTraceBeforeTransformation
(this is jaxopt==0.6
)
Another thing to note: the jaxpr
tracing induced by closure_convert
seems to really fill up the cache, which made this quite problematic in practice (I had to use @patrick-kidger's hack from this JAX issue). Just a health warning for anyone else interested in this type of solution!
I can't see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid closure_convert
altogether).
from jaxopt.
Thanks both @mblondel & @fllinares for your replies, and the helpful information!
I'm struggling a little with this because the suggestion of moving param_for_grad
into the to_minimize
call explicitly is a bit cumbersome for my use case; the way I'm actually making this objective function looks more like:
def setup_objective(param_for_grad, **kwargs):
to_minimize = complicated_function(param_for_grad, **kwargs)
return to_minimize
def pipeline(param_for_grad, **kwargs):
obj = setup_objective(param_for_grad, **kwargs)
solver = OptaxSolver(fun=obj, opt=optax.adam(5e-2), implicit_diff=True)
... etc ...
return result
To parametrize directly with param_for_grad
would mean that I would have to construct the objective via complicated_function
every time it was called in the minimization loop, when strictly this doesn't change with respect to param_for_grad
during the minimization.
Am I missing something here in terms of nicely setting up this problem? Or for implicit diff, do I really need to be explicit in the way you described, even though param_for_grad
is only used in the construction of the objective, and not for its evaluation?
Thanks again for the quick response earlier, and sorry if this is somehow unclear!
from jaxopt.
By the way, you can also take a look at jax.lax.custom_root
, which supports closed-over-variable. CC @shoyer
from jaxopt.
thanks @mblondel - is this something that could be in scope for jaxpot
later on to add closed-over variables? We can certainly provide metadata which variables require diffing and which don't
from jaxopt.
Could you sketch how this would look like on the user side?
from jaxopt.
could we co-opt the static_args
-like API for this.?
Edit: I guess this is equivalent to the argnums
kwarg .. so would that be sufficient?
def pipeline(param_for_grad, data):
def to_minimize(latent, param_for_grad):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(initial, param_for_grad)
return result
pipeline = jaxopt.annotate(pipeline, diff_args = (0,))
jax.value_and_grad(pipeline)(2., data=6.)
from jaxopt.
We could support handling closed-over-variables via jax.closure_convert
(like jax.lax.custom_root
), but the tradeoff is that it requires tracing Python functions to a JAXpr. This means you can't do dynamic control flow/debugging with Python.
The ideal solution would probably be either to (1) encourage users to use closure_convert
themselves, or (2) possibly add an optional argument to allow for opting into automatic closure conversion, e.g., closure_convert=True
.
from jaxopt.
thanks @shoyer - can we use clcosure_convert
now to achieve the desired behavior?
from jaxopt.
Hi again all -- we followed @shoyer's suggestion of using jax.closure_convert
using a function pretty much ripped from the jax
docs:
def _minimize(objective_fn, lhood_pars, lr):
converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars)
# aux_pars seems to be empty, took that line from docs example
solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
return solver.run(lhood_pars, *aux_pars)[0]
where objective_fn
is usually created on-the-fly with the aforementioned setup_objective
function.
Using this does allow autodiff with no errors, but we encounter a pretty substantial slowdown compared to wrapping an equivalent Adam optimiser using this more explicit implementation of the two-phase method. We were hoping to transition away from this in the interest of keeping up with jax
releases and other software that also follows jax
, as well as the far more active effort in jaxopt
.
As a side note, we do have an additional performance bottleneck coming from external software constraints that is hard to decorrelate from the change to jaxopt
(lack of ability to JIT some parts of the pipeline due to changing jax
version), but based on comparisons removing the JIT from the old program, I don't think it's nearly enough to explain the ~10x slowdown.
Is there any expected drop in performance from using closure_convert
, perhaps given my previous statements on the complexity of the setup_objective
function?
from jaxopt.
from jaxopt.
Any follow up on this? Does your objective seem decomposable in the way I describe?
from jaxopt.
Any follow up on this? Does your objective seem decomposable in the way I describe?
Have just thought about this a bit -- I'm not 100% if this would work, but one potential resolution to this for us in terms of decomposing the problem could be to build our statistical model (expensive boilerplate) from which we want to call a logpdf
method, and then construct the objective like this
def pipeline(param_for_grad):
res = model(param_for_grad)
def objective_fun(params, model):
return model.logpdf(params)
solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
return solver.run(init_params, model=model).params
jax.jacobian(pipeline)(param_for_grad)
Provided this model was registered as a pytree, do you think this would resolve the problem? It's not something we've implemented, but could be if this would work.
from jaxopt.
Just an update on this: we've managed to get the jit working on our side with closure_convert
, and we see the performance recover, so @shoyer got it right on that count despite my (incorrect) assumption -- thanks!
If it's helpful, I'd be happy to summarise this thread as a small entry into the documentation via a PR @mblondel, since it could come up again for other users with similar use cases.
Thanks both for the fast and attentive help!
from jaxopt.
What would the code snippet look like?
from jaxopt.
I explored this a bit, and my particular workflow here was made possible if one makes the Model
class a Pytree, which allows me to feed in the model as an explicit argument to the objective function while keeping jit across the optimization procedure. I think this also means that the relevant parameters for grad are no longer closed over, since the Pytree contains that information.
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from simple_pytree import Pytree
# dummy model for test purposes
class Model(Pytree):
x: jax.Array
def __init__(self, x) -> None:
self.x = x
def logpdf(self, pars, data):
return jsp.stats.norm.logpdf(data, loc=pars*self.x, scale=1.0).sum()
@jax.jit
def pipeline(param_for_grad):
data=jnp.array([5.0, 5.0])
init_pars=jnp.array([1.0, 1.1])
lr=1e-3
model = Model(param_for_grad)
def fit(pars, model, data):
def fit_objective(pars, model, data):
return -model.logpdf(pars, data)
solver = jaxopt.OptaxSolver(
fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, model=model, data=data)[0]
return fit(init_pars, model, data)
jax.jacrev(pipeline)(jnp.asarray(0.5))
# > Array([-1.33830826e+01, 7.10542736e-15], dtype=float64, weak_type=True)
Don't know if there's another potential issue above that i'm smearing over with this approach, but it works without closure_convert
! It may be hard to coerce a complicated model into a Pytree, but that's possibly something for us to worry more about.
from jaxopt.
I think you're doing the right thing by making the model a PyTree, i.e. I don't think you're smearing over any issue.
This is the same approach Equinox uses ubiquitously, and this handles all the complexity of Diffrax just fine!
from jaxopt.
Related Issues (20)
- PolyakSGD should take an optional parameter f_star
- Implicitly differentiate the KKT conditions HOT 5
- OSQP solvers with fun set and no init_params have misleading error message HOT 8
- BoxOSQP does not work without equality constraints HOT 5
- Type precision issue in BoxOSQP HOT 10
- Garbage collection issues HOT 1
- Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` HOT 3
- Attempted boolean conversion of traced array - for hager-zhang HOT 2
- Expression tree API like CVXPY HOT 3
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.