Git Product home page Git Product logo

pytorch_gat's Introduction

pytorch_gat

Pytorch implementation of graph attention network:
Paper address:Graph Attention Networks (Veličković et al., ICLR 2018): https://arxiv.org/abs/1710.10903

The implementation is based on the official code of the graph attention network. Not to get the same performance as the original code, but to deepen the understanding of tensorflow and pytorch. If you want better performance, you can refer to:

Official implementation: https://github.com/PetarV-/GAT
Another pytorch implementation: https://github.com/Diego999/pyGAT
keras implementation:https://github.com/danielegrattarola/keras-gat

You can learn how to convert tensorflow code to pytorch code from here:
https://i.cnblogs.com/posts/edit;postId=13659274

Introduction

utils.py: Read data and data processing.
layer.py: Attention layer.
model.py: Graph attention model network.
main.py: Training, validation and testing.
You can run it through:

python main.py

Results

I did not refer to another implementation of pytorch. In order to make it easier to compare my code with the tensorflow version, the code is constructed according to the tensorflow structure.

The following is the result of running the official code:

Dataset: cora
----- Opt. hyperparams -----
lr: 0.005
l2_coef: 0.0005
----- Archi. hyperparams -----
nb. layers: 1
nb. units per layer: [8]
nb. attention heads: [8, 1]
residual: False
nonlinearity: <function elu at 0x7f1b7507af28>
model: <class 'models.gat.GAT'>
(2708, 2708)
(2708, 1433)
epoch:  1
Training: loss = 1.94574, acc = 0.14286 | Val: loss = 1.93655, acc = 0.13600
epoch:  2
Training: loss = 1.94598, acc = 0.15714 | Val: loss = 1.93377, acc = 0.14800
epoch:  3
Training: loss = 1.94945, acc = 0.14286 | Val: loss = 1.93257, acc = 0.19600
epoch:  4
Training: loss = 1.93438, acc = 0.24286 | Val: loss = 1.93172, acc = 0.22800
epoch:  5
Training: loss = 1.93199, acc = 0.17143 | Val: loss = 1.93013, acc = 0.36400
。。。。。。
epoch:  674
Training: loss = 1.23833, acc = 0.49286 | Val: loss = 1.01357, acc = 0.81200
Early stop! Min loss:  1.010906457901001 , Max accuracy:  0.8219999074935913
Early stop model validation loss:  1.3742048740386963 , accuracy:  0.8219999074935913
Test loss: 1.3630210161209106 ; Test accuracy: 0.8219999074935913

The following is the result of running my code:

(2708, 2708)
(2708, 1433)
训练节点个数140
验证节点个数500
测试节点个数1000
epoch:001,TrainLoss:7.9040,TrainAcc:0.0000,ValLoss:7.9040,ValAcc:0.0000
epoch:002,TrainLoss:7.9040,TrainAcc:0.0000,ValLoss:7.9039,ValAcc:0.1920
epoch:003,TrainLoss:7.9039,TrainAcc:0.0714,ValLoss:7.9039,ValAcc:0.1600
epoch:004,TrainLoss:7.9038,TrainAcc:0.1000,ValLoss:7.9039,ValAcc:0.1020
。。。。。。
epoch:2396,TrainLoss:7.0191,TrainAcc:0.8929,ValLoss:7.4967,ValAcc:0.7440
epoch:2397,TrainLoss:7.0400,TrainAcc:0.8786,ValLoss:7.4969,ValAcc:0.7580
epoch:2398,TrainLoss:7.0188,TrainAcc:0.8929,ValLoss:7.4974,ValAcc:0.7580
epoch:2399,TrainLoss:7.0045,TrainAcc:0.9071,ValLoss:7.4983,ValAcc:0.7620
epoch:2400,TrainLoss:7.0402,TrainAcc:0.8714,ValLoss:7.4994,ValAcc:0.7620
TestLoss:7.4805,TestAcc:0.7700

The following is the result:
Loss changes with epochs:
pic1
Acc changes with epochs:
pic1
Dimensionality reduction visualization of test results:
pic1

pytorch_gat's People

Contributors

taishan1994 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch_gat's Issues

GAT 用于 图像分割

您好,我就直接用中文了

我看您的代码也没设置batch-size,我跑他们的代码也是这样,似乎很容易造成内存不足。

另外,adj[np.newaxis]为什么要加这一个维度呢?我的理解哈,就相当于说明只有一个graph。我使用GAT来做图像分割,相当于有79个graph(subject),每个graph有10242个节点,8个特征,分成76个类别。
在此我想问问我将adj,features,label 建成 3D (79,10242,10242) (79,10242,8) (79,10242,76),再根据graph的个数来划分训练集测试集,如果train中有52,test有27,则y_train形式(79,10242,76)79个中只有52个(10242,10242),其余位置是0。 那train_mask就是(1,79)么,类型为布尔?我不太清楚train_mask这种形式正确不?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.