Git Product home page Git Product logo

jaxsw's Introduction

Simple Ocean Models in JAX

Motivation

Sea surface height is a gateway variable to other important ocean properties, e.g. sea surface temperature, geostrophic currents. There are many massive models that attempt to model this, e.g. NEMO, MOM6, MITGCM. However they are very expensive and quite difficult to run. So there are many small models that are useful approximations, e.g. Quasi-Geostrophic and Shallow Water. This repo attempts to showcase how we can use some modern tools to construct dynamical systems for PDEs.

What makes this different from the tons and tons of different implementations is that we will be using JAX. JAX is basically numpy on steroids because the API is very similar but we also get some of the modern toolsets along with speed. Most importantly, JAX is differentiable. Having a differentiable model is important because it allows us to:

  • Learn some of the hyperparameters if necessary
  • Embed this in machine learning models where differentiability is needed

Why Not PyTorch?

We could easily just use PyTorch. However, there are some advantanges to JAX over other languages like PyTorch and TensorFlow:

  • Familiar Numpy-Like API which is nice for newcomers in the scientific community
  • CPU/GPU/TPU capabilities with minimal code changes
  • Gradient Operators instead of storing the transformations in the tensors
  • Functional-like language which is easier to read for newcomers
  • Auto-Vectorization so we can easily parallize the operators for multiple dimensions without code changes (note: TensorFlow has this)
  • JIT compilation speeds up the code by a lot (note: both PyTorch and TensorFlow has this)

Applications

This library will be relatively general but this will be a development platform for the following applications:

  • Generate Simulations
  • Surrogate Models
  • Data Assimilation

Main Components

Without making it too complicated, we settled on a few key objects that the package will comprise of.

Domain

This will be the object to define the grids where all of the fields live. It will be easy to access the coordinates, boundaries, grids and cell volumes. We don't need to store the grid all of the time, instead we just generate it as we see fit.

Operators

This will be a suite of functions for different gradient calculations and combined operations for well-known equations. We will primarily focus on finite difference operators with the finiteDiffX package. At a later date, we can introduce spectral and finite volume methods.

Integrators

We will use the diffrax package to do the time integration. We'll use the method-of-lines technique to formulate all of our PDEs to calculate the RHS of the equation for the state at $t$. Then we can propagate them through the time integrator to get the state at $t+1$.

Params, State & Equations of Motion

We will have a general API for how we can keep store parameters, initialize states and pass thew both through the equation of motion. To handle what's differentiable and what is not, we will use the equinox package.

Configs

We will use the hydra package to keep track of the configurations and to initialize parameters for experiments.


Installation

pip

We can directly install it via pip from the

pip install "git+https://github.com/jejjohnson/jaxsw.git"

Cloning

We can also clone the git repository

git clone https://github.com/jejjohnson/jaxsw.git
cd jaxsw

poetry

The easiest way to get started is to simply use the poetry package which installs all necessary dev packages as well

poetry install

pip

We can also install via pip as well

pip install .

Conda

We also have a conda environment with all of the equivalent dependencies.

conda env create -f environments/jax_linux_cpu.yaml
conda activate jaxsw

Contributions


Acknowledgements

  • qg_utils - useful functions for dealing with QG equations
  • jaxdf - Nice API for defining operators for PDEs.
  • jax-cfd - Nice API for defining PDEs
  • invobs-data-assimilation - Nice API for Dynamical Systems
  • MASSH - The differentiable QG and SW models applied to sea surface height interpolation.
  • qgm_pytorch - Quasi-Geostrophic Model in PyTorch
  • QGNet - QG implementation in PyTorch with convolutions.

jaxsw's People

Contributors

jejjohnson avatar roxyboy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

roxyboy

jaxsw's Issues

TODO: Functional API

A functional API for the standard spatial operator types that we typically see in numerical methods. We will target the following methods:

  • Finite Difference - just a standard wrapper for FiniteDiffX
  • Finite Volume - needs interpolation and simple finite difference operators
  • Spectral Methods - just a standard wrapper for FiniteDiffX

——
Grid Operators

These are operators that are useful for finite volume methods to do grid to grid interpolation and gradients. It is also useful for finite difference/volume methods for padding and boundary conditions.

API

  • Grid.interp
  • Grid.diff
  • Grid.pad
  • (Grid.cumsum)
  • grid.boundary
u_avg = grid.interp(u, axis=0, method="mean", operation=linear”)

du_dx = grid.difference(u, step_size=dx, axis=0, method="")

u_pad = grid.pad(u, axis=0, method="constant", values=0)

u_dirichlet = grid.boundary(u, axis=0, method=dirichlet”)

Documentation | Examples: Shallow Water

Tutorial: 12 Steps to Navier Stokes

This is the canonical 12 Steps to Navier-Stokes tutorial. All of this will be implemented using the package using the:

  • Low-Level API - FiniteDiffX and Diffrax (PRIORITY)
  • Mid-level API - kernex (kmap | kscan)
  • High-level API - PDiffrax

