Comments (6)
Thanks a lot for the report. The issue stems from the fact that our rmatvec utility currently assumes that A
in matvec(x) = dot(A, x)
is square. The current implementation is
def _rmatvec(matvec, x):
"""Computes A^T x, from matvec(x) = A x, where A is square."""
transpose = jax.linear_transpose(matvec, x)
return transpose(x)[0]
If A
is of size N x P
instead of N x N
, what we would need to do instead is
def _rmatvec(matvec, x): # x is of size N
transpose = jax.linear_transpose(matvec, dummy_vector_of_size_P)
return transpose(x)[0]
Unlike symbolic diff, autodiff doesn't give us a way to get A^T
without evaluating A
at some particular point...
I'm not sure if there is a way to generate dummy_vector_of_size_P
automatically for the user. Is there a way to infer the correct shape? CC @froystig
One way would be to add an optional argument dummy
to solve_normal_cg
so that we can write:
def solve_normal_cg(matvec: Callable,
b: Any,
ridge: Optional[float] = None,
init: Optional[Any] = None,
dummy: Optional[Any] = None,
**kwargs) -> Any:
def _matvec(x):
"""Computes A^T A x."""
return _normal_matvec(matvec, x)
if ridge is not None:
_matvec = _make_ridge_matvec(_matvec, ridge=ridge)
Ab = _rmatvec(matvec, b, dummy=dummy)
return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
from jaxopt.
Good idea on using init
. This made me thought of this approach:
if init is None:
try:
Ab = _rmatvec(matvec, b)
except TypeError, e:
raise TypeError("`init` is compulsory when `matvec` is nonsquare. Original error message:", e)
else:
Ab = _rmatvec(matvec, b, init)
This way we don't have to introduce a new argument example_solution
and the user gets a clear error message.
from jaxopt.
Another way would be to optionally solve A^T A x = b
instead of A^T A x = A^T b
:
def solve_normal_cg(matvec: Callable,
b: Any,
ridge: Optional[float] = None,
init: Optional[Any] = None,
apply_right: bool = True,
**kwargs) -> Any:
def _matvec(x):
"""Computes A^T A x."""
return _normal_matvec(matvec, x)
if ridge is not None:
_matvec = _make_ridge_matvec(_matvec, ridge=ridge)
if apply_right:
y = _rmatvec(matvec, b)
else:
y = b
return jax.scipy.sparse.linalg.cg(_matvec, y, x0=init, **kwargs)[0]
but this requires the user to manually transpose the linear operator, which is what we wanted to avoid with this utility...
from jaxopt.
(I would call it example_x
or example_solution
rather than dummy
to be clear.)
Passing in an example would be a fine catch-all approach. Once that's given, we can transpose. The resource cost to callers need not be more than constant memory: you can pass jax.linear_transpose
a 0-strided zeros array, or a jax.ShapeDtypeStruct
, or anything at all that has shape
and dtype
attributes really.
If an example is not given, but init
is, then we could back off to using init
as an example.
If we have no such information, then we simply can't auto-transpose without assumptions. Maybe we can take this as the indication of a square solve.
Note that jax.eval_shape
might be useful to callers in determining the shape/dtype of a solution, if for any reason they don't have it on hand already.
Sketch summarizing these ideas:
def _rmatvec(matvec, x, example):
"""Computes A^T x, from matvec(z) = A z, given an example input z for matvec"""
transpose = jax.linear_transpose(matvec, example)
return transpose(x)[0]
def solve_normal_cg(matvec: Callable,
b: Any,
ridge: Optional[float] = None,
init: Optional[Any] = None,
example_solution: Optional[Any] = None,
**kwargs) -> Any:
"""
[...]
For example_solution, any object with shape/dtype attributes suffices.
Consider using jax.eval_shape if that's helpful.
"""
def _matvec(x):
"""Computes A^T A x."""
return _normal_matvec(matvec, x)
if ridge is not None:
_matvec = _make_ridge_matvec(_matvec, ridge=ridge)
if example_solution is None:
example_solution = init
if example_solution is None:
example_solution = b # assuming square matvec
Ab = _rmatvec(matvec, b, example_solution)
return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
from jaxopt.
That works. Taking init
is a stronger requirement: example_solution
would only need to carry shape/dtype and not any actual data. By contrast, init
not only carries real data, but also must be used as the initialization point. The latter affects the solver's behavior. I decoupled the two in the sketch in order to highlight that we can accept the minimum if we want, and still offer init = None
behavior. I'm OK with taking either approach, provided we understand these options.
from jaxopt.
I assume CG is fine with an all-zero initialization and is the default behavior anyway. Whether we use init
or example_solution
as the example, what I think is important is to give a clear error message so the user knows what they have to do in order to fix the issue.
from jaxopt.
Related Issues (20)
- Type precision issue in BoxOSQP HOT 10
- Garbage collection issues HOT 1
- Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` HOT 3
- Attempted boolean conversion of traced array - for hager-zhang HOT 2
- Expression tree API like CVXPY HOT 3
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
- "invalid escape sequence" warning in `BoxOSQP` docstring
- Error when taking gradient wrt parameters in BoxOSQP
- Disable warnings in vmap or print to stderr HOT 2
- Constraint violation causes L-BGFS-B to fail HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.