Git Product home page Git Product logo

jax-cfd's Introduction

JAX-CFD: Computational Fluid Dynamics in JAX

Authors: Dmitrii Kochkov, Jamie A. Smith, Peter Norgaard, Gideon Dresdner, Ayya Alieva, Stephan Hoyer

JAX-CFD is an experimental research project for exploring the potential of machine learning, automatic differentiation and hardware accelerators (GPU/TPU) for computational fluid dynamics. It is implemented in JAX.

To learn more about our general approach, read our paper Machine learning accelerated computational fluid dynamics (PNAS 2021).

Getting started

The "notebooks" directory contains several demonstrations of using the JAX-CFD code.

Organization

JAX-CFD is organized around sub-modules:

  • jax_cfd.base: core finite volume/difference methods for CFD, written in JAX.
  • jax_cfd.spectral: core pseudospectral methods for CFD, written in JAX.
  • jax_cfd.ml: machine learning augmented models for CFD, written in JAX and Haiku.
  • jax_cfd.data: data processing utilities for preparing, evaluating and post-processing data created with JAX-CFD, written in Xarray and Pillow.

A base install with pip install jax-cfd only requires NumPy, SciPy and JAX. To install dependencies for the other submodules, use pip install jax-cfd[ml], pip install jax-cfd[data] or pip install jax-cfd[complete].

Numerics

JAX-CFD is currently focused on unsteady turbulent flows:

  • Spatial discretization:
    • Finite volume/difference methods on a staggered grid (the "Arakawa C" or "MAC" grid) with pressure at the center of each cell and velocity components defined on corresponding faces.
    • Pseudospectral methods for vorticity which use anti-aliasing filtering techniques for non-linear terms to maintain stability.
  • Temporal discretization: Currently only first-order temporal discretization, using explicit time-stepping for advection and either implicit or explicit time-stepping for diffusion.
  • Pressure solves: Either CG or fast diagonalization with real-valued FFTs (suitable for periodic boundary conditions).
  • Boundary conditions: Currently only periodic boundary conditions are supported.
  • Advection: We implement 2nd order accurate "Van Leer" schemes.
  • Closures: We currently implement Smagorinsky eddy-viscosity models.

TODO: add a notebook explaining our numerical models in more depth.

In the long term, we're interested in expanding JAX-CFD to implement methods relevant for related research, e.g.,

  • Colocated grids
  • Alternative boundary conditions (e.g., non-periodic boundaries and immersed boundary methods)
  • Higher order time-stepping
  • Geometric multigrid
  • Steady state simulation (e.g., RANS)
  • Distributed simulations across multiple TPUs/GPUs

We would welcome collaboration on any of these! Please reach out (either on GitHub or by email) to coordinate before starting significant work.

Projects using JAX-CFD

Other awesome projects

Other differentiable CFD codes compatible with deep learning:

JAX for science:

Did we miss something? Please let us know!

Citation

If you use our finite volume method (FVM) or ML models, please cite:

@article{Kochkov2021-ML-CFD,
  author = {Kochkov, Dmitrii and Smith, Jamie A. and Alieva, Ayya and Wang, Qing and Brenner, Michael P. and Hoyer, Stephan},
  title = {Machine learning{\textendash}accelerated computational fluid dynamics},
  volume = {118},
  number = {21},
  elocation-id = {e2101784118},
  year = {2021},
  doi = {10.1073/pnas.2101784118},
  publisher = {National Academy of Sciences},
  issn = {0027-8424},
  URL = {https://www.pnas.org/content/118/21/e2101784118},
  eprint = {https://www.pnas.org/content/118/21/e2101784118.full.pdf},
  journal = {Proceedings of the National Academy of Sciences}
}

If you use our spectral code, please cite:

@article{Dresdner2022-Spectral-ML,
  doi = {10.48550/ARXIV.2207.00556},
  url = {https://arxiv.org/abs/2207.00556},
  author = {Dresdner, Gideon and Kochkov, Dmitrii and Norgaard, Peter and Zepeda-Núñez, Leonardo and Smith, Jamie A. and Brenner, Michael P. and Hoyer, Stephan},
  title = {Learning to correct spectral methods for simulating turbulent flows},
  publisher = {arXiv},
  year = {2022},
  copyright = {arXiv.org perpetual, non-exclusive license}
}

Local development

To locally install for development:

git clone https://github.com/google/jax-cfd.git
cd jax-cfd
pip install jaxlib
pip install -e ".[complete]"

Then to manually run the test suite:

pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py

jax-cfd's People

Contributors

agrue avatar akasharidas avatar dionhaefner avatar froystig avatar gideonite avatar hawkinsp avatar jamieas avatar kochkov92 avatar langmore avatar lenamartens avatar lukegb avatar pnorgaard avatar rchen152 avatar shoyer avatar superbobry avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jax-cfd's Issues

Block-Structured NS LES/WMLES code

Hello!

First, thanks for sharing this amazing project. Second, sorry if this isn’t the right place to ask this question.

I would like to know if there is an ongoing effort to develop a block structured NS-FV LES/WMLES code able to simulate academic relevant flows, i.e turbulent channel flow, NASA Hump etc.

If yes, could someone please point me to the right direction? If no, do you think the current code could be used as a starting point?

Best,
Eduardo

Tests fail upon fresh install.

Upon following the instructions to setup a development environment, i.e:

git clone https://github.com/google/jax-cfd; cd jax-cfd
python -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install jaxlib
pip install -e ".[complete]"

and running the test suite, as per:

python -m pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py

the test-suite fails with:

11 failed, 447 passed, 242 warnings, 4 errors 

I've attached a full output of the tests here.

Quick note on running tests: the instructions in the README.md don't run if you set the environment up like I did above. The modification with python -m prefixing the command is required.

Are there any plans to fix these to ensure that all tests pass?

Can't run spectral_forced_turbulence due to jnp.linalg.norm error

I have copied the code from spectral_forced_turbulence.ipynb, but it gives the following error:

---------------------------------------------------------------------------

UnfilteredStackTrace                      Traceback (most recent call last)

<timed exec> in <module>

[/usr/local/lib/python3.8/dist-packages/jax_cfd/base/initial_conditions.py](https://localhost:8080/#) in filtered_velocity_field(rng_key, grid, maximum_velocity, peak_wavenumber, iterations)
    110   # specified maximum velocity.
--> 111   return funcutils.repeated(project_and_normalize, iterations)(velocity)
    112 

28 frames

UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)

<timed exec> in <module>

[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/util.py](https://localhost:8080/#) in _check_arraylike(fun_name, *args)
    343                     if not _arraylike(arg))
    344     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 345     raise TypeError(msg.format(fun_name, type(arg), pos))
    346 
    347 

TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

You can reproduce this with this Google Colab. It is also confusing that the Google Colab provided in the README does not install jax-cfd. What is the intended way of running them?

Edit: it seems that the jax-cfd version on PyPi is outdated. Downloading the source from the repository works.

Modelling turbulent channel flow

Using the 2D channel flow notebook structure I have tried making some changes, starting by changing the Re of the flow, but it remains laminar. Is there a setting preventing the transition to turbulent that I am missing?

Thank you

[Installation] Failed to import spectral module

Hi,

Thanks for this amazing library!
I installed the module using the command given in the README pip install jax-cfd[complete] to have access to all submodules.

When I try to import the spectral module via import jax_cfd.spectral as spectral I get the following error message:

import jax_cfd.spectral as spectral
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ModuleNotFoundError: No module named 'jax_cfd.spectral'

Some more information about my system and the environment:

  • Python version: Python 3.9.7 (default, Sep 16 2021, 13:09:58)
  • JAX version: 0.2.21
  • JAX CFD version: 0.1.0
  • OS: Linux 5.10.16.3-microsoft-standard-WSL2

Is there any way that I can resolve this issue?

Thanks in advance!

AttributeError: AxesImage.set() got an unexpected keyword argument 'msg'

Plot the x-velocity

ds.pipe(lambda ds: ds.u).plot.imshow(
'x', 'y', col='time',cmap=seaborn.cm.rocket, robust=True, col_wrap=4, aspect=2);

Solution:

Plot the x-velocity

ds.pipe(lambda ds: ds.u).plot.imshow(x='x',y='y',col='time',cmap=seaborn.cm.rocket, robust=True, col_wrap=4, aspect=2);

[Error] Channel flow demo

I'm trying to run channel_flow_demo.ipynb, but I get the following error when I call the step_fn :

(massive traceback truncated)

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/advection.py in (.0)
107 for u, target_offset in zip(v, target_offsets))
108 aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
--> 109 for target_offset in target_offsets)
110 return _advect_aligned(aligned_c, aligned_v)
111

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/interpolation.py in tvd_interpolation(c, offset, v, dt)
282 c_left = c.shift(-1, axis)
283 c_right = c.shift(1, axis)
--> 284 c_next_right = c.shift(2, axis)
285 # Velocities of different sign are evaluated with limiters at different
286 # points. See equations (4.34) -- (4.39) from the reference above.

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/grids.py in shift(self, offset, axis)
269 GridArray has offset u.offset + offset.
270 """
--> 271 return self.bc.shift(self.array, offset, axis)
272
273 def _interior_grid(self) -> Grid:

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/boundaries.py in shift(self, u, offset, axis)
77 u.offset + offset.
78 """
---> 79 padded = self._pad(u, offset, axis)
80 trimmed = self._trim(padded, -offset, axis)
81 return trimmed

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/boundaries.py in _pad(self, u, width, axis)
116 if bc_type != BCType.PERIODIC and abs(width) > 1:
117 raise ValueError(
--> 118 'Padding past 1 ghost cell is not defined in nonperiodic case.')
119
120 if bc_type == BCType.PERIODIC:

ValueError: Padding past 1 ghost cell is not defined in nonperiodic case.

I installed JAX-CFD on a fresh Colab notebook using pip install git+https://github.com/google/jax-cfd.git

Please help me fix this. Let me know if I'm using the wrong version (since the repo is constantly being developed).

Thanks!

Boundary Conditions

Hi,

I tried to define some boundary conditions. Until recently, everything went perfectly fine with

import jax_cfd.base.grids as grids
bc = grids.BoundaryConditions((grids.PERIODIC, grids.PERIODIC))
Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'jax_cfd.base.grids' has no attribute 'BoundaryConditions'

I saw that you moved the boundary conditions to a new module called boundaries.py. I tried then to access them accordingly from the new module via:

import jax_cfd.base as cfd
bc = cfd.boundaries.periodic_boundary_conditions(ndim=2)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'jax_cfd.base' has no attribute 'boundaries'

But now, neither of these work anymore (see the appended error messages) and I wonder how I can define boundary conditions.
I tried installing JAX-CFD according to the documentation and according to #90, but neither of them gave me a solution to my problem.

My python version is 3.8.

Python 3.8.12 (default, Oct 12 2021, 13:49:34) 
[GCC 7.5.0] :: Anaconda, Inc. on linux

Input data and training process

I have run your program through your ipynb, but I can't find where is the data input? Are the nc and kpl files downloaded through gsutils? I would like to try out some other simple examples with your method. In addition, is the method of using machine learning in base/ml? I see that ml_model_inference_demo.ipynb uses ml, I hope to get your help, thank you very much.

Question about the random intial velocity in notebook spectral_forced_turbulence.ipynb

Hello,

This is a very nice JAX implementation for CFDs! I have a few questions about your latest released notebook spectral_forced_turbulence.ipynb.

  1. What is the exact NS equation that the notebook aims to solve? Would it be possible to provide some detailed expressions of NS equations and boundary conditions? Or would it be possible to provide some references? I think It will be super helpful to people like me who are more familiar with machine learning and know less about CFD.

  2. If I understand correctly, I think the code aims to solve the NS equation (vorticity-velocity form) using the pseudo-spectral method. So we need to enforce periodic boundary conditions for the velocity (or vorticity?) However, when I try to run the following piece of code (which should give me a random velocity field and visualize it at the boundary)

v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 4)
u = v0[0].data
v = v0[1].data

