Git Product home page Git Product logo

Comments (21)

fllinares avatar fllinares commented on May 6, 2024 2

Hi,

I think only a couple of small changes would be needed.

To use implicit differentiation with solver.run, you should (1) expose the args with respect to which you'd like to differentiate the solver's solution explicitly in the signature of fun and (2) avoid using keyword arguments in the call to solver.run.

In your MWE:

def pipeline(param_for_grad, data):
    def to_minimize(latent, param_for_grad):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(initial, param_for_grad)

    return result

jax.value_and_grad(pipeline)(2., data=6.)

P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps.

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024 2

Thanks @phinate for the question and @fllinares for the answer!

Indeed, as Felipe explained, your param_for_grad was in the scope (this is what is meant by closed-over value) but it wasn't an explicit argument of run.

By the way, since run calls init for you, the line

initial, _ = solver.init(init_params = 5.)

is not needed. You can just set initial = 5 and then call run(initial, params_for_grad).

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024 1

We are working on a documentation, hopefully these things will become clearer soon.

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024 1

We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don't. This is problematic if you have several big variables in your scope, such as data matrices.

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024 1

I'm not sure what your setup_objective function is doing but I would try to decompose it

    def pipeline(param_for_grad, latent):
        res = intermediary_step(param_for_grad, **kwargs)

        def objective_fun(params, intermediary_result):
            [...]   # do not use param_for_grad or res here!

        solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
        return solver.run(init_params, intermediary_result=res * latent).params

    jax.jacobian(pipeline)(param_for_grad, latent)

The key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea.

from jaxopt.

phinate avatar phinate commented on May 6, 2024 1

hi again @mblondel, sorry to resurrect this from the dead -- my solution that uses closure_convert randomly started leaking tracers, with one of the jax/jaxlib updates and it's a bit of a nightmare to debug. Luckily, I found a fairly simple MWE:

from functools import partial

import jax
import jax.numpy as jnp
import jaxopt
import optax


# dummy model for test purposes
class Model:
    x: jax.Array
    def __init__(self, x) -> None:
        self.x = x
    def logpdf(self, pars, data):
        return jnp.sum(pars*data*self.x)

@partial(jax.jit, static_argnames=["objective_fn"])
def _minimize(
    objective_fn,
    init_pars,
    lr,
):
    # this is the line added from our discussion above
    converted_fn, aux_pars = jax.closure_convert(objective_fn, init_pars)
    # aux_pars seems to be empty -- would have assumed it was the closed-over vals or similar?
    solver = jaxopt.OptaxSolver(
        fun=converted_fn, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
    )
    return solver.run(init_pars, *aux_pars)[0]


@partial(jax.jit, static_argnames=["model"])
def fit(
    data,
    model,
    init_pars,
    lr = 1e-3,
):
    def fit_objective(pars):
        return -model.logpdf(pars, data)

    fit_res = _minimize(fit_objective, init_pars, lr)
    return fit_res

def pipeline(x):
    model = Model(x)
    mle_pars = fit(
        model=model,
        data=jnp.array([5.0, 5.0]),
        init_pars=jnp.array([1.0, 1.1]),
        lr=1e-3,
    )
    return mle_pars

jax.jacrev(pipeline)(jnp.asarray(0.5))
# >> JaxStackTraceBeforeTransformation

(this is jaxopt==0.6)

Another thing to note: the jaxpr tracing induced by closure_convert seems to really fill up the cache, which made this quite problematic in practice (I had to use @patrick-kidger's hack from this JAX issue). Just a health warning for anyone else interested in this type of solution!

I can't see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid closure_convert altogether).

from jaxopt.

phinate avatar phinate commented on May 6, 2024

Thanks both @mblondel & @fllinares for your replies, and the helpful information!

I'm struggling a little with this because the suggestion of moving param_for_grad into the to_minimize call explicitly is a bit cumbersome for my use case; the way I'm actually making this objective function looks more like:

def setup_objective(param_for_grad, **kwargs):
    to_minimize = complicated_function(param_for_grad, **kwargs)
    return to_minimize

def pipeline(param_for_grad, **kwargs):
    obj = setup_objective(param_for_grad, **kwargs)
    solver = OptaxSolver(fun=obj, opt=optax.adam(5e-2), implicit_diff=True)
    ... etc ...
    return result

To parametrize directly with param_for_grad would mean that I would have to construct the objective via complicated_function every time it was called in the minimization loop, when strictly this doesn't change with respect to param_for_grad during the minimization.

Am I missing something here in terms of nicely setting up this problem? Or for implicit diff, do I really need to be explicit in the way you described, even though param_for_grad is only used in the construction of the objective, and not for its evaluation?

Thanks again for the quick response earlier, and sorry if this is somehow unclear!

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024

By the way, you can also take a look at jax.lax.custom_root, which supports closed-over-variable. CC @shoyer

from jaxopt.

lukasheinrich avatar lukasheinrich commented on May 6, 2024

thanks @mblondel - is this something that could be in scope for jaxpot later on to add closed-over variables? We can certainly provide metadata which variables require diffing and which don't

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024

Could you sketch how this would look like on the user side?

from jaxopt.

lukasheinrich avatar lukasheinrich commented on May 6, 2024

could we co-opt the static_args-like API for this.?

Edit: I guess this is equivalent to the argnums kwarg .. so would that be sufficient?

def pipeline(param_for_grad, data):
    def to_minimize(latent, param_for_grad):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(initial, param_for_grad)

    return result

pipeline = jaxopt.annotate(pipeline, diff_args = (0,))

jax.value_and_grad(pipeline)(2., data=6.)

from jaxopt.

shoyer avatar shoyer commented on May 6, 2024

We could support handling closed-over-variables via jax.closure_convert (like jax.lax.custom_root), but the tradeoff is that it requires tracing Python functions to a JAXpr. This means you can't do dynamic control flow/debugging with Python.

The ideal solution would probably be either to (1) encourage users to use closure_convert themselves, or (2) possibly add an optional argument to allow for opting into automatic closure conversion, e.g., closure_convert=True.

from jaxopt.

lukasheinrich avatar lukasheinrich commented on May 6, 2024

thanks @shoyer - can we use clcosure_convert now to achieve the desired behavior?

from jaxopt.

phinate avatar phinate commented on May 6, 2024

Hi again all -- we followed @shoyer's suggestion of using jax.closure_convert using a function pretty much ripped from the jax docs:

def _minimize(objective_fn, lhood_pars, lr):
        converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars) 
        # aux_pars seems to be empty, took that line from docs example
        solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
        return solver.run(lhood_pars, *aux_pars)[0]

