Git Product home page Git Product logo

Comments (6)

mblondel avatar mblondel commented on May 5, 2024 1

Thanks a lot for the report. The issue stems from the fact that our rmatvec utility currently assumes that A in matvec(x) = dot(A, x) is square. The current implementation is

def _rmatvec(matvec, x):
  """Computes A^T x, from matvec(x) = A x, where A is square."""
  transpose = jax.linear_transpose(matvec, x)
  return transpose(x)[0]

If A is of size N x P instead of N x N, what we would need to do instead is

def _rmatvec(matvec, x): # x is of size N 
  transpose = jax.linear_transpose(matvec, dummy_vector_of_size_P)
  return transpose(x)[0]

Unlike symbolic diff, autodiff doesn't give us a way to get A^T without evaluating A at some particular point...
I'm not sure if there is a way to generate dummy_vector_of_size_P automatically for the user. Is there a way to infer the correct shape? CC @froystig

One way would be to add an optional argument dummy to solve_normal_cg so that we can write:

def solve_normal_cg(matvec: Callable,
                    b: Any,
                    ridge: Optional[float] = None,
                    init: Optional[Any] = None,
                    dummy: Optional[Any] = None,
                    **kwargs) -> Any:

  def _matvec(x):
    """Computes A^T A x."""
    return _normal_matvec(matvec, x)

  if ridge is not None:
    _matvec = _make_ridge_matvec(_matvec, ridge=ridge)

  Ab = _rmatvec(matvec, b, dummy=dummy)

  return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024 1

Good idea on using init. This made me thought of this approach:

if init is None:
  try:
    Ab = _rmatvec(matvec, b)
  except TypeError, e:
    raise TypeError("`init` is compulsory when `matvec` is nonsquare. Original error message:", e)
else:
  Ab = _rmatvec(matvec, b, init)

This way we don't have to introduce a new argument example_solution and the user gets a clear error message.

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

Another way would be to optionally solve A^T A x = b instead of A^T A x = A^T b:

def solve_normal_cg(matvec: Callable,
                    b: Any,
                    ridge: Optional[float] = None,
                    init: Optional[Any] = None,
                    apply_right: bool = True,
                    **kwargs) -> Any:

  def _matvec(x):
    """Computes A^T A x."""
    return _normal_matvec(matvec, x)

  if ridge is not None:
    _matvec = _make_ridge_matvec(_matvec, ridge=ridge)

  if apply_right: 
    y = _rmatvec(matvec, b)
  else:
    y = b

  return jax.scipy.sparse.linalg.cg(_matvec, y, x0=init, **kwargs)[0]

but this requires the user to manually transpose the linear operator, which is what we wanted to avoid with this utility...

from jaxopt.

froystig avatar froystig commented on May 5, 2024

(I would call it example_x or example_solution rather than dummy to be clear.)

Passing in an example would be a fine catch-all approach. Once that's given, we can transpose. The resource cost to callers need not be more than constant memory: you can pass jax.linear_transpose a 0-strided zeros array, or a jax.ShapeDtypeStruct, or anything at all that has shape and dtype attributes really.

If an example is not given, but init is, then we could back off to using init as an example.

If we have no such information, then we simply can't auto-transpose without assumptions. Maybe we can take this as the indication of a square solve.

Note that jax.eval_shape might be useful to callers in determining the shape/dtype of a solution, if for any reason they don't have it on hand already.

Sketch summarizing these ideas:

def _rmatvec(matvec, x, example):
  """Computes A^T x, from matvec(z) = A z, given an example input z for matvec"""
  transpose = jax.linear_transpose(matvec, example)
  return transpose(x)[0]

def solve_normal_cg(matvec: Callable,
                    b: Any,
                    ridge: Optional[float] = None,
                    init: Optional[Any] = None,
                    example_solution: Optional[Any] = None,
                    **kwargs) -> Any:
  """
  [...]

  For example_solution, any object with shape/dtype attributes suffices.
  Consider using jax.eval_shape if that's helpful.
  """

  def _matvec(x):
    """Computes A^T A x."""
    return _normal_matvec(matvec, x)

  if ridge is not None:
    _matvec = _make_ridge_matvec(_matvec, ridge=ridge)

  if example_solution is None:
    example_solution = init
  if example_solution is None:
    example_solution = b  # assuming square matvec

  Ab = _rmatvec(matvec, b, example_solution)

  return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]

from jaxopt.

froystig avatar froystig commented on May 5, 2024

That works. Taking init is a stronger requirement: example_solution would only need to carry shape/dtype and not any actual data. By contrast, init not only carries real data, but also must be used as the initialization point. The latter affects the solver's behavior. I decoupled the two in the sketch in order to highlight that we can accept the minimum if we want, and still offer init = None behavior. I'm OK with taking either approach, provided we understand these options.

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

I assume CG is fine with an all-zero initialization and is the default behavior anyway. Whether we use init or example_solution as the example, what I think is important is to give a clear error message so the user knows what they have to do in order to fix the issue.

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.