Git Product home page Git Product logo

pna's Introduction

Principal Neighbourhood Aggregation

Implementation of Principal Neighbourhood Aggregation for Graph Nets arxiv.org/abs/2004.05718 in PyTorch, DGL and PyTorch Geometric.

Update: now you can find PNA directly integrated in both PyTorch Geometric and DGL!

symbol

Overview

We provide the implementation of the Principal Neighbourhood Aggregation (PNA) in PyTorch, DGL and PyTorch Geometric frameworks, along with scripts to generate and run the multitask benchmarks, scripts for running real-world benchmarks, a flexible PyTorch GNN framework and implementations of the other models used for comparison. The repository is organised as follows:

  • models contains:
    • pytorch contains the various GNN models implemented in PyTorch:
      • the implementation of the aggregators, the scalers and the PNA layer (pna)
      • the flexible GNN framework that can be used with any type of graph convolutions (gnn_framework.py)
      • implementations of the other GNN models used for comparison in the paper, namely GCN, GAT, GIN and MPNN
    • dgl contains the PNA model implemented via the DGL library: aggregators, scalers, and layer.
    • pytorch_geometric contains the PNA model implemented via the PyTorch Geometric library: aggregators, scalers, and layer.
    • layers.py contains general NN layers used by the various models
  • multi_task contains various scripts to recreate the multi_task benchmark along with the files used to train the various models. In multi_task/README.md we detail the instructions for the generation and training hyperparameters tuned.
  • real_world contains various scripts from Benchmarking GNNs to download the real-world benchmarks and train the PNA on them. In real_world/README.md we provide instructions for the generation and training hyperparameters tuned.

results

Reference

