Git Product home page Git Product logo

Comments (11)

mblondel avatar mblondel commented on May 5, 2024

Thanks a lot for the report! All functions that involve SciPy wrappers currently can't be jitted or vmapped. We may be able to use host_callback on the JAXopt side so that they work transparently for the user.

CC @fllinares @shoyer

from jaxopt.

johnjmolina avatar johnjmolina commented on May 5, 2024

@mblondel Thanks for the quick reply. I understood the issue. For now I will just try to avoid use of the SciPy wrappers.

from jaxopt.

richinex avatar richinex commented on May 5, 2024

Modifying the last line to use map instead of vmap works

import jax
from   jax.config import config; config.update("jax_enable_x64", True)
import jax.numpy as np
from   jax import random,vmap

from jaxopt import linear_solve
from jaxopt import ScipyRootFinding

def func(x, params):
    a,b = params
    return (x - a) * (x - b)
kwargs = {'implicit_diff_solve':linear_solve.solve_normal_cg, 'method':'hybr', 'tol':1e-10}
rootfinder = ScipyRootFinding(optimality_fun=func, **kwargs)

init = np.zeros(1)
def solve(params):
    root, info = rootfinder.run(init, params)
    return root
key        = random.PRNGKey(1235711)
param_list = random.uniform(key, (100, 2), minval = -10, maxval = 10)

root       = solve(param_list[0]) # this is ok
root_list  = list(map(solve,(param_list))) # this works

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

map just computes each operation sequentially while vmap will compile to vectorized operations (e.g. vector-vector products becoming matrix-vector products), so map should be slower in general. The issue is that SciPy solvers are blackbox functions from the point of view of JAX, so it's not possible to vmap them.

from jaxopt.

richinex avatar richinex commented on May 5, 2024

You're right map is slower. That is why it is a tenmporary fix.

from jaxopt.

richinex avatar richinex commented on May 5, 2024

Some how jax.scipy.optimize.minimize (not jaxopts) is vmappable but it only supports the algorithm BFGS.

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

Well, it's written in JAX, unlike SciPy's solvers...

from jaxopt.

shoyer avatar shoyer commented on May 5, 2024

In princple we could define a custom primitive with a batching rule for the SciPy solvers.

from jaxopt.

mrzv avatar mrzv commented on May 5, 2024

I'm running into a similar problem with Bisection. Is it also coming from SciPy, or is it implemented directly in jax, and it's a different problem? Tweaking the example above to:

import jax
import jax.numpy as np
from   jax import random,vmap

from jaxopt import Bisection

def func(x, a):
    return x - a

rootfinder = Bisection(optimality_fun=func, lower = 0., upper = 1.)

def solve(a):
    return rootfinder.run(a=a[0]).params

key        = random.PRNGKey(1235711)
param_list = random.uniform(key, (4,1), minval = 0, maxval = 1)

root       = solve(param_list[0])
root_list  = vmap(solve)(param_list) # this doesn't work

I get

jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<BatchTrace(level=1/1)> with
  val = DeviceArray([False, False, False, False], dtype=bool, weak_type=True)
  batch_dim = 0
The problem arose with the `bool` function. 

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

It works if you set check_bracket=False like this:

rootfinder = Bisection(optimality_fun=func, lower = 0., upper = 1., check_bracket=False)

We check the bracketing interval by default because it's easy to make a mistake. Once your code is correct, you can disable it if you want to jit or vmap.

Ideally, we need better documentation and a clearer error message if possible.

from jaxopt.

mrzv avatar mrzv commented on May 5, 2024

Ah, that makes perfect sense. Thanks!

This is my first foray into jax, so forgive me if this question is naive, but is it possible to hide this from the user using the checkify module? Or do I misunderstand how that one works?

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.