Git Product home page Git Product logo

en_flows's Introduction

E(n) Equivariant Normalizing Flows

Official implementation (Pytorch 1.7.1) of:

E(n) Equivariant Normalizing Flows
Victor Garcia Satorras*, Emiel Hogeboom*, Fabian Fuchs, Ingmar Posner, Max Welling
https://arxiv.org/abs/2105.09016

Abstract: This paper introduces a generative model equivariant to Euclidean symmetries: E(n) Equivariant Normalizing Flows (E-NFs). To construct E-NFs, we take the discriminative E(n) graph neural networks and integrate them as a differential equation to obtain an invertible equivariant function: a continuous-time normalizing flow. We demonstrate that E-NFs considerably outperform baselines and existing methods from the literature on particle systems such as DW4 and LJ13, and on molecules from QM9 in terms of log-likelihood. To the best of our knowledge, this is the first flow that jointly generates molecule features and positions in 3D.

Requirements

  • Pytorch 1.7.1
  • wandb (Weights and Biases)
  • rdkit-pypi (Only for evaluation scripts)

You can use your Weights & Biases username by adding --wandb_usr <your username> to the experiments

DW4 commands

GNN

EXP_NAME=exp_dw4_gnn_noatt_noaug python -u main_dw4_lj13.py --exp_name $EXP_NAME --data dw4 --model gnn_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False --attention False

GNN attention

EXP_NAME=exp_dw4_gnn_att_noaug python -u main_dw4_lj13.py --exp_name $EXP_NAME --data dw4 --model gnn_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False

GNN attention and augmentation

EXP_NAME=exp_dw4_gnn_att_aug python -u main_dw4_lj13.py --exp_name $EXP_NAME --data dw4 --model gnn_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation True

Simple dynamics.

EXP_NAME=exp_dw4_simple_dynamics python -u main_dw4_lj13.py --exp_name $EXP_NAME --data dw4 --model simple_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False

Kernel dynamics

EXP_NAME=exp_dw4_kernel_dynamics python -u main_dw4_lj13.py --exp_name $EXP_NAME --data dw4 --model kernel_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False

E-NF

EXP_NAME=exp_dw4_egnn python -u main_dw4_lj13.py --exp_name $EXP_NAME --data dw4 --model egnn_dynamics --sweep_n_data True --lr 5e-4 --n_layers 3 --nf 32

LJ13 dataset

GNN

EXP_NAME=exp_lj13_gnn_noatt_noaug python -u main_dw4_lj13.py --exp_name $EXP_NAME --data lj13 --model gnn_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False --attention False

GNN attention

EXP_NAME=exp_lj13_gnn_att python -u main_dw4_lj13.py --exp_name $EXP_NAME --data lj13 --model gnn_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False

GNN attention and augmentation

EXP_NAME=exp_lj13_gnn_att_aug_a python -u main_dw4_lj13.py --exp_name $EXP_NAME --data lj13 --model gnn_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation True

Simple dynamics

EXP_NAME=exp_lj13_simple_dynamics python -u main_dw4_lj13.py --exp_name $EXP_NAME --data lj13 --model simple_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False

Kernel dynamics

EXP_NAME=exp_lj13_kernel_dynamics python -u main_dw4_lj13.py --exp_name $EXP_NAME --data lj13 --model kernel_dynamics --sweep_n_data True --lr 1e-3 --n_layers 3 --nf 32 --data_augmentation False

EGNN

EXP_NAME=exp_lj13_egnn python -u main_dw4_lj13.py --exp_name $EXP_NAME --data lj13 --model egnn_dynamics --sweep_n_data True --lr 5e-4 --n_layers 3 --nf 32

QM9 Positional

E-NF

EXP_NAME=qm9pos_exp_27_egnn_nf64 python -u main_qm9_pos.py --exp_name $EXP_NAME --model egnn_dynamics --nf 64

GNN

EXP_NAME=qm9pos_exp_27_gnn_noatt_nf64 python -u main_qm9_pos.py --exp_name $EXP_NAME --model gnn_dynamics --nf 64 --lr 5e-4 --attention False 

GNN attention

EXP_NAME=qm9pos_exp_27_gnn_att_nf64 python -u main_qm9_pos.py --exp_name $EXP_NAME --model gnn_dynamics --nf 64 --lr 5e-4 --save_model True 

GNN attention and augmentation

EXP_NAME=qm9pos_exp_27_gnn_att_aug_nf64 python -u main_qm9_pos.py --exp_name $EXP_NAME --model gnn_dynamics --nf 64  --lr 5e-4 --save_model True

Kernel dynamics

EXP_NAME=qm9pos_exp_27_kernel_dynamics python -u main_qm9_pos.py --exp_name $EXP_NAME --model kernel_dynamics --lr 5e-4

Simple dynamics

EXP_NAME=qm9pos_exp_27_simple_dynamics_lr2e4 python -u main_qm9_pos.py --exp_name $EXP_NAME --model simple_dynamics --lr 2e-4 

QM9

To train E-NF on QM9

python main_qm9.py --nf 64 --ode_regularization 0.001 --batch_size 128 --dequantizer argmax_variational --lr 5e-4 --exp_name qm9_enf --model egnn_dynamics

To train the best baseline on QM9 (GNN attention and augmentation)

python main_qm9.py --nf 64 --ode_regularization 0.001 --batch_size 128 --dequantizer argmax_variational --lr 5e-4 --exp_name qm9_baseline_augmentation --model gnn_dynamics --data_augmentation True

To resume the training on qm9 in case it diverges, you can use the same commands as before but adding two argumens. You will get experiments/<exp_name>_resume_resume --resume experiments/<exp_name> --start_epoch <last_saved_epoch>

To evaluate a pre-trained model on QM9 (n_samples=10000 for the reported numbers in the paper)

python -u eval_analyze_qm9.py --model_path outputs/en_flows_pretrained --n_samples 1000 --val_novel_unique False

To generate samples from a pre-trained model run

python -u eval_sample_qm9.py --model_path outputs/en_flows_pretrained

Acknowledgements

The Robert Bosch GmbH is acknowledged for financial support.

en_flows's People

Contributors

vgsatorras avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

en_flows's Issues

why 'log_px = (log_pz + delta_logp - log_qh_x + log_pN)'?

in the compute loss code you define log_px = (log_pz + delta_logp - log_qh_x + log_pN), but in the original paper of ffjord: 'log_pz = log_pz - delta_logp '

is there a mistake? or did I miss something in the loss computing code

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.