Git Product home page Git Product logo

meta-mgnn's Introduction

Few-shot Graph Learning for Molecular Property Prediction

Introduction

This is the source code and dataset for the following paper:

Few-shot Graph Learning for Molecular Property Prediction. In WWW 2021.

Contact Zhichun Guo ([email protected]), if you have any questions.

Datasets

The datasets uploaded can be downloaded to train our model directly.

The original datasets are downloaded from Data. We utilize Original_datasets/splitdata.py to split the datasets according to the molecular properties and save them in different files in the Original_datasets/[DatasetName]/new. Then run main.py, the datasets will be automatically preprocessed by loader.py and the preprocessed results will be saved in the Original_datasets/[DatasetName]/new/[PropertyNumber]/propcessed.

Usage

Installation

We used the following Python packages for the development by python 3.6.

- torch = 1.4.0
- torch-geometric = 1.6.1
- torch-scatter = 2.0.4
- torch-sparse = 0.6.1
- scikit-learn = 0.23.2
- tqdm = 4.50.0
- rdkit

Run code

Datasets and k (for k-shot) can be changed in the last line of main.py.

python main.py

Performance

The performance of meta-learning is not stable for some properties. We report two times results and the number of the iteration where we obtain the best results here for your reference.

Dataset k Iteration Property Results k Iteration Property Results
Sider 1 307/599 Si-T1 75.08/75.74 5 561/585 Si-T1 76.16/76.47
Si-T2 69.44/69.34 Si-T2 68.90/69.77
Si-T3 69.90/71.39 Si-T3 72.23/72.35
Si-T4 71.78/73.60 Si-T4 74.40/74.51
Si-T5 79.40/80.50 Si-T5 81.71/81.87
Si-T6 71.59/72.35 Si-T6 74.90/73.34
Ave. 72.87/73.82 Ave. 74.74/74.70
Tox21 1 1271/1415 SR-HS 73.72/73.90 5 1061/882 SR-HS 74.85/74.74
SR-MMP 78.56/79.62 SR-MMP 80.25/80.27
SR-p53 77.50/77.91 SR-p53 78.86/79.14
Ave. 76.59/77.14 Ave. 77.99/78.05

Acknowledgements

The code is implemented based on Strategies for Pre-training Graph Neural Networks.

Reference

@article{guo2021few,
  title={Few-Shot Graph Learning for Molecular Property Prediction},
  author={Guo, Zhichun and Zhang, Chuxu and Yu, Wenhao and Herr, John and Wiest, Olaf and Jiang, Meng and Chawla, Nitesh V},
  journal={arXiv preprint arXiv:2102.07916},
  year={2021}
}

meta-mgnn's People

Contributors

zhichunguo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

meta-mgnn's Issues

After 1000 epoch, the training loss and test accuracy remains the same?

Hi @zhichunguo

I am facing a challenge with MetaMGNN. I trained the model for 1000 epoch, but the loss and accuracy remains the same. This means it is not training the meta model at all. I am applying your code on my data with 79 node features, however, I tried with your papers features too and the result is similar. I am wondering I am doing mistakes somewhere....

DEBUG:root:Epoch: 0	Train Loss: 0.8704835772514343	Test Accuracy: 43.87436384313247	
DEBUG:root:Epoch: 1	Train Loss: 0.32417869567871094	Test Accuracy: 49.572593688964844	
DEBUG:root:Epoch: 2	Train Loss: 0.09464379400014877	Test Accuracy: 42.63082504272461	
DEBUG:root:Epoch: 3	Train Loss: 0.07533042877912521	Test Accuracy: 54.545582715202784	
DEBUG:root:Epoch: 4	Train Loss: 0.11637484282255173	Test Accuracy: 44.320564943201404	
DEBUG:root:Epoch: 5	Train Loss: 0.034378036856651306	Test Accuracy: 54.640758626601276	
DEBUG:root:Epoch: 6	Train Loss: 0.03435908630490303	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 7	Train Loss: 0.03385530039668083	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 8	Train Loss: 0.03292354196310043	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 9	Train Loss: 0.0315357968211174	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 10	Train Loss: 0.03256123140454292	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 11	Train Loss: 0.03294479474425316	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 12	Train Loss: 0.032359760254621506	Test Accuracy: 50.21079287809484	
.
.
.