where objective_fn is usually created on-the-fly with the aforementioned setup_objective function.

Using this does allow autodiff with no errors, but we encounter a pretty substantial slowdown compared to wrapping an equivalent Adam optimiser using this more explicit implementation of the two-phase method. We were hoping to transition away from this in the interest of keeping up with jax releases and other software that also follows jax, as well as the far more active effort in jaxopt.

As a side note, we do have an additional performance bottleneck coming from external software constraints that is hard to decorrelate from the change to jaxopt (lack of ability to JIT some parts of the pipeline due to changing jax version), but based on comparisons removing the JIT from the old program, I don't think it's nearly enough to explain the ~10x slowdown.

Is there any expected drop in performance from using closure_convert, perhaps given my previous statements on the complexity of the setup_objective function?

from jaxopt.

shoyer avatar shoyer commented on May 6, 2024

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024

Any follow up on this? Does your objective seem decomposable in the way I describe?

from jaxopt.

phinate avatar phinate commented on May 6, 2024

Any follow up on this? Does your objective seem decomposable in the way I describe?

Have just thought about this a bit -- I'm not 100% if this would work, but one potential resolution to this for us in terms of decomposing the problem could be to build our statistical model (expensive boilerplate) from which we want to call a logpdf method, and then construct the objective like this

def pipeline(param_for_grad):
    res = model(param_for_grad)

    def objective_fun(params, model):
      return model.logpdf(params)
        

    solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
    return solver.run(init_params, model=model).params

jax.jacobian(pipeline)(param_for_grad)

Provided this model was registered as a pytree, do you think this would resolve the problem? It's not something we've implemented, but could be if this would work.

from jaxopt.

phinate avatar phinate commented on May 6, 2024

Just an update on this: we've managed to get the jit working on our side with closure_convert, and we see the performance recover, so @shoyer got it right on that count despite my (incorrect) assumption -- thanks!

If it's helpful, I'd be happy to summarise this thread as a small entry into the documentation via a PR @mblondel, since it could come up again for other users with similar use cases.

Thanks both for the fast and attentive help!

from jaxopt.

mblondel avatar mblondel commented on May 6, 2024

What would the code snippet look like?

from jaxopt.

phinate avatar phinate commented on May 6, 2024

I explored this a bit, and my particular workflow here was made possible if one makes the Model class a Pytree, which allows me to feed in the model as an explicit argument to the objective function while keeping jit across the optimization procedure. I think this also means that the relevant parameters for grad are no longer closed over, since the Pytree contains that information.

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from simple_pytree import Pytree

# dummy model for test purposes
class Model(Pytree):
    x: jax.Array
    def __init__(self, x) -> None:
        self.x = x
    def logpdf(self, pars, data):
        return jsp.stats.norm.logpdf(data, loc=pars*self.x, scale=1.0).sum()


@jax.jit
def pipeline(param_for_grad):
    data=jnp.array([5.0, 5.0])
    init_pars=jnp.array([1.0, 1.1])
    lr=1e-3

    model = Model(param_for_grad)

    def fit(pars, model, data):
        def fit_objective(pars, model, data):
            return -model.logpdf(pars, data)

        solver = jaxopt.OptaxSolver(
            fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
        )
        return solver.run(init_pars, model=model, data=data)[0]

    return fit(init_pars, model, data)

jax.jacrev(pipeline)(jnp.asarray(0.5))
# > Array([-1.33830826e+01,  7.10542736e-15], dtype=float64, weak_type=True)

Don't know if there's another potential issue above that i'm smearing over with this approach, but it works without closure_convert! It may be hard to coerce a complicated model into a Pytree, but that's possibly something for us to worry more about.

from jaxopt.

patrick-kidger avatar patrick-kidger commented on May 6, 2024

I think you're doing the right thing by making the model a PyTree, i.e. I don't think you're smearing over any issue.

This is the same approach Equinox uses ubiquitously, and this handles all the complexity of Diffrax just fine!

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.