Git Product home page Git Product logo

diffrax's Introduction

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python 3.9+, JAX 0.4.13+, and Equinox 0.10.11+.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).

Scientific computing
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

diffrax's People

Contributors

abocquet avatar allen-adastra avatar amir-saadat avatar andyelking avatar cholberg avatar ciupakabra avatar federicov avatar fpepin avatar hawkinsp avatar jacobusmmsmit avatar jakevdp avatar jatentaki avatar joglekara avatar lockwo avatar mahdi-shafiei avatar packquickly avatar particularlypythonicbs avatar patrick-kidger avatar randl avatar rdaems avatar rehmoritz avatar simipixel avatar slishak avatar stefanocortinovis avatar thibmonsel avatar tttc3 avatar vivelev avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

diffrax's Issues

Support Ito OtD for SDEs

  • At the moment we support DtO for both, of course.
  • Reversible solvers only exist for Stratonovich so that's a moot point.
  • The "ODE" OtD corresponds to Stratonovich.

The only one remaining is Ito OtD.

Check on IController

The latent ODE example seems to fail if using IController(rtol=1e-5, atol=1e-5, unvmap_dt=True). Possibly indicative of something going wrong with the adaptivity?

Add support for online settings

This is relevant for CDEs in particular, but in principle one could also want to solve an ODE or SDE continuously as part of some larger program.

Related to #5?

Space-time Levy area / full Levy area

At the moment all of our SDE solvers are Levy-area-free. It should be relatively straightforward to add support for different kinds of Levy area, by extending the evaluate interface.

Use Kahan summation?

At least as an option, this might be a nice thing to have.
Would need to think about the correct way of backpropagating through it + making sure XLA doesn't compile it away.

C.f. also Section VIII.5 of Hairer-Lubich-Wanner, which is the only reference I know of that actually discusses this.

Tests + fixes

  • Semi-implicit Euler
  • Reversible Heun
  • Leapfrog/midpoint

Support weakly increasing times in global interpolation routines

At the moment all global interpolation routines require that the sequence of times be strictly increasing. In practice it's often helpful to admit weakly increasing times as well, e.g. to pad variable-length time series.

In particular this is in contrast to dense interpolation (generated by diffeqsolve(saveat=SaveAt(dense=True))) which will sometimes use weakly increasing times. (Although this is a detail that is hidden from the user.)

Operations that need addressing:

  • linear_interpolation
  • rectilinear_interpolation
  • backward_hermite_coeffs

(The corresponding classes should already handle this.)

Add differentiable event handling

One thing that needs thinking about is how this compares to the use of jump processes as a driving control, which morally speaking do something very similar.

  • Discrete terminating events
  • Discrete non-terminating events
  • Continuous terminating events
  • Continuous non-terminating events

The continuous events can be implemented by using the most recent dense_info and performing a nonlinear solve to locate the event location.

New solvers

SDE:

  • Milstein
  • Euler-Heun
  • SRKs (Additive)
  • Talay (Ito commutative)

Implict:

DAE:

  • Lots of overlap with Rosenbrock methods, get them in at the same time.
  • Standard DAE->ODE conversion (+optional nonlinear projections)
  • Set made_jump=True if dconstraint/dz=0
  • Mass matrices: work with these or work with constraint functions? Former is a little more general; maybe less efficient in some cases? Also think about traced mass matrices, possibly space-varying (e.g. quasilinear problems), possibly rank-varying.

Symplectic:

  • Higher-order symplectic methods
  • A lot of the lower-order methods (Stormer-Verlet, Leapfrog, Kick-drift-kick, Semi-implicit Euler) are all essentially the same thing. It would be good to write out something making this explicit, plus maybe add some special variants if you really want a particular variant.

Other:

Handling discontinuities in time derivative?

Hi,
first of all, let me say that this looks like an amazing project.
I am looking forward to playing around with this :).

In a concrete problem I am dealing with, I have a forced system where the external force is piecewise constant. The external force changes at specific time points (t1, ..., tn), causing a discontinuity of the time derivative.
I would like to use adaptive step-size solvers for increased accuracy, but naively applying adaptive step-size solvers will "waste" a lot of steps to find the point of change.

Would including the change points in SaveAt avoid this problem?
Or is there some other recommended way to handle this?

Add links to references

We have a lot of BibTeX references specified throughout the documentation. It would be nice to add:

  • DOI references to the BibTeX itself
@article{blahblah,
    ...
    doi={...}
}
  • arXiv/DOI hyperlinks adjacent to the BibTeX
[arXiv](...)
@article{...
}

Specify solver orders for SDEs more precisely

At the moment:

  • the ODE order is specified
  • an SDE order is specified -- in principle for whatever the most general type of noise that solver is expecting. (e.g. general noise for Euler, commutative noise for Milstein)

This isn't very flexible. For example what about solving an additive-noise problem with Heun? In this case Heun gets strong order 1, rather than the 0.5 it is specified with. At the moment orders are (only) used for adaptive stepping, so practically speaking this can be handled by passing the appropriate local_order to diffrax.PIDController, but this isn't well-advertised.

Generally speaking determining this automatically seems to be a huge can of worms. The strong convergence order for every kind of solver for every kind of noise simply aren't known. (Which is the reason we punt the problem into user-land instead.) Moreover we'd probably need to do some pretty involved introspection of the information we're passed to determine what kind of order to expect, that may well simply be wrong in edge cases.

`IController`: investigate removing the `stop_gradient`s

At the moment there are several stop_gradients in there, primarily to avoid instability observed in the backward pass. There's been a few bugs fixed in IController since then, so this may no longer be the case?

It may still be nice to keep as an option, since non-backpropagation through rejected steps is an increase in efficiency with minimal change to the gradient.

Local interpolations: switch from weird 4th order thing to standard 3rd order Hermite

During the early phases of the development of this library, we pretty much copied the interpolation routines wholesale from torchdiffeq without much thought. In particular this means a 4th order polynomial interpolation scheme is used for several solvers -- which Dopri5 uses with the appropriate care, but most other algorithms do not.

Where no better choice is available then these other algorithms should switch to use 3rd order Hermite interpolation, which is really the standard choice when a custom dense interpolation scheme hasn't been developed.

Short term TO-DOs

  • Initial automated stepsize selection is giving different values to torchdiffeq (and needs using on the latent ODE example)
  • _step is getting recompiled 61 steps into a CNF?

Example for solving forced ODE

Very cool work!

It would be great to have an example on how to solve an ODE that is directly forced by some signal x(t), e.g. a forced mass-spring-damper

m y'' + r y' + k y = x(t).

Do I understand correctly, that the controlled ODEs are "forced" by the derivative of x(t)?

Add documentation

  • docstrings etc.
  • The full extent of things we can vmap -- vmap all the things!
  • Make sure to include solver.step as an explicit part of the public API, for use independent of diffeqsolve.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.