Git Product home page Git Product logo

Comments (3)

mblondel avatar mblondel commented on May 5, 2024

@shoyer may have an idea.

from jaxopt.

shoyer avatar shoyer commented on May 5, 2024

Yes, this looks like the same issue as #31.

You definitely need the dependency on params inside the optimality_fun passed to jaxopt.implicit_diff.custom_root. Otherwise jaxopt is not going to calculate the gradients corectly.

I think something like this should work, but it results in a different strange error:

def implicit_diff_3(params, R_init, box_size, use_for_loop=True):
    energy_fn = energy.soft_sphere_pair(displacement, **params)
    force_fn = jit(quantity.force(energy_fn))

    def optimality_fun(sol, params):
      energy_fn = energy.soft_sphere_pair(displacement, **params)
      force_fn = jit(quantity.force(energy_fn))
      return force_fn(sol)

    def solver(params, x):
        del params
        return run_minimization_scan(force_fn, x, shift, use_for_loop, num_steps = 19400)

    decorated_solver = custom_root(optimality_fun)(solver)
    R_final = decorated_solver(params, R_init)

    return (energy_fn(R_final,**params), jnp.amax(jnp.abs(force_fn(R_final,**params))))

NotImplementedError: Differentiation rule for 'custom_lin' not implemented

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

I'm a bit confused by the variable names above. In JAXopt, we usually use optimality_fun(params, hyperparams) and solver_fun(init_params, hyperparams), where params is what is optimized and hyperparams is what is differentiated. See e.g. this example.

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.