DEBUG:root:Epoch: 994	Train Loss: 0.031872864812612534	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 995	Train Loss: 0.03162331134080887	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 996	Train Loss: 0.03111858479678631	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 997	Train Loss: 0.03308967500925064	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 998	Train Loss: 0.031367577612400055	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 999	Train Loss: 0.03169362619519234	Test Accuracy: 50.21079287809484	
DEBUG:root:Epoch: 1000	Train Loss: 0.0316644087433815	Test Accuracy: 50.21079287809484	
INFO:root:Training has finished.

What I did include the followings:

  • I have extracted 79 node features and 10 edge features from my data
  • I want to run on MAML firts, so I set
 parser.add_argument('--add_similarity', type=bool, default=False)
    parser.add_argument('--add_selfsupervise', type=bool, default=False)
    parser.add_argument('--interact', type=bool, default=False)

and

parser.add_argument('--num_tasks', type=int, default= 33, help='number of tasks')
parser.add_argument('--num_train_tasks', type=int, default= 16, help='number of meta-training task')
parser.add_argument('--num_test_tasks', type=int, default= 17, help='number of meta-testing tasks')
parser.add_argument('--n_ways', type=int, default=2, help='n-ways of dataset')
parser.add_argument('--m_support', type=int, default=10, help='size of support set')
parser.add_argument('--k_query', type=int, default=20, help='size of query set')
parser.add_argument('--meta_lr', type=float, default=0.001,
                    help='meta learning rate (i.e., outer lr:  beta) (default: 0.001)')
parser.add_argument('--update_lr', type=float, default=0.005,
                    help='task-specific learning rate (inner lr: alpha) (default: 0.04)')
