Git Product home page Git Product logo

3dinfomax's Introduction

3D Infomax improves GNNs for Molecular Property Prediction

We pre-train GNNs to understand the geometry of molecules given only their 2D molecular graph which they can use for better molecular property predictions. Below is a 3 step guide for how to use the code and how to reproduce our results and a guide for creating molecular fingerprints. If you have questions, don't hesitate to open an issue or ask me via [email protected] or social media. I am happy to hear from you!

This repository additionally adapts different self-supervised learning methods to graphs such as "Bootstrap your own Latent", "Barlow Twins", or "VICReg".

Generating fingerprints for arbitrary SMILES

To generate fingerprints that carry 3D information, just set up the environment as in step 1 below, then place your SMILES into the file dataset/inference_smiles.txt and run

python inference.py --config=configs_clean/fingerprint_inference.yml

Your fingerprints are saved as pickle file into the dataset_directory

Step 1: Setup Environment

We will set up the environment using Anaconda. Clone the current repo

git clone https://github.com/HannesStark/3DInfomax

Create a new environment with all required packages using environment.yml (this can take a while). While in the project directory run:

conda env create

Activate the environment

conda activate 3DInfomax

Step 2: 3D Pre-train a model

Let's pre-train a GNN with 50 000 molecules and their structures from the QM9 dataset (you can also skip to Step 3 and use the pre-trained model weights provided in this repo). For other datasets see the Data section below.

python train.py --config=configs_clean/pre-train_QM9.yml

This will first create the processed data of dataset/QM9/qm9.csv with the 3D information in qm9_eV.npz. Then your model starts pre-training and all the logs are saved in the runs folder which will also contain the pre-trained model as best_checkpoint.pt that can later be loaded for fine-tuning.

You can start tensorboard and navigate to localhost:6006 in your browser to monitor the training process:

tensorboard --logdir=runs --port=6006

Explanation:

The config files in configs_clean provide additional examples and blueprints to train different models. The files always contain a model_type that should be pre-trained (2D network) and a model3d_type (3D network) where you can specify the parameters of these networks. To find out more about all the other parameters in the config file, have a look at their description by running python train.py --help.

Step 3: Fine-tune a model

During pre-training a directory is created in the runs directory that contains the pre-trained model. We provide an example of such a directory with already pre-trained weights runs/PNA_qmugs_NTXentMultiplePositives_620000_123_25-08_09-19-52 which we can fine-tune for predicting QM9's homo property as follows.

python train.py --config=configs_clean/tune_QM9_homo.yml

You can monitor the fine-tuning process on tensorboard as well and in the end the results will be printed to the console but also saved in the runs directory that was created for fine-tuning in the file evaluation_test.txt.

The model which we are fine-tuning from is specified in configs_clean/tune_QM9_homo.yml via the parameter:

pretrain_checkpoint: runs/PNA_qmugs_NTXentMultiplePositives_620000_123_25-08_09-19-52/best_checkpoint_35epochs.pt

Multiple seeds:

This is a second fine-tuning example where we predict non-quantum properties of the OGB datasets and train multiple seeds (we always use the seeds 1, 2, 3, 4, 5, 6 in our experiments):

python train.py --config=configs_clean/tune_freesolv.yml

After all runs are done, the averaged results are saved in the runs directory of each seed in the file multiple_seed_test_statistics.txt

Data

You can pre-train or fine-tune on different datasets by specifying the dataset: parameter in a .yml file such as dataset: drugs to use GEOM-Drugs.

The QM9 dataset and the OGB datasets are already provided with this repository. The QMugs and GEOM-Drugs datasets need to be downloaded and placed in the correct location.

GEOM-Drugs: Download GEOM-Drugs here ( the rdkit_folder.tar.gz file), unzip it, and place it into dataset/GEOM.

QMugs: Download QMugs here (the structures.tar and summary.csv files), unzip the structures.tar, and place the resulting structures folder and the summary.csv file into a new folder QMugs that you have to create NEXT TO the repository root. Not in the repository root (sorry for this).

Reference

📃 Paper on arXiv

@article{stark2021_3dinfomax,
  title={3D Infomax improves GNNs for Molecular Property Prediction},
  author={Hannes Stärk and Dominique Beaini and Gabriele Corso and Prudencio Tossou and Christian Dallago and Stephan Günnemann and Pietro Liò},
  journal={arXiv preprint arXiv:2110.04126},
  year={2021}
}

3dinfomax's People

Contributors

hannesstark avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

3dinfomax's Issues

