google / jax-cfd Goto Github PK
View Code? Open in Web Editor NEWComputational Fluid Dynamics in JAX
License: Apache License 2.0
Computational Fluid Dynamics in JAX
License: Apache License 2.0
I have copied the code from spectral_forced_turbulence.ipynb, but it gives the following error:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<timed exec> in <module>
[/usr/local/lib/python3.8/dist-packages/jax_cfd/base/initial_conditions.py](https://localhost:8080/#) in filtered_velocity_field(rng_key, grid, maximum_velocity, peak_wavenumber, iterations)
110 # specified maximum velocity.
--> 111 return funcutils.repeated(project_and_normalize, iterations)(velocity)
112
28 frames
UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<timed exec> in <module>
[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/util.py](https://localhost:8080/#) in _check_arraylike(fun_name, *args)
343 if not _arraylike(arg))
344 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 345 raise TypeError(msg.format(fun_name, type(arg), pos))
346
347
TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.
You can reproduce this with this Google Colab. It is also confusing that the Google Colab provided in the README does not install jax-cfd. What is the intended way of running them?
Edit: it seems that the jax-cfd version on PyPi is outdated. Downloading the source from the repository works.
Hello,
I was trying to install locally jax-cfd
using conda
. This is what I have done to set up the environment and if I try to import the package in a script I get the error of the description above. Is there anything particular that I am missing about how to build the library?
(base) user@host:~/$ conda create -y -n jax-cfd-trial python=3.10
Channels:
- defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done
## Package Plan ##
environment location: /home/user/miniconda3/envs/jax-cfd-trial
added / updated specs:
- python=3.10
The following NEW packages will be INSTALLED:
_libgcc_mutex pkgs/main/linux-64::_libgcc_mutex-0.1-main
_openmp_mutex pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu
bzip2 pkgs/main/linux-64::bzip2-1.0.8-h5eee18b_5
ca-certificates pkgs/main/linux-64::ca-certificates-2024.3.11-h06a4308_0
ld_impl_linux-64 pkgs/main/linux-64::ld_impl_linux-64-2.38-h1181459_1
libffi pkgs/main/linux-64::libffi-3.4.4-h6a678d5_0
libgcc-ng pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1
libgomp pkgs/main/linux-64::libgomp-11.2.0-h1234567_1
libstdcxx-ng pkgs/main/linux-64::libstdcxx-ng-11.2.0-h1234567_1
libuuid pkgs/main/linux-64::libuuid-1.41.5-h5eee18b_0
ncurses pkgs/main/linux-64::ncurses-6.4-h6a678d5_0
openssl pkgs/main/linux-64::openssl-3.0.13-h7f8727e_0
pip pkgs/main/linux-64::pip-23.3.1-py310h06a4308_0
python pkgs/main/linux-64::python-3.10.14-h955ad1f_0
readline pkgs/main/linux-64::readline-8.2-h5eee18b_0
setuptools pkgs/main/linux-64::setuptools-68.2.2-py310h06a4308_0
sqlite pkgs/main/linux-64::sqlite-3.41.2-h5eee18b_0
tk pkgs/main/linux-64::tk-8.6.12-h1ccaba5_0
tzdata pkgs/main/noarch::tzdata-2024a-h04d1e81_0
wheel pkgs/main/linux-64::wheel-0.41.2-py310h06a4308_0
xz pkgs/main/linux-64::xz-5.4.6-h5eee18b_0
zlib pkgs/main/linux-64::zlib-1.2.13-h5eee18b_0
Downloading and Extracting Packages:
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
#
# To activate this environment, use
#
# $ conda activate jax-cfd-trial
#
# To deactivate an active environment, use
#
# $ conda deactivate
(base) user@host:~/$ conda activate jax-cfd-trial
(jax-cfd-trial) user@host:~/$ pip install jaxlib jax jax-cfd
Collecting jaxlib
Using cached jaxlib-0.4.26-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting jax
Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB)
Collecting jax-cfd
Using cached jax_cfd-0.2.0-py3-none-any.whl.metadata (1.4 kB)
Collecting scipy>=1.9 (from jaxlib)
Using cached scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting numpy>=1.22 (from jaxlib)
Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting ml-dtypes>=0.2.0 (from jaxlib)
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting opt-einsum (from jax)
Using cached opt_einsum-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Collecting tree-math (from jax-cfd)
Using cached tree_math-0.2.1-py3-none-any.whl.metadata (477 bytes)
Using cached jaxlib-0.4.26-cp310-cp310-manylinux2014_x86_64.whl (78.8 MB)
Using cached jax-0.4.26-py3-none-any.whl (1.9 MB)
Using cached jax_cfd-0.2.0-py3-none-any.whl (197 kB)
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
Using cached scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.6 MB)
Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Using cached tree_math-0.2.1-py3-none-any.whl (21 kB)
Installing collected packages: numpy, scipy, opt-einsum, ml-dtypes, jaxlib, jax, tree-math, jax-cfd
Successfully installed jax-0.4.26 jax-cfd-0.2.0 jaxlib-0.4.26 ml-dtypes-0.4.0 numpy-1.26.4 opt-einsum-3.3.0 scipy-1.13.0 tree-math-0.2.1
(jax-cfd-trial) user@host:~/$ conda env export
name: jax-cfd-trial
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h5eee18b_5
- ca-certificates=2024.3.11=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.13=h7f8727e_0
- pip=23.3.1=py310h06a4308_0
- python=3.10.14=h955ad1f_0
- readline=8.2=h5eee18b_0
- setuptools=68.2.2=py310h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- tzdata=2024a=h04d1e81_0
- wheel=0.41.2=py310h06a4308_0
- xz=5.4.6=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- jax==0.4.26
- jax-cfd==0.2.0
- jaxlib==0.4.26
- ml-dtypes==0.4.0
- numpy==1.26.4
- opt-einsum==3.3.0
- scipy==1.13.0
- tree-math==0.2.1
prefix: /home/user/miniconda3/envs/jax-cfd-trial
(jax-cfd-trial) user@host:~/$ python -c "import jax_cfd"
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/__init__.py", line 19, in <module>
import jax_cfd.base
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/__init__.py", line 17, in <module>
import jax_cfd.base.advection
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/advection.py", line 20, in <module>
from jax_cfd.base import boundaries
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/boundaries.py", line 20, in <module>
from jax_cfd.base import grids
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/grids.py", line 25, in <module>
from jax_cfd.base import array_utils
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/array_utils.py", line 30, in <module>
Array = Union[np.ndarray, jnp.DeviceArray]
File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
Seems like the package is using outdated dependencies from jax
. Should I just downgrade the package to an older version?
Thanks for the help!
Using the 2D channel flow notebook structure I have tried making some changes, starting by changing the Re of the flow, but it remains laminar. Is there a setting preventing the transition to turbulent that I am missing?
Thank you
Hi,
I'm trying to understand the logic behind the various functions in ml.train_utils.
in the definition of loss_and_gradient the returned _loss function uses the target trajectory as part of the input of the trajectory_fn. I cannot make sense of it. The description of the function states that the trajectory_fn should accepts params
and initial_velocity
, which make sense to me, but then I don't understand why would we want to use the target_trajectory as the initial_velocity (which presumably doesn't have the same shape as initial_velocity since one is a trajectory and the other the velocity at a single time step).
I'm pretty much stuck here because I now don't know how to use this in the context of defining a train_step function etc.
Best regards,
Gwen
Hi, I'm so excited to see your work, it's so exciting! I want to implement your project on my own platform system, but since there is no jaxlib source code on the python official website, the compiled jaxlib is not suitable for my working platform. I would be very grateful if you can provide the source code of jaxlib
I have run your program through your ipynb, but I can't find where is the data input? Are the nc and kpl files downloaded through gsutils? I would like to try out some other simple examples with your method. In addition, is the method of using machine learning in base/ml? I see that ml_model_inference_demo.ipynb uses ml, I hope to get your help, thank you very much.
ds.pipe(lambda ds: ds.u).plot.imshow(
'x', 'y', col='time',cmap=seaborn.cm.rocket, robust=True, col_wrap=4, aspect=2);
Solution:
ds.pipe(lambda ds: ds.u).plot.imshow(x='x',y='y',col='time',cmap=seaborn.cm.rocket, robust=True, col_wrap=4, aspect=2);
I haven't read through the paper yet but I figured I'm inquire, is this research limited in scope to just 2d or are tensor based calculations appropriate and practical for 3d modeling as well?
Upon following the instructions to setup a development environment, i.e:
git clone https://github.com/google/jax-cfd; cd jax-cfd
python -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install jaxlib
pip install -e ".[complete]"
and running the test suite, as per:
python -m pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py
the test-suite fails with:
11 failed, 447 passed, 242 warnings, 4 errors
I've attached a full output of the tests here.
Quick note on running tests: the instructions in the README.md
don't run if you set the environment up like I did above. The modification with python -m
prefixing the command is required.
Are there any plans to fix these to ensure that all tests pass?
Jax.devicearray is no longer available in recent JAX distributions.
Hi! I have a problem and sincerely hope to get your help!
Problem description: I wrote a python program myself to implement some functions of the project, but I encountered a problem in the first step.
In grids.py, line 66, grid: Grid, I didn't see the definition or import of Grid. In the end, I implemented the def velocity_trajectory_to_xarray() function of xarray_utils.py of the data submodule in the python file I wrote and reported an error " NameError: name 'Grid' is not defined", what should I do about this?
Hi,
Thanks for this amazing library!
I installed the module using the command given in the README pip install jax-cfd[complete]
to have access to all submodules.
When I try to import the spectral module via import jax_cfd.spectral as spectral
I get the following error message:
import jax_cfd.spectral as spectral
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ModuleNotFoundError: No module named 'jax_cfd.spectral'
Some more information about my system and the environment:
Python 3.9.7 (default, Sep 16 2021, 13:09:58)
0.2.21
0.1.0
Linux 5.10.16.3-microsoft-standard-WSL2
Is there any way that I can resolve this issue?
Thanks in advance!
When running layers_util_test.py with jax[CPU] version later than 0.4.1, it throws a key error 'experimental_xmap_spmd_lowering'.
Hello!
First, thanks for sharing this amazing project. Second, sorry if this isn’t the right place to ask this question.
I would like to know if there is an ongoing effort to develop a block structured NS-FV LES/WMLES code able to simulate academic relevant flows, i.e turbulent channel flow, NASA Hump etc.
If yes, could someone please point me to the right direction? If no, do you think the current code could be used as a starting point?
Best,
Eduardo
Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.
Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:
import jax_cfd.base.grids as gd
import jax.numpy as jnp
grid = gd.Grid([4,], domain = [(0, 1),])
array_of_values = jnp.array([2.0, 2.0, 3.0, 4.0])
centered_array = grid.center(array_of_values)
print(centered_array + centered_array)
But this throws an exception:
bc = gd.BoundaryConditions((gd.PERIODIC,))
centered_variable = gd.GridVariable(centered, bc)
print(centered_variable + centered_variable)
I'm happy to have a go at implementing this myself, if someone isn't already working on it.
Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:
print(jnp.abs(centered_array))
But this works:
import numpy as np
print(np.abs(centered_array))
I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.
The readme says that an example notebook showing how to train a simple hybrid ML + CFD model is being prepared. Is there an ETA on this?
I'm trying to get it to work using JAX-CFD and Flax but I'm running into problems. An example notebook would help me a lot!
Any plan to move to colab?
jax-cfd/jax_cfd/collocated/equations.py
Lines 64 to 85 in d215f13
For example, for this function, only the velocity variable is the output, in the function call, it is like this:
step_fn = funcutils.repeated(
collocated.equations.semi_implicit_navier_stokes(
density=density, viscosity=viscosity, dt=dt, grid=grid),
steps=inner_steps)
rollout_fn = jax.jit(funcutils.trajectory(step_fn, outer_steps))
_, trajectory = jax.device_get(rollout_fn(v0))
where v0
is a GridVariableVector
.
Let us say if we want the time derivative to be the output as well, this is my ugly way to do it:
basically it concats a dummy tensor after the velocity
@jax.named_call
def navier_stokes_step(v: GridVariableVector) -> GridVariableVector:
v, _ = (v[0], v[1]), (v[2], v[3])
...
dvdt = sum_fields(*accelerations)
...
v = pressure_projection(v, pressure_solve)
dvdt = tuple(grids.GridVariable(dudt, u.bc) for u, dudt in zip(v, dvdt))
return v+dvdt
return navier_stokes_step
then append a dummy GridVariable
after v0
, and call the same step function in the rollout,.
vt0 = tuple(grids.GridVariable(grids.GridArray(jnp.zeros_like(u.data), u.offset, grid),
u.bc) for u in v0)
v0 += vt0
is there any template in jax to do this?
I'm trying to run channel_flow_demo.ipynb
, but I get the following error when I call the step_fn
:
(massive traceback truncated)
/usr/local/lib/python3.7/dist-packages/jax_cfd/base/advection.py in (.0)
107 for u, target_offset in zip(v, target_offsets))
108 aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
--> 109 for target_offset in target_offsets)
110 return _advect_aligned(aligned_c, aligned_v)
111/usr/local/lib/python3.7/dist-packages/jax_cfd/base/interpolation.py in tvd_interpolation(c, offset, v, dt)
282 c_left = c.shift(-1, axis)
283 c_right = c.shift(1, axis)
--> 284 c_next_right = c.shift(2, axis)
285 # Velocities of different sign are evaluated with limiters at different
286 # points. See equations (4.34) -- (4.39) from the reference above./usr/local/lib/python3.7/dist-packages/jax_cfd/base/grids.py in shift(self, offset, axis)
269 GridArray has offsetu.offset + offset
.
270 """
--> 271 return self.bc.shift(self.array, offset, axis)
272
273 def _interior_grid(self) -> Grid:/usr/local/lib/python3.7/dist-packages/jax_cfd/base/boundaries.py in shift(self, u, offset, axis)
77u.offset + offset
.
78 """
---> 79 padded = self._pad(u, offset, axis)
80 trimmed = self._trim(padded, -offset, axis)
81 return trimmed/usr/local/lib/python3.7/dist-packages/jax_cfd/base/boundaries.py in _pad(self, u, width, axis)
116 if bc_type != BCType.PERIODIC and abs(width) > 1:
117 raise ValueError(
--> 118 'Padding past 1 ghost cell is not defined in nonperiodic case.')
119
120 if bc_type == BCType.PERIODIC:ValueError: Padding past 1 ghost cell is not defined in nonperiodic case.
I installed JAX-CFD on a fresh Colab notebook using pip install git+https://github.com/google/jax-cfd.git
Please help me fix this. Let me know if I'm using the wrong version (since the repo is constantly being developed).
Thanks!
Hello,
This is a very nice JAX implementation for CFDs! I have a few questions about your latest released notebook spectral_forced_turbulence.ipynb.
What is the exact NS equation that the notebook aims to solve? Would it be possible to provide some detailed expressions of NS equations and boundary conditions? Or would it be possible to provide some references? I think It will be super helpful to people like me who are more familiar with machine learning and know less about CFD.
If I understand correctly, I think the code aims to solve the NS equation (vorticity-velocity form) using the pseudo-spectral method. So we need to enforce periodic boundary conditions for the velocity (or vorticity?) However, when I try to run the following piece of code (which should give me a random velocity field and visualize it at the boundary)
v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 4)
u = v0[0].data
v = v0[1].data
plt.plot(u[:, 0])
plt.plot(u[:, -1])
plt.show()
plt.plot(u[0, :])
plt.plot(u[-1, :])
plt.show()
I got the following figures below. It seems that the periodic boundary condition is not fully imposed. Could you please take a look or correct me if I am wrong?
Once again, this is really a project! Looking forward to your reply.
Hi,
I tried to define some boundary conditions. Until recently, everything went perfectly fine with
import jax_cfd.base.grids as grids
bc = grids.BoundaryConditions((grids.PERIODIC, grids.PERIODIC))
Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: module 'jax_cfd.base.grids' has no attribute 'BoundaryConditions'
I saw that you moved the boundary conditions to a new module called boundaries.py
. I tried then to access them accordingly from the new module via:
import jax_cfd.base as cfd
bc = cfd.boundaries.periodic_boundary_conditions(ndim=2)
Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: module 'jax_cfd.base' has no attribute 'boundaries'
But now, neither of these work anymore (see the appended error messages) and I wonder how I can define boundary conditions.
I tried installing JAX-CFD according to the documentation and according to #90, but neither of them gave me a solution to my problem.
My python version is 3.8.
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
[GCC 7.5.0] :: Anaconda, Inc. on linux
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.