Comments (5)
Can you provide a MWE?
from optimistix.
sry very late but still relevant. A single shooting example:
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
import optax
import diffrax as dfx
from watermark import watermark
jax.config.update("jax_enable_x64", True)
def c_func(mach):
return jnp.select([mach < 0.4, mach < 0.8, mach < 1.2],
[0.1, .1 * (mach - 0.4) / 0.4 + 0.1, 0.25 * (mach - 0.8) / 0.4 + 0.25], default=.5)
class CannonODE(eqx.Module):
c: float
g: float
def __call__(self, t, y, args):
v = y[1]
T, = args
speed = jnp.linalg.norm(v)
mach = speed / 340.0
c = c_func(mach)
dp = T * v
dv = T * jnp.array([-c * v[0] * speed,
-c * v[1] * speed - self.g])
return (dp, dv)
class CannonTrajectory(eqx.Module):
ode: CannonODE
def __init__(self, ode):
self.ode = ode
def __call__(self, parameter, saveat: dfx.SaveAt):
QE, v0, T = parameter
y0 = (jnp.array([0.0, 0.0]) , jnp.array([v0*jnp.cos(QE), v0*jnp.sin(QE)]))
term = dfx.ODETerm(self.ode)
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
solver = dfx.Tsit5()
t0 = saveat.subs.ts[0]
t1 = saveat.subs.ts[-1]
dt0 = 0.01
sol = dfx.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=(T,),
saveat=saveat,
stepsize_controller=stepsize_controller,
# support forward-mode autodiff, which is used by Levenberg--Marquardt
adjoint=dfx.DirectAdjoint(),
max_steps=1024
)
return sol
def residuals(parameter, args):
traj, target = args
saveat = dfx.SaveAt(ts=jnp.array([0., 1.]))
pred_values = traj(parameter, saveat).ys[0][-1,:]
return target - pred_values
def residuals_min(parameter, args):
res = residuals(parameter, args)
return jnp.sqrt(jnp.dot(res, res))
def main(target):
v0 = 200.0
QE0 = 0.01#jnp.pi/4
T0 = 2.0
ode = CannonODE(c=0.6, g=9.81)
traj = CannonTrajectory(ode)
init_parameter = jnp.array([QE0, v0, T0])
solver = optx.OptaxMinimiser(optax.adabelief, rtol=1e-8, atol=1e-8)
res = optx.minimise(residuals_min, solver, init_parameter, max_steps=128, throw=False, args=(traj, target))
return res, traj, target
if __name__ == "__main__":
print(watermark(packages="jax,jaxlib,optimistix,equinox,diffrax,optax"))
target = jnp.array([100., 0.])
res, traj, target = main(target)
output:
jax : 0.4.20
jaxlib : 0.4.14
optimistix: 0.0.5
equinox : 0.11.2
diffrax : 0.4.1
optax : 0.1.7
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
.....
File "/Users/hkortier/venvs/diffrax/lib/python3.10/site-packages/optimistix/_solver/optax.py", line 90, in init
opt_state = self.optim.init(y)
AttributeError: '_Closure' object has no attribute 'init'
from optimistix.
Ah! You want optax.adabelief(...)
, not just optax.adabelief
.
from optimistix.
ah thanks for you prompt reponse! I took this sentence from the https://docs.kidger.site/optimistix/how-to-choose/
optimistix.OptaxMinimiser(optax.adabelief, learning_rate=1e-3, rtol=1e-8, atol=1e-8)
However, lower in that text the correct syntax is listed.
from optimistix.
Ah, thank you for pointing out the mistake! This should now be fixed in #29, so I'm closing this.
from optimistix.
Related Issues (20)
- Extracting intermediate function values/ losses from the solve HOT 4
- Zero implicit gradients when using `ImplicitAdjoint` with CG solver HOT 4
- Would an exhaustive grid search have a place in `optimistix`? HOT 2
- Using `optimistix` with an `equinox` model HOT 2
- Incompatibility with jax 0.4.27 HOT 1
- Possibly of interest HOT 1
- Unexpected behaviour with JAX version HOT 3
- Slow compile of least_squares with large dict parameters HOT 2
- Can't vmap across input using Gauss Newton fwd HOT 11
- Question: errorhandling, BFGS minimization, vmap, and best practices HOT 2
- Parallel multi start HOT 10
- Optimization across multidimensional array HOT 4
- Accessing the (success/not) result of the solver HOT 1
- `UnexpectedTracerError` when enabling `EQX_ON_ERROR=breakpoint` HOT 1
- Difficulties getting good results HOT 16
- Damped Newton solve(?) / Scipy and Optimistix HOT 5
- Support complex-to-real optimization HOT 11
- AssertionError: assert _is_global_function(fn_primal) HOT 2
- FutureWarning (unhashable type) thrown with LM solver HOT 1
- First step of `GradientDescent` optimizer is a no-op HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from optimistix.