parser.add_argument('--batch_size', type=int, default= 32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default= 1000,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate (default: 0.001')
parser.add_argument('--lr_scale',type=float, default=0.1,
                    help='relative learning rate for feature extraction layer (default: 1)')
parser.add_argument('--decay', type=float, default=0,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer' ,type=int, default=7,
                    help='number of GNN message passing layer (default: 5)')
parser.add_argument('--emb_dim', type=int, default = 256,
                    help='embedding dimension (default: 300)', choices=[64, 300])

loader.py bond features and possible values for initializations

Can you please explain how in Line 12 the possible bond type is 6
and in line #43

self_loop_attr[:,0] = 4 #bond type for self-loop edge

why the value is equal to 4?

According to the bond features,

   'possible_bonds' : [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ],

Thanks

Question about sampling dataset

Hi @zhichunguo

I am confused about the meaning of obtain_distr_list. Does it stand for [training, testing] or [query, support] ?

Besides, what's support_list += random.sample(range(distri_list[task][0],len(data)), m_support) for in

support_list += random.sample(range(distri_list[task][0],len(data)), m_support)

and
support_list += random.sample(range(distri_list[task][0],len(data)), m_support)

which doubles the support set.

what I think is that obtain_distr_list pre-define the set of query and support in each task and will not change in the entire training process. For each training loop, query dataset and support dataset should be sampled from these two sets based on the defined query and support number. Please let me know if I was wrong.

Could you explain this part of the code in samples.py?

def obtain_distr_list(dataset):
    if dataset = "siderdo ":
        return [[684,743],[431,996],[1405,22],[551,876],[276,1151],[430,997],[129,1298],[1176,251],[403,1024],[700,727],[1051,376],[135,1292],[1104,323],[1214,213],[319,1108],[542,885],[109,1318],[1174,253],[421,1006],[367,1060],[411,1016],[516,911],[1302,125],[768,659],[439,988],[123,1304],[481,946]]
    elif dataset == "tox21":
        return [[6956,309],[6521,237],[5781,768],[5521,300],[5400,793],[6605,350],[6264,186],[4890,942],[6808,264],[6095,372],[4892,918],[6351,423]]
    elif dataset == "muv":
        return [[14814,27],[14705,29],[14698,30],[14593,30],[14873,29],[14572,29],[14614,30],[14383,28],[14807,29],[14654,28],[14662,29],[14615,29],[14637,30],[14681,30],[14622,29],[14745,29],[14722,24]]
    elif dataset == "toxcast":
        return [[1293, 438], [1441, 290], [864, 170], [995, 39], [794, 240], [738, 296], [591, 443], [977, 57], [948, 86], [960, 59], [910, 109], [908, 126], [1010, 24], [930, 89], [947, 72], [281, 22], [849, 185], [889, 130], [822, 212], [740, 279], [979, 55], [994, 40], [1018, 16], [797, 237], [788, 246], [286, 17], [967, 67], [935, 99], [842, 192], [828, 206], [262, 41], [257, 46], [252, 51], [267, 36], [251, 52], [248, 55], [247, 56], [251, 52], [251, 52], [262, 41], [263, 40], [283, 20], [274, 29], [286, 17], [284, 19], [3333, 79], [2796, 616], [3198, 214], [3379, 33], [3400, 12], [3382, 30], [3184, 228], [2975, 437], [3321, 91], [3032, 380], [3341, 71], [3386, 26], [3279, 133], [2875, 537], [3363, 49], [3061, 351], [3341, 71], [3172, 240], [2716, 696], [3357, 55], [3278, 134], [3094, 318], [3287, 125], [3390, 22], [2913, 499], [3207, 205], [2543, 869], [3386, 26], [3388, 24], [3402, 10], [2685, 727], [3081, 331], [3340, 72], [3195, 217], [3395, 17], [3320, 92], [3353, 59], [3350, 62], [3256, 156], [3370, 42], [3068, 76], [3310, 102], [3376, 36], [3380, 32], [3231, 181], [3271, 141], [3367, 45], [3395, 17], [3363, 49], [3193, 219], [3036, 376], [3388, 24], [3373, 39], [3293, 119], [3356, 56], [3367, 45], [2998, 414], [3078, 334], [3330, 82], [2947, 465], [3397, 15], [3359, 53], [3319, 93], [3397, 15], [3346, 66], [2696, 716], [3400, 12], [3338, 74], [3356, 56], [3386, 26], [3364, 48], [3370, 42], [3363, 49], [3392, 20], [3401, 11], [3299, 113], [3371, 41], [3372, 40], [3233, 179], [3365, 47], [3146, 266], [3142, 270], [3282, 130], [3265, 147], [3319, 93], [3367, 45], [2123, 1289], [3392, 20], [3208, 204], [3386, 26], [2867, 545], [3392, 20], [3026, 386], [3385, 27], [3184, 228], [3351, 61], [2484, 928], [3330, 82], [2887, 525], [3090, 322], [1769, 1643], [3400, 12], [2445, 967], [2963, 449], [3176, 236], [3344, 68], [3285, 127], [3397, 15], [3357, 55], [3364, 48], [3393, 19], [3041, 371], [3368, 44], [3393, 19], [3387, 25], [3399, 13], [3314, 98], [3324, 88], [2893, 519], [3379, 33], [2887, 525], [3323, 89], [3381, 31], [3389, 23], [3198, 214], [3388, 24], [3071, 341], [3357, 55], [3300, 112], [3394, 18], [3186, 226], [2958, 454], [3382, 30], [3299, 113], [3285, 127], [3384, 28], [3311, 101], [3403, 9], [2486, 926], [3398, 14], [3373, 39], [2648, 245], [3393, 19], [2960, 452], [3083, 329], [3334, 78], [1043, 396], [889, 550], [1240, 199], [1114, 325], [1062, 377], [1254, 185], [842, 597], [964, 475], [1383, 56], [1306, 133], [1124, 315], [1385, 54], [1090, 349], [1006, 433], [1026, 413], [1006, 433], [1058, 381], [1041, 398], [1423, 16], [1051, 388], [1018, 421], [1229, 210], [1098, 341], [1424, 15], [1056, 383], [1174, 265], [1060, 379], [1253, 186], [1253, 186], [1312, 127], [1150, 289], [1235, 204], [1215, 224], [1198, 241], [1244, 195], [1406, 33], [1204, 235], [1154, 285], [1235, 204], [1396, 43], [1299, 140], [1281, 158], [1361, 78], [1231, 208], [1413, 26], [1180, 259], [1423, 16], [1287, 152], [998, 441], [1417, 22], [1267, 172], [1409, 30], [1193, 246], [1371, 68], [1191, 248], [1223, 216], [1160, 279], [1407, 32], [1197, 242], [1422, 17], [1218, 221], [1147, 292], [1121, 318], [1420, 19], [1186, 253], [1419, 20], [1053, 386], [1211, 228], [1151, 288], [1119, 320], [1177, 262], [1019, 420], [1138, 301], [1423, 16], [1134, 305], [1423, 16], [1124, 315], [1414, 25], [1119, 320], [1047, 392], [1146, 293], [1349, 90], [1070, 369], [1151, 288], [1368, 71], [1208, 231], [1390, 49], [1003, 436], [1000, 439], [998, 441], [1040, 399], [1034, 405], [1398, 41], [1096, 343], [1402, 37], [1096, 343], [1212, 227], [1123, 316], [1367, 72], [877, 562], [1079, 360], [1006, 433], [1347, 92], [1382, 57], [1252, 187], [1023, 416], [1027, 412], [1149, 290], [1178, 261], [1380, 59], [1049, 390], [817, 622], [1112, 327], [1176, 263], [1032, 407], [300, 202], [318, 184], [369, 133], [365, 131], [427, 69], [470, 30], [428, 72], [459, 43], [436, 66], [411, 51], [353, 147], [387, 113], [351, 118], [358, 142], [283, 17], [279, 21], [176, 120], [186, 109], [201, 101], [169, 133], [147, 153], [221, 81], [128, 171], [139, 161], [121, 181], [178, 114], [178, 116], [254, 42], [272, 28], [277, 22], [261, 39], [252, 50], [236, 64], [173, 200], [276, 97], [143, 175], [66, 307], [22, 31], [221, 482], [168, 71], [105, 70], [39, 134], [86, 27], [35, 101], [76, 301], [38, 187], [37, 80], [75, 85], [49, 28], [23, 31], [74, 68], [90, 21], [72, 23], [80, 90], [42, 37], [99, 31], [43, 60], [81, 80], [59, 54], [136, 29], [196, 24], [55, 44], [37, 45], [55, 35], [70, 34], [72, 21], [58, 39], [53, 26], [80, 58], [113, 67], [92, 20], [65, 31], [63, 24], [54, 25], [51, 24], [76, 32], [29, 38], [88, 26], [69, 29], [42, 21], [130, 24], [56, 84], [42, 61], [50, 49], [56, 39], [31, 84], [42, 64], [57, 71], [76, 56], [52, 54], [74, 38], [24, 31], [50, 85], [43, 77], [36, 53], [37, 28], [45, 57], [55, 91], [63, 46], [66, 89], [35, 65], [40, 120], [46, 21], [34, 84], [20, 66], [30, 61], [31, 81], [38, 57], [38, 40], [61, 25], [32, 98], [53, 72], [21, 57], [33, 57], [49, 22], [26, 57], [43, 75], [32, 70], [49, 81], [85, 79], [47, 60], [75, 114], [35, 60], [41, 70], [43, 29], [44, 48], [41, 51], [40, 53], [25, 53], [42, 23], [66, 46], [57, 28], [57, 72], [57, 65], [37, 33], [915, 27], [25, 30], [42, 57], [26, 77], [51, 40], [31, 71], [35, 54], [41, 117], [42, 25], [43, 23], [24, 26], [37, 25], [54, 30], [133, 215], [116, 217], [927, 127], [110, 75], [98, 206], [116, 112], [194, 83], [900, 228], [133, 31], [198, 59], [120, 225], [304, 72], [602, 178], [196, 85], [405, 109], [231, 29], [145, 21], [168, 55], [742, 186], [139, 131], [77, 20], [38, 107], [50, 123], [26, 51], [50, 193], [69, 160], [64, 39], [39, 39], [52, 61], [53, 49], [1635, 137], [1629, 111], [1532, 201], [1623, 125], [1575, 99], [1544, 211], [1478, 190], [1543, 201], [1497, 169], [1596, 175], [1619, 139], [1424, 311], [1549, 133], [1560, 198], [1657, 80], [6835, 352], [6564, 623], [5583, 1604], [7118, 69], [7926, 5], [7562, 369], [7540, 391], [7908, 23], [7746, 185], [6773, 1158], [7351, 580], [7565, 366], [7100, 831], [6034, 1153], [7141, 790], [6674, 1257], [7900, 31], [7898, 33], [7899, 32], [7926, 5], [7901, 30], [7927, 4], [5151, 120], [7653, 278], [7482, 449], [7480, 451], [7650, 281], [7694, 237], [6919, 1012], [7750, 181], [6691, 1240], [7234, 697], [7110, 77], [7094, 93], [6871, 316], [6971, 216], [6843, 344], [6917, 270], [7020, 167], [6997, 190], [6243, 944], [6871, 316], [7620, 311], [7721, 210], [7448, 483], [7413, 518], [7492, 439], [7550, 381], [6909, 278], [6830, 357], [6592, 595], [7035, 152], [4425, 846], [5163, 108], [4982, 289], [6908, 279], [7100, 87], [6961, 226], [6755, 432], [6551, 636], [7084, 103], [7184, 3], [7017, 170], [7010, 177], [6761, 426], [7171, 16], [7926, 5], [7716, 215], [7596, 335], [7175, 12], [6655, 532], [7045, 142], [7912, 19], [6186, 1745], [6572, 615], [7027, 160], [7140, 47], [7134, 53], [6890, 297], [6926, 261], [7380, 551], [7479, 452], [7136, 795], [7556, 375], [7423, 508], [7149, 782], [6919, 1012], [7152, 779], [7269, 662], [7239, 692], [6991, 940], [7516, 415], [7292, 639], [7379, 552], [7001, 930], [7279, 652], [7596, 335], [7306, 625], [7066, 865], [7622, 309], [910, 111], [840, 194], [962, 59], [975, 46], [1006, 15], [945, 76], [911, 110], [914, 120], [982, 39], [903, 118], [969, 52], [979, 42], [914, 107], [985, 36], [991, 30], [966, 55], [942, 79], [895, 126]]

Hi, thank you for your work and code.
But could you tell me what the meaning of this part of samples.py ? Where these numbers come from?

Code Error of the file "meta_model.py "

In the line 240 in meta_mode.py, the computation loss of "add_masking" module should be updated based on loss_query (loss_q), not on loss_support (loss).

Is it my misunderstanding or code error?

pre-trained models

Hello, this is a really great job! However, I met some issues during my training. Can you give me some suggestions?
If I didn't use the pretrained model for training, the loss seem to so large that the training process will shut down.
ValueError: Input contains NaN.

Tasks distribution: how did you divide the data?

Hi, thanks for your great contribution.

I am trying to understand you code. Can you please explain what is the reason you have set these numbers in this line for tox21 dataset?

and how did you split the datasets? For example, in splitdata.py, in line 27, you have set a number for for loop as 618. can you explain on this?

The other question is that, does any smiles sequence repeated in different tasks?
For example:

0,,0,,1,,,1,0,1,0,1,TOX25232,O=C(O)Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1

Is this sequence included as a sample for 4 tasks SR-HSE,SR-p53, NR-ER, and SR-ARE?

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.