Git Product home page Git Product logo

jax-cfd's Issues

Can't run spectral_forced_turbulence due to jnp.linalg.norm error

I have copied the code from spectral_forced_turbulence.ipynb, but it gives the following error:

---------------------------------------------------------------------------

UnfilteredStackTrace                      Traceback (most recent call last)

<timed exec> in <module>

[/usr/local/lib/python3.8/dist-packages/jax_cfd/base/initial_conditions.py](https://localhost:8080/#) in filtered_velocity_field(rng_key, grid, maximum_velocity, peak_wavenumber, iterations)
    110   # specified maximum velocity.
--> 111   return funcutils.repeated(project_and_normalize, iterations)(velocity)
    112 

28 frames

UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)

<timed exec> in <module>

[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/util.py](https://localhost:8080/#) in _check_arraylike(fun_name, *args)
    343                     if not _arraylike(arg))
    344     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 345     raise TypeError(msg.format(fun_name, type(arg), pos))
    346 
    347 

TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

You can reproduce this with this Google Colab. It is also confusing that the Google Colab provided in the README does not install jax-cfd. What is the intended way of running them?

Edit: it seems that the jax-cfd version on PyPi is outdated. Downloading the source from the repository works.

AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

Hello,

I was trying to install locally jax-cfd using conda. This is what I have done to set up the environment and if I try to import the package in a script I get the error of the description above. Is there anything particular that I am missing about how to build the library?

(base) user@host:~/$ conda create -y -n jax-cfd-trial python=3.10 
Channels:
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /home/user/miniconda3/envs/jax-cfd-trial

  added / updated specs:
    - python=3.10


The following NEW packages will be INSTALLED:

  _libgcc_mutex      pkgs/main/linux-64::_libgcc_mutex-0.1-main 
  _openmp_mutex      pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu 
  bzip2              pkgs/main/linux-64::bzip2-1.0.8-h5eee18b_5 
  ca-certificates    pkgs/main/linux-64::ca-certificates-2024.3.11-h06a4308_0 
  ld_impl_linux-64   pkgs/main/linux-64::ld_impl_linux-64-2.38-h1181459_1 
  libffi             pkgs/main/linux-64::libffi-3.4.4-h6a678d5_0 
  libgcc-ng          pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1 
  libgomp            pkgs/main/linux-64::libgomp-11.2.0-h1234567_1 
  libstdcxx-ng       pkgs/main/linux-64::libstdcxx-ng-11.2.0-h1234567_1 
  libuuid            pkgs/main/linux-64::libuuid-1.41.5-h5eee18b_0 
  ncurses            pkgs/main/linux-64::ncurses-6.4-h6a678d5_0 
  openssl            pkgs/main/linux-64::openssl-3.0.13-h7f8727e_0 
  pip                pkgs/main/linux-64::pip-23.3.1-py310h06a4308_0 
  python             pkgs/main/linux-64::python-3.10.14-h955ad1f_0 
  readline           pkgs/main/linux-64::readline-8.2-h5eee18b_0 
  setuptools         pkgs/main/linux-64::setuptools-68.2.2-py310h06a4308_0 
  sqlite             pkgs/main/linux-64::sqlite-3.41.2-h5eee18b_0 
  tk                 pkgs/main/linux-64::tk-8.6.12-h1ccaba5_0 
  tzdata             pkgs/main/noarch::tzdata-2024a-h04d1e81_0 
  wheel              pkgs/main/linux-64::wheel-0.41.2-py310h06a4308_0 
  xz                 pkgs/main/linux-64::xz-5.4.6-h5eee18b_0 
  zlib               pkgs/main/linux-64::zlib-1.2.13-h5eee18b_0 



Downloading and Extracting Packages:

Preparing transaction: done
Verifying transaction: done
Executing transaction: done
#
# To activate this environment, use
#
#     $ conda activate jax-cfd-trial
#
# To deactivate an active environment, use
#
#     $ conda deactivate

(base) user@host:~/$ conda activate jax-cfd-trial

(jax-cfd-trial) user@host:~/$ pip install jaxlib jax jax-cfd
Collecting jaxlib
  Using cached jaxlib-0.4.26-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting jax
  Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB)
Collecting jax-cfd
  Using cached jax_cfd-0.2.0-py3-none-any.whl.metadata (1.4 kB)
Collecting scipy>=1.9 (from jaxlib)
  Using cached scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting numpy>=1.22 (from jaxlib)
  Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting ml-dtypes>=0.2.0 (from jaxlib)
  Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting opt-einsum (from jax)
  Using cached opt_einsum-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Collecting tree-math (from jax-cfd)
  Using cached tree_math-0.2.1-py3-none-any.whl.metadata (477 bytes)
