Git Product home page Git Product logo

smsharma / point-cloud-galaxy-diffusion Goto Github PK

View Code? Open in Web Editor NEW
17.0 3.0 3.0 400.13 MB

Transformer-guided diffusion for galaxy clustering. Code repository associated with https://arxiv.org/abs/2311.17141

License: MIT License

Python 1.57% Jupyter Notebook 98.41% Shell 0.02%
diffusion variational-diffusion sets transformer generative-model diffusion-models dark-matter gnn graph-neural-network jax

point-cloud-galaxy-diffusion's Introduction

A point cloud approach to field level generative modeling

Carolina Cuesta-Lazaro and Siddharth Mishra-Sharma

License: MIT arXiv

Figure-Light Figure-Dark

Contents

Abstract

We introduce a diffusion-based generative model to describe the distribution of galaxies in our Universe directly as a collection of points in 3-D space (coordinates) optionally with associated attributes (e.g., velocities and masses), without resorting to binning or voxelization. The custom diffusion model can be used both for emulation, reproducing essential summary statistics of the galaxy distribution, as well as inference, by computing the conditional likelihood of a galaxy field. We demonstrate a first application to massive dark matter haloes in the Quijote simulation suite. This approach can be extended to enable a comprehensive analysis of cosmological data, circumventing limitations inherent to summary statistics- as well as neural simulation-based inference methods.

Dependencies

The Python environment is defined in environment.yml. To create the environment run e.g.,

mamba env create --file environment.yaml

For evaluation of the nbody dataset, Corrfunc is needed:

python -m pip install git+https://github.com/cosmodesi/pycorr#egg=pycorr[corrfunc]

Dataset

The processed dark matter halo features from the Quijote simulations used to train the model can be found here. Make sure to update the hard-coded DATA_DIR in datasets.py to point to the location of the dataset before training.

Code overview

Running the code

With the dataset in place, the diffusion model can be trained via

python train.py --config ./configs/nbody.py

which is called from scripts/submit_train.sh. The config file ./configs/nbody.py (which sets diffusion, score model, and dataset configuration) can be edited accordingly. Similarly, scripts/submit_infer.sh computes the likelihood profiles for the trained model, calling infer.py.

The notebooks directory contains notebooks used to produce results for the paper, each linked from the respective figures.

Diffusion model basic usage

For standalone usage, the following can be used to compute the variational lower bound loss and sample from the model:

import jax
import jax.numpy as np

from flax.core import FrozenDict

from models.diffusion import VariationalDiffusionModel
from models.diffusion_utils import generate, loss_vdm

# Transformer (score model) args
score_dict = FrozenDict({"d_model":256, "d_mlp":512, "n_layers":5, "n_heads":4, "induced_attention":False, "n_inducing_points":32})

# Instantiate model
vdm = VariationalDiffusionModel(gamma_min=-6.0, gamma_max=6.0,  # Min and max initial log-SNR in the noise schedule
          d_feature=4,  # Number of features per set element, e.g. 7 for (x, y, z, vx, vy, vz, m)
          score="transformer",  # Score model; "transformer", "graph"
          score_dict=score_dict,  # Score-prediction transformer parameters
          noise_schedule="learned_linear",  # Noise schedule; "learned_linear", "learned_net" (monotonic neural network), or "linear" (fixed)
          embed_context=False,  # Whether to embed context vector.
          timesteps=0,  # Number of diffusion steps; set 0 for continuous-time version of variational lower bound
          d_t_embedding=16,  # Timestep embedding dimension
          noise_scale=1e-3,  # Data noise model
          n_pos_features=3,  # Number of positional features, for graph-building
        )

rng = jax.random.PRNGKey(42)

x = jax.random.normal(rng, (32, 100, 4))  # Input set, (batch_size, max_set_size, num_features)
mask = jax.random.randint(rng, (32, 100), 0, 2)  # Optional set mask, (batch_size, max_set_size); can be `None`
conditioning = jax.random.normal(rng, (32, 6))  # Optional conditioning context, (batch_size, context_size); can be `None`

# Call to get losses
(loss_diff, loss_klz, loss_recon), params = vdm.init_with_output({"sample": rng, "params": rng}, x, conditioning, mask)

# Compute full loss, accounting for masking
loss = loss_vdm(params, vdm, rng, x, conditioning, mask)

# Sample from model

mask_sample = jax.random.randint(rng, (24, 100), 0, 2)
conditioning_sample = jax.random.normal(rng, (24, 6))

x_samples = generate(vdm, params, rng, (24, 100), conditioning_sample, mask_sample)
x_samples.mean().shape  # Mean of decoded Normal distribution -- (24, 100, 4)

Citation

If you use this code, please cite our paper:

@article{Cuesta-Lazaro:2023zuk,
    author = "Cuesta-Lazaro, Carolina and Mishra-Sharma, Siddharth",
    title = "{A point cloud approach to generative modeling for galaxy surveys at the field level}",
    eprint = "2311.17141",
    archivePrefix = "arXiv",
    primaryClass = "astro-ph.CO",
    reportNumber = "MIT-CTP/5651",
    month = "11",
    year = "2023"
}

point-cloud-galaxy-diffusion's People

Contributors

florpi avatar smsharma avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

point-cloud-galaxy-diffusion's Issues

Tasks for inference

  • Switch to 8x dataset
  • Do a coverage test with SVI
  • Move infer.py to root, have infer_utils.py
  • Validate with HMC or NS
  • Check how Gaussian the likelihood is

Possible improvements to GNN

  • The edge updates are not being propagated at the moment through the skip connections (only node updates are) -- should they?
  • Attention implementation (e.g., softmax not needed in attention logits fn?).

Improve GNN architectures

  • Fix equivariant norm in the EGNN
  • Add attention to both EGNN and GNN
  • Unify EGNN and GNN (make them the same function, or make the functions more similar)
  • Add random graphs
  • Add globals to the EGNN node updates
  • Understand how globals are treated -- treat as scalars?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.