Fine-tuning a model with `BACEGeomol` dataset - error when collating

Hello, I am trying to fine-tune the pre-trained model you provided here runs/PNA_qmugs_NTXentMultiplePositives_620000_123_25-08_09-19-52/best_checkpoint_35epochs.pt with a csv of the bace dataset. When I use the BACEGeomol class to load and process the data, I get the following error when trying to collate the data with torch_geometric.

Traceback (most recent call last):
  File "train.py", line 702, in <module>
    train(args)
  File "train.py", line 286, in train
    return train_geomol(args, device, metrics_dict)
  File "train.py", line 313, in train_geomol
    train = dataset(split='train', device=device)
  File "/home/ubuntu/code/3DInfomax/datasets/bace_geomol_feat.py", line 59, in __init__
    super(BACEGeomol, self).__init__(root, transform, pre_transform)
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/in_memory_dataset.py", line 57, in __init__
    super().__init__(root, transform, pre_transform, pre_filter)
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 88, in __init__
    self._process()
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 171, in _process
    self.process()
  File "/home/ubuntu/code/3DInfomax/datasets/bace_geomol_feat.py", line 127, in process
    data, slices = self.collate(data_list)
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/in_memory_dataset.py", line 116, in collate
    add_batch=False,
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 85, in collate
    increment)
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 179, in _collate
    key, [v[key] for v in values], data_list, stores, increment)
  File "/home/ubuntu/anaconda3/envs/3DInfomax/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 179, in <listcomp>
    key, [v[key] for v in values], data_list, stores, increment)
KeyError: 0

My questions are:

  1. I see there are a lot of different dataset classes. Can I use the BACEGeomol dataset class to fine-tune that model, or should I be using a different dataset class? I'm also not sure because I see different functions to featurize molecules in the different classes.
  2. Have you seen this error before, and do you know what might be causing the bug?

Thank you!

Help

Can someone provide the correct environment information, including CUDA version, RDKIT version, OGB version, Python version

ResolvePackageNotFound in Windows

Hello, when using Anaconda in Windows 10:
I met the following question:
Resolve Package Not Found:

  • dgl-cuda 10.2
  • torchaudio
  • pytorch-geometry
  • torchvision

I wonder whether it means I really need a computer with an NVIDIA card so that I can install cuda and cudnn, and then I can install packages above?
Great thanks.

generating a fingerprint

Hi,
I find the code a bit difficult to read, is there an easy way to generate the embedding as a fingerprint ?

classification yml file

Could you provide another .yml file with respect to classification problem such as bbbp or tox21.

DglPCQM4MDataset ImportError for inference

Hello!

Hope you are well! When I run inference on the included example, I get the following error:

ImportError: cannot import name 'DglPCQM4MDataset' from 'ogb.lsc' 

obg is installed and is version 1.3.3. Any thoughts as to why this error would occur?
Thanks.

pretrain loss with negative value

hello everyone! I need help. When I repretrain the model with the commend
python train.py --config=configs_clean/pre-train_QM9.yml