Steps

1D Problems

  • 1D Linear Advection
  • 1D Non-Linear Advection
  • 1D Diffusion
  • 1D Burgers

2D Problems

  • 2D Linear Advection
  • 2D Non-Linear Advection
  • 2D Diffusion
  • 2D Burgers

Hyperbolic Problems

  • Laplace Equation
  • Poisson Equation

Navier-Stokes

  • 2D Navier-Stokes - Cavity Flow
  • 2D Navier-Stokes - Channel Flow

Arakawa C-Grid

I want a working demo which showcase how one can define an Arakawa C-Grid for the SW and QG PDEs. I think this a minimum bang for buck improvement we can do for solving these PDEs with simple finite difference schemes.

Example API

I like the API seen in this codebase. However, we would simply define the state with the appropriate staggering. We can then use the transformations in the grid functions module to move the variables around, i.e., u --> v, v --> u, u,v --> H, etc.

Example PDEs

Shallow Water Equations

The most obvious choice is the SW equations as we have to solve a system of 3 states simultaneously. Lots of potential to blow up with instabilities so the C-Grid should be very helpful.

Quasi-Geostrophic Equations

Based on an implementation from Louis Thiry, we can define the QG model using the staggered grid for the velocities, potential vorticity and streamfunction/SSH.

Shallow-Water Model

Tutorial: Learning Problems

An overview of the learning problems available

Parameter Estimation

$$ \boldsymbol{\theta}^* = \underset{\boldsymbol{\theta}}{\text{argmin}}\hspace{2mm}\mathcal{L}(\boldsymbol{\theta}) $$

  • ODE Parameters (e.g. Lorenz63, Lorenz96)
  • PDE Parameters (e.g. Diffusivity Coefficient)
  • UDE (e.g. QG, SW)
    • Parameterizations
    • Hybrid Models
    • Surrogate Model

State Estimation

$$ \mathbf{x}^* = \underset{\mathbf{x}}{\text{argmin}}\hspace{2mm}\mathcal{U}(\mathbf{x}) $$

  • Inverse Problem
    • Gradients - Unrolling, Implicit Diff
  • Noisy Observations
  • Missing Observations

Bi-Level Optimizations

  • Inverse Problem
    • Gradients - Unrolling, Implicit Diff/Adjoint
  • 4DVarNet (Learning-to-Learn)

Generic PDE Terms

This will feature some fairly generic PDE terms that keep cropping up that we may want to solve.

Diffusion

Some generic Diffusion equation ν (∂²η/∂x² + ∂²η/∂y²) that we may want to solve

$$ \begin{aligned} &=\nu \left( \frac{\partial^2 \eta}{\partial x^2} + \frac{\partial^2 \eta}{\partial y^2} + \frac{\partial^2 \eta}{\partial z^2}\right) \end{aligned} $$

  • Naive
    • 1D
    • 2D

Advection

Some generic advection terms that we may want to solve, i.e. u ∂η/∂x + v ∂η/∂y + w ∂η/∂z

$$ \begin{aligned} &=u \frac{\partial \eta}{\partial x} + v \frac{\partial \eta}{\partial y} + w \frac{\partial \eta}{\partial z} \end{aligned} $$

  • Naive
    • 1D
    • 2D
  • Upwind
    • 1D
    • 2D
  • #21

Note: I've also seen instantiations of this with the name "determinant Jacobian".

Package Tutorial List

This is my mega-task list where I just dump all of my ideas into a single issue. At a later date, I will break these up into much smaller issues.

Package Elements

This will go over some specifics of the jaxsw package to showcase how it works and some things one can do. It will be taught all within the context of PDEs.

5 Levels of Granularity of PDE's

Tutorial: Boundary Value Operators

This showcases some of the different boundary conditions ones could encounter. This is easily the most underrated part of PDEs in general and cause most of the problems.

  • Periodic
  • Dirichlet
  • Neumann
  • Robin
  • Open
  • Custom

Math Mistakes

Comments From Thiery

  • Check potential vorticity term - $q=\hat{k}\cdot \nabla\times u$
  • Check Hyperdiffusivity term $\Delta^3$

Tutorial: Anatomy of a PDE

In this tutorial, we will walk-through the different abstract components of a PDE and how they all come together using a minimal example.

  • Domain
  • State
  • Equation of Motion (RHS)
  • Parameters
  • Boundary Conditions
  • Spatial Discretization
  • Time Stepper

Tutorial: Data Assimilation

Some simple tutorials for using the package for a data assimilation problem.

  • Backwards-Forwards Nudging
  • 4DVariational
    • Weak-Constrained (Step)
    • Strong-Constrained (Trajectory)

Add Finite Difference schemes

Arakawa C-Grid API

This is a high level API to represent the Arakawa C-grid where we have the variable of interest, the u-v velocities, and the tracer. It should be easy for people to move between the 4 grids and calculate differences. It will handle all of the ghost points and boundaries.

