Git Product home page Git Product logo

chemprop-atom-bond's Introduction

ChemProp for atomic and bond property predictions

This repository contains multitask constraint message passing neural networks for atomic/bond property predictions as described in the paper Regio-Selectivity Prediction with a Machine-Learned Reaction Representation and On-the-Fly Quantum Mechanical Descriptors. This network is modeled after ChemProp as described in the paper Analyzing Learned Molecular Representations for Property Prediction.

Table of Contents

Requirements

While it is possible to run all of the code on a CPU-only machine, GPUs make training significantly faster. To run with GPUs, you will need:

  • cuda >= 8.0
  • cuDNN

Installation

Option 1: Conda

The easiest way to install the chemprop dependencies is via conda. Here are the steps:

  1. Install Miniconda from https://conda.io/miniconda.html
  2. cd /path/to/chemprop
  3. conda env create -f environment.yml
  4. source activate chemprop (or conda activate chemprop for newer versions of conda)
  5. (Optional) pip install git+https://github.com/bp-kelley/descriptastorus

The optional descriptastorus package is only necessary if you plan to incorporate computed RDKit features into your model (see Additional Features). The addition of these features improves model performance on some datasets but is not necessary for the base model.

Note that on machines with GPUs, you may need to manually install a GPU-enabled version of PyTorch by following the instructions here.

Data

In order to train the model, you must provide training data containing molecules (as SMILES strings) and known atomic/bond target values. Targets must be numpy array for corresponding atomic/bond properties. Our model can train on any number of atomic/bond properties.

The data file must be a pickle file with a header row. For example:

                              smiles                                  hirshfeld_charges  ...                                 bond_length_matrix                                  bond_index_matrix
0     CNC(=S)N/N=C/c1c(O)ccc2ccccc12  [-0.026644, -0.075508, 0.096217, -0.287798, -0...  ...  [[0.0, 1.4372890960937539, 2.4525543850909814,...  [[0.0, 0.9595, 0.0158, 0.0162, 0.0103, 0.0008,...
2      O=C(NCCn1cccc1)c1cccc2ccccc12  [-0.292411, 0.170263, -0.085754, 0.002736, 0.0...  ...  [[0.0, 1.2158509801073485, 2.2520730233154076,...  [[0.0, 1.6334, 0.1799, 0.0086, 0.0068, 0.0002,...
3  C=C(C)[C@H]1C[C@@H]2OO[C@H]1C=C2C  [-0.101749, 0.012339, -0.07947, -0.020027, -0....  ...  [[0.0, 1.3223632546838255, 2.468055985361353, ...  [[0.0, 1.9083, 0.0179, 0.016, 0.0236, 0.001, 0...
4                     OCCCc1cc[nH]n1  [-0.268379, 0.027614, -0.050745, -0.045047, 0....  ...  [[0.0, 1.4018301850170725, 2.4667588956616737,...  [[0.0, 0.9446, 0.0311, 0.002, 0.005, 0.0007, 0...
5      CC(=N)NCc1cccc(CNCc2ccncc2)c1  [-0.083162, 0.114954, -0.274544, -0.100369, 0....  ...  [[0.0, 1.5137126697008916, 2.4882198180715465,...  [[0.0, 1.0036, 0.0437, 0.0108, 0.0134, 0.0004,......

where atomic properties (e.g. hirshfeld_charges) must be a 1D numpy array with the oder same as that of atoms in the SMILES string; and bond properties (e.g. bond_length_matrix) must be a 2D numpy array of shape (number_of_atoms ร— number_of_atoms)

Training

To train a model, run:

python train.py --data_path <path> --atom_targets <atom targets> --bond_targets <bond targets>

where <path> is the path to a CSV file containing a dataset, <atom targets> is a list of atomic targets to train, which should be consistent with the column name in the pickle file and <bond targets> is a list of bond targets to train.

For example:

CUDA_VISIBLE_DEVICES=1 python train.py --log_freq 200 --hidden_size 600 --batch_size 50 --epochs 100 --depth 6 --atom_targets hirshfeld_charges hirshfeld_fukui_neu hirshfeld_fukui_elec NMR --atom_constraints 0 1 1 --bond_targets bond_index_matrix bond_length_matrix --save_smiles_splits --save_dir QM_137k_fukui_out_scope --loss_weights 1 1 1 0.00001 1 1 --data_path data/QM_137k_fukui_scope.pickle --explicit_Hs

Notes:

  • The model allows multi-task constraints applied to different atomic properties by specifying --atom_constraints
  • --explicit_Hs can be used to train/predict based on all-atoms (including H) molecular graph
  • When the scale of different properties are drastically different, --loss_weights flag is suggested, which scale loss function for different targets into similar scales.

Train/Validation/Test Splits

Our code supports random splitting data into train, validation, and test sets.

Random: By default, the data will be split randomly into train, validation, and test sets.

Note: By default, random splits the data into 80% train, 10% validation, and 10% test. This can be changed with --split_sizes <train_frac> <val_frac> <test_frac>. For example, the default setting is --split_sizes 0.8 0.1 0.1. Both also involve a random component and can be seeded with --seed <seed>. The default setting is --seed 0.

Predicting

To load a trained model and make predictions, run predict.py and specify:

  • --test_path <path> Path to the data to predict on.
  • --checkpoint_path <path> Path to a model checkpoint file (.pt file).
  • --preds_path Path where a pickle file containing the predictions will be saved.

For example:

python predict.py -test_path predict.csv --checkpoint_path trained_model/QM_137k.pt

The predict.csv contains SMILES strings to be predicted. For example:

,smiles,compounds_ID
0,CCC,0
1,CCCCC,1
2,CCCCCC,1

Trained model

We provide the checkpoint file for model trained on 137k molecules curated from PubChem and Pistachio, as described in our paper, in this repo (trained_model/QM_137k.pt). The trained model can be used either through predict.py script discussed in the Predicting section, or through the API/torchserve provided in the torchserve folder.

###API The predicting function is wrapped in the torchserve/handler.py script. Please refer to the __name__ == '__main__' section of handler.py for details.

###torchserve The trained model is also accessible through torchserve. A mar file including everything to start a torchserve is provided in torchserve/model_store/descriptors.mar, which is generated via:

torch-model-archiver --model-name descriptors --version 1.0 --serialized-file QM_137k.pt --handler handler.py --export-path model-store --extra-files model.py,mpn.py,ffn.py,featurization,nn_utils.py

To start a torchserve:

torchserve --start --ncs --model-store model_store --models descriptors.mar

chemprop-atom-bond's People

Contributors

lhirschfeld avatar mbergins avatar swansonk14 avatar wengong-jin avatar yanfeiguan avatar yangkevin2 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

mike575 hesther

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.