plt.plot(u[:, 0])
plt.plot(u[:, -1])
plt.show()

plt.plot(u[0, :])
plt.plot(u[-1, :])
plt.show()

I got the following figures below. It seems that the periodic boundary condition is not fully imposed. Could you please take a look or correct me if I am wrong?

bc
bc2

Once again, this is really a project! Looking forward to your reply.

3d modeling?

I haven't read through the paper yet but I figured I'm inquire, is this research limited in scope to just 2d or are tensor based calculations appropriate and practical for 3d modeling as well?

NameError: name 'Grid' is not defined"

Hi! I have a problem and sincerely hope to get your help!
Problem description: I wrote a python program myself to implement some functions of the project, but I encountered a problem in the first step.
In grids.py, line 66, grid: Grid, I didn't see the definition or import of Grid. In the end, I implemented the def velocity_trajectory_to_xarray() function of xarray_utils.py of the data submodule in the python file I wrote and reported an error " NameError: name 'Grid' is not defined", what should I do about this?

trajectory function use target_trajectory as input in loss_and_gradient

Hi,

I'm trying to understand the logic behind the various functions in ml.train_utils.

in the definition of loss_and_gradient the returned _loss function uses the target trajectory as part of the input of the trajectory_fn. I cannot make sense of it. The description of the function states that the trajectory_fn should accepts params and initial_velocity, which make sense to me, but then I don't understand why would we want to use the target_trajectory as the initial_velocity (which presumably doesn't have the same shape as initial_velocity since one is a trajectory and the other the velocity at a single time step).

I'm pretty much stuck here because I now don't know how to use this in the context of defining a train_step function etc.

Best regards,

Gwen

JaxNumPy functions for GridArrays/GridVariables

Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.

Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:

import jax_cfd.base.grids as gd
import jax.numpy as jnp

grid = gd.Grid([4,], domain = [(0, 1),])

array_of_values = jnp.array([2.0, 2.0, 3.0, 4.0])

centered_array = grid.center(array_of_values)

print(centered_array + centered_array)

But this throws an exception:

bc = gd.BoundaryConditions((gd.PERIODIC,))

centered_variable = gd.GridVariable(centered, bc)

print(centered_variable + centered_variable)

I'm happy to have a go at implementing this myself, if someone isn't already working on it.

Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:

print(jnp.abs(centered_array))

But this works:

import numpy as np
print(np.abs(centered_array))

I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.

Request: Example notebook showing training of a hybrid model

The readme says that an example notebook showing how to train a simple hybrid ML + CFD model is being prepared. Is there an ETA on this?
I'm trying to get it to work using JAX-CFD and Flax but I'm running into problems. An example notebook would help me a lot!

Colab

Any plan to move to colab?

i need jaxlib source

Hi, I'm so excited to see your work, it's so exciting! I want to implement your project on my own platform system, but since there is no jaxlib source code on the python official website, the compiled jaxlib is not suitable for my working platform. I would be very grateful if you can provide the source code of jaxlib

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.