there is negative value starting in the second epoch, this is the loss in epoch 200 and the reult:
[Epoch 199] NTXent: -1.781640 val loss: -1.781640 [Epoch 200; Iter 2/ 100] train: loss: -2.2522569 [Epoch 200; Iter 4/ 100] train: loss: -2.2169321 [Epoch 200; Iter 6/ 100] train: loss: -2.2570577 [Epoch 200; Iter 8/ 100] train: loss: -2.2113991 [Epoch 200; Iter 10/ 100] train: loss: -2.2871132 [Epoch 200; Iter 12/ 100] train: loss: -2.1956975 [Epoch 200; Iter 14/ 100] train: loss: -2.2575712 [Epoch 200; Iter 16/ 100] train: loss: -2.2456131 [Epoch 200; Iter 18/ 100] train: loss: -2.2454598 [Epoch 200; Iter 20/ 100] train: loss: -2.2255766 [Epoch 200; Iter 22/ 100] train: loss: -2.2309322 [Epoch 200; Iter 24/ 100] train: loss: -2.2573645 [Epoch 200; Iter 26/ 100] train: loss: -2.2338004 [Epoch 200; Iter 28/ 100] train: loss: -2.2719038 [Epoch 200; Iter 30/ 100] train: loss: -2.1639729 [Epoch 200; Iter 32/ 100] train: loss: -2.1668520 [Epoch 200; Iter 34/ 100] train: loss: -2.2000210 [Epoch 200; Iter 36/ 100] train: loss: -2.2143204 [Epoch 200; Iter 38/ 100] train: loss: -2.1709681 [Epoch 200; Iter 40/ 100] train: loss: -2.1966450 [Epoch 200; Iter 42/ 100] train: loss: -2.2163792 [Epoch 200; Iter 44/ 100] train: loss: -2.2385902 [Epoch 200; Iter 46/ 100] train: loss: -2.2735734 [Epoch 200; Iter 48/ 100] train: loss: -2.2653208 [Epoch 200; Iter 50/ 100] train: loss: -2.2500036 [Epoch 200; Iter 52/ 100] train: loss: -2.2324305 [Epoch 200; Iter 54/ 100] train: loss: -2.2250788 [Epoch 200; Iter 56/ 100] train: loss: -2.2158983 [Epoch 200; Iter 58/ 100] train: loss: -2.2318683 [Epoch 200; Iter 60/ 100] train: loss: -2.2518740 [Epoch 200; Iter 62/ 100] train: loss: -2.2250538 [Epoch 200; Iter 64/ 100] train: loss: -2.2269733 [Epoch 200; Iter 66/ 100] train: loss: -2.2310219 [Epoch 200; Iter 68/ 100] train: loss: -2.2161174 [Epoch 200; Iter 70/ 100] train: loss: -2.2205598 [Epoch 200; Iter 72/ 100] train: loss: -2.2224882 [Epoch 200; Iter 74/ 100] train: loss: -2.2406216 [Epoch 200; Iter 76/ 100] train: loss: -2.1987047 [Epoch 200; Iter 78/ 100] train: loss: -2.2273459 [Epoch 200; Iter 80/ 100] train: loss: -2.2188897 [Epoch 200; Iter 82/ 100] train: loss: -2.2317674 [Epoch 200; Iter 84/ 100] train: loss: -2.2398126 [Epoch 200; Iter 86/ 100] train: loss: -2.2111089 [Epoch 200; Iter 88/ 100] train: loss: -2.2469833 [Epoch 200; Iter 90/ 100] train: loss: -2.2543628 [Epoch 200; Iter 92/ 100] train: loss: -2.2504985 [Epoch 200; Iter 94/ 100] train: loss: -2.1923199 [Epoch 200; Iter 96/ 100] train: loss: -2.2006516 [Epoch 200; Iter 98/ 100] train: loss: -2.2204037 [Epoch 200; Iter 100/ 100] train: loss: -2.2411747 [Epoch 200] NTXent: -1.799788 val loss: -1.799788 Early stopping criterion based on -NTXent- that should be min reached after 200 epochs. Best model checkpoint was in epoch 165. Statistics on val_best_checkpoint positive_similarity: 0.9688039703501595 negative_similarity: 0.49956806831889683 contrastive_accuracy: 0.759679623776012 true_negative_rate: 0.5193592525190778 true_positive_rate: 1.0 uniformity: -4.5691987540986805 alignment: 2.764178216457367 batch_variance: 0.11491788489123185 dimension_covariance: 0.0008367771305428403 NTXent: -1.797559466626909 mean_pred: -0.0002646076123306153 std_pred: 0.10940126019219558 mean_targets: 1.0975488838956457e-05 std_targets: 0.006032964382838044 Statistics on test positive_similarity: 0.9683920939763387 negative_similarity: 0.499461951079192 contrastive_accuracy: 0.7599385954715587 true_negative_rate: 0.519951272893835 true_positive_rate: 0.9999259268796002 uniformity: -4.5703361829121905 alignment: 2.7638528523621737 batch_variance: 0.11491100241740544 dimension_covariance: 0.0008471731475933835 NTXent: -1.8462339418905753 mean_pred: -0.0003244010140743167 std_pred: 0.10939329227915516 mean_targets: 5.363229907591934e-06 std_targets: 0.006017989599732337
I just clone the code from the repitory without modification.
can someone tell me where the code mistake is?

Pretrained 3d model

Hello,
After reading the paper and part of the code (might be difficult for me to understand), I understand that you have a 3d model but not pre-trained, right?
If I misunderstood, could you tell me how I should use the 3d pre-trained model.
Thanks for your help!

Some questions about 3DInfomax

Dear professor,I have some questions about the 3DInfomax.
I want to get the evaluation metrics such as Precision,so I use the Function which you provided in your metric.py such as TruePositiveRate() and TrueNegativeRate() to get this metric. But I tried all OGB datasets and found that those metrics such as Precision,Accuracy and Recall were not ideal. I hope you can reply to me as soon as possible. Thank you, professor.