Using cached jaxlib-0.4.26-cp310-cp310-manylinux2014_x86_64.whl (78.8 MB)
Using cached jax-0.4.26-py3-none-any.whl (1.9 MB)
Using cached jax_cfd-0.2.0-py3-none-any.whl (197 kB)
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
Using cached scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.6 MB)
Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Using cached tree_math-0.2.1-py3-none-any.whl (21 kB)
Installing collected packages: numpy, scipy, opt-einsum, ml-dtypes, jaxlib, jax, tree-math, jax-cfd
Successfully installed jax-0.4.26 jax-cfd-0.2.0 jaxlib-0.4.26 ml-dtypes-0.4.0 numpy-1.26.4 opt-einsum-3.3.0 scipy-1.13.0 tree-math-0.2.1

(jax-cfd-trial) user@host:~/$ conda env export
name: jax-cfd-trial
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h5eee18b_5
  - ca-certificates=2024.3.11=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.13=h7f8727e_0
  - pip=23.3.1=py310h06a4308_0
  - python=3.10.14=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - setuptools=68.2.2=py310h06a4308_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - tzdata=2024a=h04d1e81_0
  - wheel=0.41.2=py310h06a4308_0
  - xz=5.4.6=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
      - jax==0.4.26
      - jax-cfd==0.2.0
      - jaxlib==0.4.26
      - ml-dtypes==0.4.0
      - numpy==1.26.4
      - opt-einsum==3.3.0
      - scipy==1.13.0
      - tree-math==0.2.1
prefix: /home/user/miniconda3/envs/jax-cfd-trial


(jax-cfd-trial) user@host:~/$ python -c "import jax_cfd"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/__init__.py", line 19, in <module>
    import jax_cfd.base
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/__init__.py", line 17, in <module>
    import jax_cfd.base.advection
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/advection.py", line 20, in <module>
    from jax_cfd.base import boundaries
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/boundaries.py", line 20, in <module>
    from jax_cfd.base import grids
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/grids.py", line 25, in <module>
    from jax_cfd.base import array_utils
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax_cfd/base/array_utils.py", line 30, in <module>
    Array = Union[np.ndarray, jnp.DeviceArray]
  File "/home/user/miniconda3/envs/jax-cfd-trial/lib/python3.10/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

Seems like the package is using outdated dependencies from jax. Should I just downgrade the package to an older version?

Thanks for the help!

Modelling turbulent channel flow

Using the 2D channel flow notebook structure I have tried making some changes, starting by changing the Re of the flow, but it remains laminar. Is there a setting preventing the transition to turbulent that I am missing?

Thank you

trajectory function use target_trajectory as input in loss_and_gradient

Hi,

I'm trying to understand the logic behind the various functions in ml.train_utils.

in the definition of loss_and_gradient the returned _loss function uses the target trajectory as part of the input of the trajectory_fn. I cannot make sense of it. The description of the function states that the trajectory_fn should accepts params and initial_velocity, which make sense to me, but then I don't understand why would we want to use the target_trajectory as the initial_velocity (which presumably doesn't have the same shape as initial_velocity since one is a trajectory and the other the velocity at a single time step).

I'm pretty much stuck here because I now don't know how to use this in the context of defining a train_step function etc.

Best regards,

Gwen

i need jaxlib source

Hi, I'm so excited to see your work, it's so exciting! I want to implement your project on my own platform system, but since there is no jaxlib source code on the python official website, the compiled jaxlib is not suitable for my working platform. I would be very grateful if you can provide the source code of jaxlib

Input data and training process

I have run your program through your ipynb, but I can't find where is the data input? Are the nc and kpl files downloaded through gsutils? I would like to try out some other simple examples with your method. In addition, is the method of using machine learning in base/ml? I see that ml_model_inference_demo.ipynb uses ml, I hope to get your help, thank you very much.

AttributeError: AxesImage.set() got an unexpected keyword argument 'msg'

Plot the x-velocity

ds.pipe(lambda ds: ds.u).plot.imshow(
'x', 'y', col='time',cmap=seaborn.cm.rocket, robust=True, col_wrap=4, aspect=2);

Solution:

Plot the x-velocity

ds.pipe(lambda ds: ds.u).plot.imshow(x='x',y='y',col='time',cmap=seaborn.cm.rocket, robust=True, col_wrap=4, aspect=2);

3d modeling?

I haven't read through the paper yet but I figured I'm inquire, is this research limited in scope to just 2d or are tensor based calculations appropriate and practical for 3d modeling as well?

Tests fail upon fresh install.

Upon following the instructions to setup a development environment, i.e:

git clone https://github.com/google/jax-cfd; cd jax-cfd
python -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install jaxlib
pip install -e ".[complete]"

and running the test suite, as per:

python -m pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py

the test-suite fails with:

11 failed, 447 passed, 242 warnings, 4 errors 

I've attached a full output of the tests here.

Quick note on running tests: the instructions in the README.md don't run if you set the environment up like I did above. The modification with python -m prefixing the command is required.

Are there any plans to fix these to ensure that all tests pass?

NameError: name 'Grid' is not defined"

Hi! I have a problem and sincerely hope to get your help!
Problem description: I wrote a python program myself to implement some functions of the project, but I encountered a problem in the first step.
In grids.py, line 66, grid: Grid, I didn't see the definition or import of Grid. In the end, I implemented the def velocity_trajectory_to_xarray() function of xarray_utils.py of the data submodule in the python file I wrote and reported an error " NameError: name 'Grid' is not defined", what should I do about this?

[Installation] Failed to import spectral module

Hi,

Thanks for this amazing library!
I installed the module using the command given in the README pip install jax-cfd[complete] to have access to all submodules.

When I try to import the spectral module via import jax_cfd.spectral as spectral I get the following error message:

import jax_cfd.spectral as spectral
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ModuleNotFoundError: No module named 'jax_cfd.spectral'

Some more information about my system and the environment:

  • Python version: Python 3.9.7 (default, Sep 16 2021, 13:09:58)
  • JAX version: 0.2.21
  • JAX CFD version: 0.1.0
  • OS: Linux 5.10.16.3-microsoft-standard-WSL2

Is there any way that I can resolve this issue?

Thanks in advance!

Block-Structured NS LES/WMLES code

Hello!

First, thanks for sharing this amazing project. Second, sorry if this isn’t the right place to ask this question.

I would like to know if there is an ongoing effort to develop a block structured NS-FV LES/WMLES code able to simulate academic relevant flows, i.e turbulent channel flow, NASA Hump etc.

If yes, could someone please point me to the right direction? If no, do you think the current code could be used as a starting point?

Best,
Eduardo

JaxNumPy functions for GridArrays/GridVariables

Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.

Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:

import jax_cfd.base.grids as gd
import jax.numpy as jnp

grid = gd.Grid([4,], domain = [(0, 1),])

array_of_values = jnp.array([2.0, 2.0, 3.0, 4.0])

centered_array = grid.center(array_of_values)

print(centered_array + centered_array)

But this throws an exception:

bc = gd.BoundaryConditions((gd.PERIODIC,))

centered_variable = gd.GridVariable(centered, bc)

print(centered_variable + centered_variable)

I'm happy to have a go at implementing this myself, if someone isn't already working on it.

Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:

print(jnp.abs(centered_array))

But this works:

import numpy as np
print(np.abs(centered_array))

I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.

Request: Example notebook showing training of a hybrid model

The readme says that an example notebook showing how to train a simple hybrid ML + CFD model is being prepared. Is there an ETA on this?
I'm trying to get it to work using JAX-CFD and Flax but I'm running into problems. An example notebook would help me a lot!

Colab

Any plan to move to colab?

Add extra field variables as output for the step function

Current

@jax.named_call
def navier_stokes_step(v: GridVariableVector) -> GridVariableVector:
"""Computes state at time `t + dt` using first order time integration."""
# Collect the acceleration terms
convection = convect(v)
accelerations = [convection]
if viscosity is not None:
diffusion_ = tuple(diffuse(u, viscosity / density) for u in v)
accelerations.append(diffusion_)
if forcing is not None:
# TODO(shoyer): include time in state?
force = forcing(v)
accelerations.append(tuple(f / density for f in force))
dvdt = sum_fields(*accelerations)
# Update v by taking a time step
v = tuple(
grids.GridVariable(u.array + dudt * dt, u.bc)
for u, dudt in zip(v, dvdt))
# Pressure projection to incompressible velocity field
v = pressure_projection(v, pressure_solve)
return v
return navier_stokes_step

For example, for this function, only the velocity variable is the output, in the function call, it is like this:

step_fn = funcutils.repeated(
      collocated.equations.semi_implicit_navier_stokes(
          density=density, viscosity=viscosity, dt=dt, grid=grid),
      steps=inner_steps)
  rollout_fn = jax.jit(funcutils.trajectory(step_fn, outer_steps))
  _, trajectory = jax.device_get(rollout_fn(v0))

where v0 is a GridVariableVector.

My temp hack

Let us say if we want the time derivative to be the output as well, this is my ugly way to do it:
basically it concats a dummy tensor after the velocity

