Git Product home page Git Product logo

mpgan's Introduction

MPGAN & GAPT

pre-commit.ci status Codestyle DOI

Code for models in:

[1] Kansal et. al., Graph Generative Adversarial Networks for Sparse Data Generation in High Energy Physics, ML4PS @ NeurIPS 2020 arXiv:2012.00173
[2] Kansal et. al., Particle Cloud Generation with Message Passing Generative Adversarial Networks, NeurIPS 2021 arXiv:2106.11535
[3] Kansal et. al., On the Evaluation of Generative Models in High Energy Physics, arXiv:2211.10295

Overview

This repository contains PyTorch code for the message-passing GAN (MPGAN) and generative adversarial particle transformer (GAPT) models, as well as scripts for training the models from scratch, generating and plotting the particle clouds. We include also weights of fully trained models discussed in [2].

Additionally, we release the standalone JetNet library, which provides a PyTorch Dataset class for our JetNet dataset, implementations of the evaluation metrics discussed in the paper, and some more useful utilities for development in machine learning + jets.

For the exact code and scripts used for [2], please see the neurips21 branch.

Talks

A complete list of talks can be found here.

Dependencies

MPGAN and GAPT Models

  • torch >= 1.8.0

Training, Plotting, Evaluation

  • torch >= 1.8.0
  • jetnet >= 0.2.1
  • numpy >= 1.21.0
  • matplotlib
  • mplhep

Can be installed via pip install -r requirements.txt.

External models also require

  • torch
  • torch_geometric

A Docker image containing all necessary libraries can be found here (Dockerfile).

Training

Start training with:

python train.py --name test_model --model [model] --jets [jets] [args]  

where model can be specified as mpgan or gapt, and jets can be any out of ['g', 't', 'q', 'w', 'z'].

By default, model parameters, figures of particle and jet features, and plots of the training losses and evaluation metrics over time will be saved every five epochs in an automatically created outputs/[name] directory.

Some notes:

  • Will run on a GPU by default if available.
  • The default arguments correspond to the final model architecture and training configuration used in the paper.
  • Run python train.py --help or look at setup_training.py for a full list of arguments.
  • For protoyping purposes, models can also be trained on a 'sparsified' MNIST point cloud dataset as in [1] using train_mnist.py. Normal MNIST training and testing .csvs need to be downloaded and pointed to with the --datasets-path arg.

Generation

Pre-trained generators with saved state dictionaries and arguments can be used to generate samples with, for example:

python gen.py --G-state-dict trained_models/mp_g/G_best_epoch.pt --G-args trained_models/mp_g/args.txt --num-samples 50,000 --output-file trained_models/mp_g/gen_jets.npy

mpgan's People

Contributors

rkansal47 avatar annili1212 avatar pre-commit-ci[bot] avatar

Stargazers

Zhaoyu Zhang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.