Here is the HIV dataset's metric:
Precision: 0.008995866402983665
Accuracy: 0.9988852739334106
Recall: 0.002496626228094101
F1_score: 0.003908519633114338
ROC_AUC: 0.7427065372467041
PR_AUC: 0.2141391634941101
ogbg-molhiv: 0.742706502636204
BCEWithLogitsLoss: 0.17792926660992883

Here is the BBBP dataset's metric:
Precision: 0.44607841968536377
Accuracy: 0.6127931475639343
Recall: 0.005654983688145876
F1_score: 0.011168383993208408
ROC_AUC: 0.6745756268501282
PR_AUC: 0.6546612977981567
ogbg-molbbbp: 0.6745756172839505
BCEWithLogitsLoss: 1.1453146849359785

Here is my metric code:
class Precision(nn.Module):
def init(self, threshold=0.5) -> None:
super(Precision, self).init()
self.threshold = threshold

def forward(self, x1: Tensor, x2: Tensor, pos_mask: Tensor = None) -> Tensor:
    batch_size, _ = x1.size()
    if x1.shape != x2.shape and pos_mask == None: 
        x2 = x2[:batch_size]
    sim_matrix = torch.einsum('ik,jk->ij', x1, x2)

    x1_abs = x1.norm(dim=1)
    x2_abs = x2.norm(dim=1)
    sim_matrix = sim_matrix / torch.einsum('i,j->ij', x1_abs, x2_abs)

    preds: Tensor = (sim_matrix + 1) / 2 > self.threshold
    if pos_mask == None:  # if we are comparing global with global
        pos_mask = torch.eye(batch_size, device=x1.device) 
        neg_mask = 1 - pos_mask 

    num_positives = len(x1)
    num_negatives = len(x1) * (len(x2) - 1)

    false_positives = ((preds.long() - pos_mask) * pos_mask).count_nonzero()
    true_positives = num_positives - ((preds.long() - pos_mask) * pos_mask).count_nonzero()

    false_negatives = (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
    true_negatives = num_negatives - (((~preds).long() - neg_mask) * neg_mask).count_nonzero()

    pre = true_positives /(true_positives + false_positives)
    return pre

Fine-tuning a model with moltox21 dataset - error

Hello, I am trying to fine-tune the pre-trained model you provided here runs/PNA_qmugs_NTXentMultiplePositives_620000_123_25-08_09-19-52/best_checkpoint_35epochs.pt with the tox21 dataset. I change the parameters in configs_clean/tune_freesolv.yml that set the target-dim as 12, which is the classes in tox21. However, the model loss is nan.

[Epoch 1; Iter   150/  209] train: loss: nan
[Epoch 1; Iter   180/  209] train: loss: nan
[Epoch 1; Iter   180/  209] train: loss: nan
[Epoch 1; Iter   180/  209] train: loss: nan
[Epoch 1; Iter   180/  209] train: loss: nan
[Epoch 1; Iter   180/  209] train: loss: nan
...
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').
  1. Could you provide another .yml file with respect to classification problem such as tox21 or sider.
  2. Could you provide the evaluation merics using for classification tasks such as aucroc or f1-score.

Thank you!

Having trouble pre-training with example code

Hi,
After installing all the required packages, I follow the Step 2 in Readme.md to run the following code:

python train.py --config=configs_clean/pre-train_QM9.yml

However, I got the error as follow:
Traceback (most recent call last):
File "train.py", line 699, in
train(args)
File "train.py", line 270, in train
return train_qm9(args, device, metrics_dict)
File "train.py", line 562, in train_qm9
dist_embedding=args.dist_embedding, num_radial=args.num_radial)
File "/data2/3DInfomax/datasets/qm9_dataset.py", line 187, in init
self.dist_embedder = dist_emb(num_radial=6).to(device)
File "/data2/3DInfomax/commons/spherical_encoding.py", line 183, in init
self.reset_parameters()
File "/data2/3DInfomax/commons/spherical_encoding.py", line 186, in reset_parameters
torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

I didn't modify the code. Any idea for aforementioned error?

having trouble training for GEOM-Mol + trained models

File "/home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA out of memory. Tried to allocate 410.00 MiB (GPU 0; 11.17 GiB total capacity; 9.92 GiB already allocated; 336.44 MiB free; 10.30 GiB reserved in total by PyTorch)
Any idea?

Also would it be possible for you to put up trained models for both QM9 and Geom-Drugs

Linked video freezes

Hi, this doesn't concern the repo, but the linked talk (youtube video) freezes after ~22 min. Is that the complete recording, or is there a fix? Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.