@jax.named_call 
 def navier_stokes_step(v: GridVariableVector) -> GridVariableVector: 
   v, _ = (v[0], v[1]), (v[2], v[3])
   ...
   dvdt = sum_fields(*accelerations) 
   ...
   v = pressure_projection(v, pressure_solve)
   dvdt = tuple(grids.GridVariable(dudt, u.bc) for u, dudt in zip(v, dvdt)) 
   return v+dvdt 
 return navier_stokes_step 

then append a dummy GridVariable after v0, and call the same step function in the rollout,.

vt0 = tuple(grids.GridVariable(grids.GridArray(jnp.zeros_like(u.data), u.offset, grid),
                         u.bc) for u in v0)
v0 += vt0

Question

is there any template in jax to do this?

[Error] Channel flow demo

I'm trying to run channel_flow_demo.ipynb, but I get the following error when I call the step_fn :

(massive traceback truncated)

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/advection.py in (.0)
107 for u, target_offset in zip(v, target_offsets))
108 aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
--> 109 for target_offset in target_offsets)
110 return _advect_aligned(aligned_c, aligned_v)
111

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/interpolation.py in tvd_interpolation(c, offset, v, dt)
282 c_left = c.shift(-1, axis)
283 c_right = c.shift(1, axis)
--> 284 c_next_right = c.shift(2, axis)
285 # Velocities of different sign are evaluated with limiters at different
286 # points. See equations (4.34) -- (4.39) from the reference above.

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/grids.py in shift(self, offset, axis)
269 GridArray has offset u.offset + offset.
270 """
--> 271 return self.bc.shift(self.array, offset, axis)
272
273 def _interior_grid(self) -> Grid:

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/boundaries.py in shift(self, u, offset, axis)
77 u.offset + offset.
78 """
---> 79 padded = self._pad(u, offset, axis)
80 trimmed = self._trim(padded, -offset, axis)
81 return trimmed

/usr/local/lib/python3.7/dist-packages/jax_cfd/base/boundaries.py in _pad(self, u, width, axis)
116 if bc_type != BCType.PERIODIC and abs(width) > 1:
117 raise ValueError(
--> 118 'Padding past 1 ghost cell is not defined in nonperiodic case.')
119
120 if bc_type == BCType.PERIODIC:

ValueError: Padding past 1 ghost cell is not defined in nonperiodic case.

I installed JAX-CFD on a fresh Colab notebook using pip install git+https://github.com/google/jax-cfd.git

Please help me fix this. Let me know if I'm using the wrong version (since the repo is constantly being developed).

Thanks!

Question about the random intial velocity in notebook spectral_forced_turbulence.ipynb

Hello,

This is a very nice JAX implementation for CFDs! I have a few questions about your latest released notebook spectral_forced_turbulence.ipynb.

  1. What is the exact NS equation that the notebook aims to solve? Would it be possible to provide some detailed expressions of NS equations and boundary conditions? Or would it be possible to provide some references? I think It will be super helpful to people like me who are more familiar with machine learning and know less about CFD.

  2. If I understand correctly, I think the code aims to solve the NS equation (vorticity-velocity form) using the pseudo-spectral method. So we need to enforce periodic boundary conditions for the velocity (or vorticity?) However, when I try to run the following piece of code (which should give me a random velocity field and visualize it at the boundary)

v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 4)
u = v0[0].data
v = v0[1].data

plt.plot(u[:, 0])
plt.plot(u[:, -1])
plt.show()

plt.plot(u[0, :])
plt.plot(u[-1, :])
plt.show()

I got the following figures below. It seems that the periodic boundary condition is not fully imposed. Could you please take a look or correct me if I am wrong?

bc
bc2

Once again, this is really a project! Looking forward to your reply.

Boundary Conditions

Hi,

I tried to define some boundary conditions. Until recently, everything went perfectly fine with

import jax_cfd.base.grids as grids
bc = grids.BoundaryConditions((grids.PERIODIC, grids.PERIODIC))
Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'jax_cfd.base.grids' has no attribute 'BoundaryConditions'

I saw that you moved the boundary conditions to a new module called boundaries.py. I tried then to access them accordingly from the new module via:

import jax_cfd.base as cfd
bc = cfd.boundaries.periodic_boundary_conditions(ndim=2)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'jax_cfd.base' has no attribute 'boundaries'

But now, neither of these work anymore (see the appended error messages) and I wonder how I can define boundary conditions.
I tried installing JAX-CFD according to the documentation and according to #90, but neither of them gave me a solution to my problem.

My python version is 3.8.

Python 3.8.12 (default, Oct 12 2021, 13:49:34) 
[GCC 7.5.0] :: Anaconda, Inc. on linux

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.