Tutorial: 3 APIs for PDEs

This tutorial showcases how granular/explicit one can be when defining a PDE. It will start with writing everything from scratch with loops and little by little start automating things here and there. It's a great exercise to teach people different levels of abstraction that are appropriate for different use cases.

  • Scratch - Loops w/ Numba
  • Functional
    • Spatial Discretisation - Slicing w/ FiniteDiffX
    • TimeStepper - Scan w/ JAX
  • Automated w/ Kernex
    • Spatial Discretisation w/ kmap
    • TimeStepping w/ kscan
  • Explicit
    • Spatial Discretisation w/ finitediffX
    • Time Stepping w/ diffrax
  • Implicit
    • Discretisation w/ jaxdf-like
    • TimeStepping w/ jaxdf-like

I - Prebuilt Models

For just diving right in and using it!


II - Operator API

A medium level of granularity.


III - Functional API

The finest level of granularity available.

Lorenz Family

Need the obligatory Lorenz family of methods.

  • Lorenz 63
  • #23
  • Lorenz 96

Add Staggered Domain

Currently the domain is regular. But it would be nice to add an optional staggering option when constructing the grid with some helpful methods to convert between the grids.

TODO: Operator API

Grid Operators

This will make the functional grid operators to be compatible with fields and domains.

  • Interpolate

——
Finite Differences (Slicing)

This will make the FiniteDiffX package compatible with the Field API.

  • Difference
  • Laplacian
  • Divergence
  • Curl 2D

—-
Finite Volume

This will make the functional grid operators to be compatible with fields and domains.

  • Difference
  • Laplacian
  • Divergence
  • Curl 2D

——
Finite Differences (Convolutions)

This will make the Convolutional finite difference API compatible with the Field API.

  • Difference
  • Laplacian - example
  • Divergence
  • Curl 2D

——
Spectral Differences (Convolutions)

This will make the Convolutional finite difference API compatible with the Field API.

  • Difference
  • Laplacian
  • Divergence
  • Curl 2D

Generic Elliptical Terms

Some generic Hyperbolic functions that we may need to solve, e.g. Laplacian, Poisson. It's 2023 so we'll stick with iterative solvers that scale, e.g., the steepest descent methods, the conjugate gradient methods and the discrete sine transform methods.

$$ \begin{aligned} \boldsymbol{\nabla}^2\boldsymbol{u}&= 0 \\ \boldsymbol{\nabla}^2\boldsymbol{u}&= \boldsymbol{b} \end{aligned} $$

Methods From Scratch

I have don't a few methods from scratch but they are not well tested...

  • Steepest Descent
  • Conjugate Gradient
  • Discrete Sine Transformation

External Packages

There are many external packages that can be used with general linear solvers

This showcases some of the ways we can do linear solvers using the package.

  • Exact Linear Solver
  • Iterative Linear Solvers w/ jaxopt (Steepest Descent, Conjugate Gradient)
  • Discrete Sine Transform
  • Staggered Discrete Sine Transform
  • Spectral

Tutorial: Geometries

An overview of the different domains we can encounter for PDEs. The easy ones are the uniform grid-like scenarios and the more applied situations feature the spherical lat/lon/height cases. A very important improvement is the Staggered grid which will be useful for multivariate PDEs with velocities and such.

  • Uniform Grid
  • Latitude, Longitude (Meters)
  • Latitude, Longitude (Spherical)
  • Arakawa C-Grid

Simple Ocean Models

Here, we showcase some examples of simple ocean models we can use within the package. These can be demonstrated with free-runs that are conditioned on real data.

Quasi-Geostrophic Equations

  • 2D Model - Gyre
    • 1 Layer
    • 2 Layer
  • 2D Model - Jet

Shallow Water Equations

Documentation

Main Tutorials

These tutorials showcase how the user should think about PDEs using our abstraction.
Then they showcase how to use the software to comply with this abstraction.

  • Anatomy of a PDE - #40
  • Coding a PDE in 3 Ways - #38
  • Spatial Operators - #41
  • Boundary Value Operators - #42
  • Grid Operations
  • TimeSteppers - #39
  • 12 Steps to Navier-Stokes - #11

Extended Deep Dives

These tutorials go over some of the nuances of the package that may help the user define special cases.

  • Geometries - #43
  • Modern Inversion Schemes - #24
  • Experimental Scripts w/ Hydra
  • Spherical Coordinates - #33
  • FAQs & Gotchas
    • NANs
    • Non-Dimensionalization

Simple Ocean Models

These tutorials showcase the simplified ocean models.

  • Lorenz Family - #23
  • Quasi-Geostrophic Equation - #36
  • Shallow Water Equations - #37

Learning Schemes - #13

These tutorials introduce the user to how one can do parameter and/or state estimation using the autodiff framework.

  • Gradients & Optimization - #6
  • State Estimation
  • Parameter Estimation
  • State & Parameter Estimation

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.