@inproceedings{corso2020pna,
 title = {Principal Neighbourhood Aggregation for Graph Nets},
 author = {Corso, Gabriele and Cavalleri, Luca and Beaini, Dominique and Li\`{o}, Pietro and Veli\v{c}kovi\'{c}, Petar},
 booktitle = {Advances in Neural Information Processing Systems},
 year = {2020}
}

License

MIT

Acknowledgements

The authors would like to thank Saro Passaro for running some of the tests presented in this repository and Giorgos Bouritsas, Fabrizio Frasca, Leonardo Cotta, Zhanghao Wu, Zhanqiu Zhang and George Watkins for pointing out some issues with the code.

pna's People

Contributors

gcorso avatar lukecavabarrett avatar michaelvll avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

pna's Issues

Suspected memory leak issue in the code

Hello, I have an existing project that uses different types of graph neural networks. Other neural networks from the official dgl implementations are working normally, but the dgl version of your code experiences GPU memory leaks. Which version of dgl are you using?

Simpler version of the PNA

I found for some tasks that I'm working on personally that the MPNN-style architecture does not perform well, no matter the aggregators or scalers that are used. Even the simple sum and mean aggregators perform less well than their GIN and GCN cousins.

For this reason, I propose to add the following simpler architecture as a variant of the PNA layer, which doesn't use the MPNN attention mechanism, but instead aggregates the neighbours in a similar way than CNN, GCN and GIN layers. An obvious drawback is the lack of edge features, but on my personal project on a molecular dataset, edge features seem to cause more overfit.

I propose to add it in the file pna/models/dgl/pna_layer.py. I did not implement it in pytorch-geometric or standard pytorch.

class PNASimpleLayer(nn.Module):

    def __init__(self, in_dim, out_dim, aggregators, scalers, avg_d, dropout, batch_norm, activation,
                posttrans_layers=1, residual=False):
        """
        A PNA layer that simply aggregates the neighbourhood (similar to GCN and GIN),
        without using the attention mechanism of the MPNN. It does not support edge features.

        :param in_dim:              size of the input per node
        :param out_dim:             size of the output per node
        :param aggregators:         set of aggregation function identifiers
        :param scalers:             set of scaling functions identifiers
        :param avg_d:               average degree of nodes in the training set, used by scalers to normalize
        :param dropout:             dropout used
        :param batch_norm:          whether to use batch normalisation
        :param posttrans_layers:    number of layers in the transformation after the aggregation
        """
        super().__init__()

        # retrieve the aggregators and scalers functions
        aggregators = [AGGREGATORS[aggr] for aggr in aggregators.split()]
        scalers = [SCALERS[scale] for scale in scalers.split()]

        self.aggregators = aggregators
        self.scalers = scalers
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.batch_norm = batch_norm

        self.batchnorm_h = nn.BatchNorm1d(out_dim)
        self.activation = activation
        self.posttrans = MLP(in_size=(len(aggregators) * len(scalers) + 1) * in_dim, hidden_size=out_dim,
                            out_size=out_dim, layers=posttrans_layers, 
                            mid_activation=activation, last_activation=activation,
                            dropout=dropout, mid_b_norm=batch_norm, last_b_norm=batch_norm)
        self.avg_d = avg_d


    def reduce_func(self, nodes):
        h = nodes.mailbox['m']
        D = h.shape[-2]
        h = torch.cat([aggregate(h) for aggregate in self.aggregators], dim=1)
        h = torch.cat([scale(h, D=D, avg_d=self.avg_d) for scale in self.scalers], dim=1)
        return {'h': h}


    def forward(self, g, h):
        g.ndata['h'] = h

        # aggregation
        g.update_all(fn.copy_u('h', 'm'), self.reduce_func)
        h = torch.cat([h, g.ndata['h']], dim=1)

        # posttransformation
        h = self.posttrans(h)

        return h


    def __repr__(self):
        return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_dim, self.out_dim)

The use of the pytorch code

Hi, thank you very much for your sharing of your code. Is there any way that in which kind of code, you can provide a examples how to create the input and how to use the layer?

Learnable Scalers

Hi @lukecavabarrett @gcorso

Thank you for this great work! I have a minor question here. Is it possible to use learnable scalers that could be adapted according to different inputs, instead of using degree-based scalers? Thank you!

Best,
Yongcheng

Bug in device check

The type of device is either torch.device("cuda") or torch.device("cpu"). Thus, device == 'cuda' will always be False and torch.cuda.manual_seed(params['seed']) will not work.

Bug in argparse

It seems that args.divide_input_first and args.divide_input_first will always be True, even if we set --divide_input_first=False in the command line. A similar question can be found here.

Wrong default scalers

Hi all! First of all, great paper, I could definitely see it boosting GNN research.

I believe there is an issue with the default parameters in models/pna/train.py, line 11, namely:

parser.add_argument('--scalers', type=str, default='identity exp log', help='Scalers to use')

exp and log are not part of the scalers. Based on the paper (and the README) I assume you mean amplification and attenuation. Currently, as the code is, not providing scalers argument results in a KeyError.

Thanks!

Potential error with average degrees in example.py

I think the calculation of average degrees contains an error when example.py is called.

The deg variable (defined in pna/models/pytorch_geometric/example.py) contains, at index i, the number of vertices in the dataset with degree i. This is passed to the init function of PNAConvSimple (in pna/models/pytorch_geometric/pna.py) to calculate the values for self.avg_deg.

But we end up calculating the average of the degree counts, not the average of the degrees. The following two (non-optimised) solutions would work:

  • Edit example.py so deg is actually a list of degrees
deg = []
for data in dataset[split_idx['train']]:
    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
    deg += list(d)
deg = torch.tensor(deg, dtype=torch.long)
  • Edit pna.py to handle binned frequency tensors (such as is returned by torch.bincount)
total_no_vertices = deg.sum()
bin_degrees = torch.arange(len(deg))
self.avg_deg: Dict[str, float] = {
    'lin': ((bin_degrees*deg).sum()/total_no_vertices).item(),
    'log': (((bin_degrees+1).log()*deg).sum()/total_no_vertices).item(),
    'exp': ((bin_degrees.exp()*deg).sum()/total_no_vertices).item(),
}

Apologies if I've misunderstood what the avg_deg variable is supposed to contain. Note I haven't checked elsewhere to see if this same issue arises in other files.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.