Git Product home page Git Product logo

torch_qg's Introduction

torch_qg

pyqg simulation animation

Modelling of a 2 layer quasi-geostrophic system in PyTorch. We use PyTorch 2.0.0. This repo is based on pyqg, with a few changes. We implement all components of the numerical scheme in PyTorch, such that the simulation is end-to-end differentiable. Additionally, whilst pyqg uses a pseudo-spectral method to evolve the system forward in time, we include a real-space time-stepper, with an Arakawa advection scheme.

Core dependencies can be found in tests/requirements.txt. After git cloning the repo, install by running

pip install .

or to run in editable mode, run

pip install -e .

torch_qg's People

Contributors

chris-pedersen avatar

Stargazers

Shubham avatar

Watchers

 avatar

torch_qg's Issues

Seg faults when running on slurm

Getting this quite commonly when running on slurm, but not through ood. Often can be a memory issue - but this latest one happened right at script start:

WARNING: Could not find any nv files on this host!
/ext3/miniconda3/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/bin/bash: line 1: 1999515 Segmentation fault      (core dumped) python3 /home/cp3759/Projects/torch_qg/scripts/gen_L_sims.py

Then another error for the smagorinsky sims:

WARNING: Could not find any nv files on this host!
/ext3/miniconda3/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
corrupted double-linked list
/bin/bash: line 1: 2532150 Aborted                 (core dumped) python3 /home/cp3759/Projects/torch_qg/scripts/gen_smag_sims.py

Low res simulations dominated by zonal flow

High resolution simulations visually look very similar to pyqg. However low-res appears different - the eddies are pretty static and just dominated by zonal flow. Perhaps there is some bug in the code making mean flow effects stronger, or is this a consequence of the different numerical scheme? Either way need to understand this. Start by verifying the forcing terms are all implemented correctly - we have already verified that the advection is correct and have tests for this.

Import order can crash kernel

Importing pyqg before torch, or importing xarray before torch crashes the ood kernel. When running on command line I get:

[cp3759@cm005 torch_qg]$ sing
coSingularity> cond
Singularity> python3
Python 3.9.12 (main, Apr  5 2022, 06:56:58) 
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pyqg
>>> import torch_qg.model as torch_model
>>> import torch
>>> m32=torch_model.PseudoSpectralModel(nx=256,dtype=torch.float32)
CUDA Not Available, using CPU
/ext3/miniconda3/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
python3: symbol lookup error: /ext3/miniconda3/lib/python3.9/site-packages/mkl/../../../libmkl_intel_thread.so.1: undefined symbol: __kmpc_global_thread_num

I think its to do with xarray/torch conflicts somewhere:

Singularity> python3
Python 3.9.12 (main, Apr  5 2022, 06:56:58) 
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import xarray as xr
>>> import torch
>>> quit()
Singularity> python3
Python 3.9.12 (main, Apr  5 2022, 06:56:58) 
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import xarray as xr
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/ext3/miniconda3/lib/python3.9/site-packages/xarray/__init__.py", line 3, in <module>
    from xarray import testing, tutorial
  File "/ext3/miniconda3/lib/python3.9/site-packages/xarray/testing.py", line 8, in <module>
    import pandas as pd
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/__init__.py", line 48, in <module>
    from pandas.core.api import (
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/api.py", line 48, in <module>
    from pandas.core.groupby import (
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/groupby/__init__.py", line 1, in <module>
    from pandas.core.groupby.generic import (
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/groupby/generic.py", line 70, in <module>
    from pandas.core.frame import DataFrame
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/frame.py", line 157, in <module>
    from pandas.core.generic import NDFrame
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/generic.py", line 152, in <module>
    from pandas.core.window import (
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/window/__init__.py", line 1, in <module>
    from pandas.core.window.ewm import (  # noqa:F401
  File "/ext3/miniconda3/lib/python3.9/site-packages/pandas/core/window/ewm.py", line 12, in <module>
    import pandas._libs.window.aggregations as window_aggregations
ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /ext3/miniconda3/lib/python3.9/site-packages/pandas/_libs/window/aggregations.cpython-39-x86_64-linux-gnu.so)

Speed up Smagorinsky parameterisation

We have Smagorinsky implemented and working. However we are taking a large number of ffts:

        uh=-il*ph
        vh=ik*ph
        Sxx = torch.fft.irfftn(uh*ik,dim=(1,2))
        Syy = torch.fft.irfftn(vh*il,dim=(1,2))
        Sxy = 0.5 * torch.fft.irfftn(uh * il + vh * ik)
        nu = (self.constant * dx)**2 * torch.sqrt(2 * (Sxx**2 + Syy**2 + 2 * Sxy**2))
        nu_Sxxh = torch.fft.rfftn(nu * Sxx,dim=(1,2))
        nu_Sxyh = torch.fft.rfftn(nu * Sxy,dim=(1,2))
        nu_Syyh = torch.fft.rfftn(nu * Syy,dim=(1,2))
        du = 2 * (torch.fft.irfftn(nu_Sxxh * ik,dim=(1,2)) + torch.fft.irfftn(nu_Sxyh * il,dim=(1,2)))
        dv = 2 * (torch.fft.irfftn(nu_Sxyh * ik,dim=(1,2)) + torch.fft.irfftn(nu_Syyh * il,dim=(1,2)))
        ## Take curl to convert u, v forcing to potential vorticity forcing
        dq = -torch.fft.irfftn(il*torch.fft.rfftn(du,dim=(1,2)),dim=(1,2))+torch.fft.irfftn(ik*torch.fft.rfftn(dv,dim=(1,2)),dim=(1,2))

Given the relationship between the velocities and the streamfunction, and that we are mostly just taking derivatives here, there should be some speedups possible. And given that we need this Smag even for our "DNS", we should make this as efficient as possible.

Spin-up phase doesn't form eddies in Arakawa scheme

May be related to #8 , there's some bug somewhere in the solver. When starting from the same initialisation procedure as pyqg, with these vertical bands and small random noise, no matter how long I evolve our torch system, the system does not form eddies - the dynamics just continue varying the pv in vertical bands. This is only the case in the Arakawa advection scheme - the spectral solver does form eddies (although they explode, but that's another issue).

Add 2/3 dealiasing filter for pseudospectral system

We see significant difference compared to pyqg for low res sims. We have only changed the advection scheme and removed the exponential filter, so it would be nice to know which of thesse two modifications caused this (particularly now as the backscatter is captured even with nx=48 simulations, which is a dramatic difference to pyqg).

Pseudospectral method with 2/3 dealiasing should be equivalent to Arakawa - this scheme also conserves PV. We would need to add Smagorinsky here again, in place of the exponential filter, but it would be interesting to see if that system produces similar spectral characteristics to the Arakawa + smag. Pavel put a PR for pyqg here.

Test Euler vs AB3 over short trajectories

Came up in the NYU-ML meeting, that when Karl was running his short trajectories, the low res sim would have to be initialised from a fresh set of potential vorticities, which would mean the first step was Euler, then AB2, then AB3. Test how much of a difference that makes, also whether we can just run Euler for ~25 timesteps or so (would be a significantly cheaper computational graph for online training)

Add typing

Once the structure is a bit more established, add typing to class/function arguments to improve stability

Float32 vs 64

When it comes to running on gpu, we'd be a lot more efficient with float32. In the current main branch, the wavenumbers (used to calculate the determinant of the elliptic eq inversion) are float32, and changing this doesn't impact the spectra/statistics of high res fields, despite having some impact on the matrix inversion:

Screenshot 2024-01-12 at 4 35 17 PM

Screenshot 2024-01-12 at 4 35 27 PM

Since this is the most sensitive part of the code to numerical precision, if this isn't impacted, maybe we can just run in 32? Lets compare pyqg64, torchqg64, torchqg32 and see if the spectra are different.

Add online training

Made a start on this in #35. One consideration will be the renormalisation that goes on in the CNN parameterisation, and whether this will lead to numerical issues in the gradient flow. I guess let's just try with the most straightforward set up, and see if we see some kind of convergence. If not, this would be one place to look.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.