pip install jeometric
NOTE: this library is still in the very early stages of development. Breaking changes might appear every other day ❤️
import jax
from jeometric.data import Data, Batch
from jeometric.gnn import GCNLayer
# generate random node features and edges
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10, 5))
senders = jax.random.randint(key, (10,), 0, 10)
receivers = jax.random.randint(key, (10,), 0, 10)
# create two graphs
graph1 = Data(x=x, senders=senders, receivers=receivers)
graph2 = Data(x=x, senders=senders, receivers=receivers)
# batch the graphs together in a single graphs
batch = Batch.from_data_list([graph1, graph2])
# create a GCN layer
gcn_layer = GCNLayer(input_dim=5, output_dim=1)
# initialize the layer and apply it to the batch
params = gcn_layer.init(key, batch.x, batch.senders, batch.receivers)
out = gcn_layer.apply(params, batch.x, batch.senders, batch.receivers)
# out.shape == (20, 1)
import jax
from flax import linen as nn
from jeometric.data import Data
from jeometric.ops import segment_sum
from jeometric.gnn import GCNLayer
from typing import List
class GraphConvolutionalNetwork(nn.Module):
input_dim: int
hidden_dims: List[int]
output_dim: int
@nn.compact
def __call__(self, graph: Data, num_graphs: int) -> Data:
x, senders, receivers = graph.x, graph.senders, graph.receivers
current_input_dim = self.input_dim
for dim in self.hidden_dims:
x = GCNLayer(
input_dim=current_input_dim,
output_dim=dim,
)(x, senders, receivers)
x = jax.nn.relu(x)
current_input_dim = dim
x = GCNLayer(
input_dim=current_input_dim,
output_dim=self.output_dim,
)(x, senders, receivers)
x = segment_sum(x, graph.batch, num_graphs)
return x
Some examples can be found in the examples
directory.
examples/train_molhiv.py
provides an example of training a graph convolutional network onmolhiv
.examples/benchmark_gcn_molhiv.py
provides code to benchmark the jit and non-jit version.