This notebook demonstrates how to train a graph classification model in a supervised setting using the Deep Graph Convolutional Neural Network (DGCNN) [1] algorithm.
In supervised graph classification, we are given a collection of graphs each with an attached categorical label. For example, the PROTEINS dataset we use for this demo is a collection of graphs each representing a chemical compound and labelled as either an enzyme or not. Our goal is to train a machine learning model that uses the graph structure of the data together with any information available for the graph's nodes, e.g., chemical properties for the compounds in PROTEINS, to predict the correct label for a previously unseen graph; a previously unseen graph is one that was not used for training and validating the model.
The DGCNN architecture was proposed in [1] (see Figure 5 in [1]) using the graph convolutional layers from [2] but with a modified propagation rule (see [1] for details). DGCNN introduces a new SortPooling
layer to generate a representation (also know as embedding) for each given graph using as input the representations learned for each node via a stack of graph convolutional layers. The output of the SortPooling
layer is then used as input to one-dimensional convolutional, max pooling, and dense layers that learn graph-level features suitable for predicting graph labels.
References
[1] An End-to-End Deep Learning Architecture for Graph Classification, M. Zhang, Z. Cui, M. Neumann, Y. Chen, AAAI-18. (link)
[2] Semi-supervised Classification with Graph Convolutional Networks, T. N. Kipf and M. Welling, ICLR 2017. (link)
# install StellarGraph if running on Google Colab
import sys
if 'google.colab' in sys.modules:
%pip install -q stellargraph[demos]==1.2.1
# verify that we're using the correct version of StellarGraph for this notebook
import stellargraph as sg
try:
sg.utils.validate_notebook_version("1.2.1")
except AttributeError:
raise ValueError(
f"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed. Please see <https://github.com/stellargraph/stellargraph/issues/1172>."
) from None
import pandas as pd
import numpy as np
import stellargraph as sg
from stellargraph.mapper import PaddedGraphGenerator
from stellargraph.layer import DeepGraphCNN
from stellargraph import StellarGraph
from stellargraph import datasets
from sklearn import model_selection
from IPython.display import display, HTML
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, Conv1D, MaxPool1D, Dropout, Flatten
from tensorflow.keras.losses import binary_crossentropy
import tensorflow as tf
(See the "Loading from Pandas" demo for details on how data can be loaded.)
dataset = datasets.PROTEINS()
display(HTML(dataset.description))
graphs, graph_labels = dataset.load()
The graphs
value is a list of many StellarGraph
instances, each of which has a few node features:
print(graphs[0].info())
StellarGraph: Undirected multigraph Nodes: 42, Edges: 162 Node types: default: [42] Features: float32 vector, length 4 Edge types: default-default->default Edge types: default-default->default: [162] Weights: all 1 (default) Features: none
print(graphs[1].info())
StellarGraph: Undirected multigraph Nodes: 27, Edges: 92 Node types: default: [27] Features: float32 vector, length 4 Edge types: default-default->default Edge types: default-default->default: [92] Weights: all 1 (default) Features: none
Summary statistics of the sizes of the graphs:
summary = pd.DataFrame(
[(g.number_of_nodes(), g.number_of_edges()) for g in graphs],
columns=["nodes", "edges"],
)
summary.describe().round(1)
nodes | edges | |
---|---|---|
count | 1113.0 | 1113.0 |
mean | 39.1 | 145.6 |
std | 45.8 | 169.3 |
min | 4.0 | 10.0 |
25% | 15.0 | 56.0 |
50% | 26.0 | 98.0 |
75% | 45.0 | 174.0 |
max | 620.0 | 2098.0 |
The labels are 1
or 2
:
graph_labels.value_counts().to_frame()
label | |
---|---|
1 | 663 |
2 | 450 |
graph_labels = pd.get_dummies(graph_labels, drop_first=True)
To feed data to the tf.Keras
model that we will create later, we need a data generator. For supervised graph classification, we create an instance of StellarGraph
's PaddedGraphGenerator
class.
generator = PaddedGraphGenerator(graphs=graphs)
We are now ready to create a tf.Keras
graph classification model using StellarGraph
's DeepGraphCNN
class together with standard tf.Keras
layers Conv1D
, MapPool1D
, Dropout
, and Dense
.
The model's input is the graph represented by its adjacency and node features matrices. The first four layers are Graph Convolutional as in [2] but using the adjacency normalisation from [1], $D^{-1}A$ where $A$ is the adjacency matrix with self loops and $D$ is the corresponding degree matrix. The graph convolutional layers each have 32, 32, 32, 1 units and tanh
activations.
The next layer is a one dimensional convolutional layer, Conv1D
, followed by a max pooling, MaxPool1D
, layer. Next is a second Conv1D
layer that is followed by two Dense
layers the second used for binary classification. The convolutional and dense layers use relu
activation except for the last dense layer that uses sigmoid
for classification. As described in [1], we add a Dropout
layer after the first Dense
layer.
First we create the base DGCNN model that includes the graph convolutional and SortPooling
layers.
k = 35 # the number of rows for the output tensor
layer_sizes = [32, 32, 32, 1]
dgcnn_model = DeepGraphCNN(
layer_sizes=layer_sizes,
activations=["tanh", "tanh", "tanh", "tanh"],
k=k,
bias=False,
generator=generator,
)
x_inp, x_out = dgcnn_model.in_out_tensors()
Next, we add the convolutional, max pooling, and dense layers.
x_out = Conv1D(filters=16, kernel_size=sum(layer_sizes), strides=sum(layer_sizes))(x_out)
x_out = MaxPool1D(pool_size=2)(x_out)
x_out = Conv1D(filters=32, kernel_size=5, strides=1)(x_out)
x_out = Flatten()(x_out)
x_out = Dense(units=128, activation="relu")(x_out)
x_out = Dropout(rate=0.5)(x_out)
predictions = Dense(units=1, activation="sigmoid")(x_out)
Finally, we create the Keras
model and prepare it for training by specifying the loss and optimisation algorithm.
model = Model(inputs=x_inp, outputs=predictions)
model.compile(
optimizer=Adam(lr=0.0001), loss=binary_crossentropy, metrics=["acc"],
)
We can now train the model using the model's fit
method.
But first we need to split our data to training and test sets. We are going to use 90% of the data for training and the remaining 10% for testing. This 90/10 split is the equivalent of a single fold in the 10-fold cross validation scheme used in [1].
train_graphs, test_graphs = model_selection.train_test_split(
graph_labels, train_size=0.9, test_size=None, stratify=graph_labels,
)
Given the data split into train and test sets, we create a StellarGraph.PaddedGenerator
generator object that prepares the data for training. We create data generators suitable for training at tf.keras
model by calling the latter generator's flow
method specifying the train and test data.
gen = PaddedGraphGenerator(graphs=graphs)
train_gen = gen.flow(
list(train_graphs.index - 1),
targets=train_graphs.values,
batch_size=50,
symmetric_normalization=False,
)
test_gen = gen.flow(
list(test_graphs.index - 1),
targets=test_graphs.values,
batch_size=1,
symmetric_normalization=False,
)
Note: We set the number of epochs to a large value so the call to model.fit(...)
later might take a long time to complete. For faster performance set epochs
to a smaller value; but if you do accuracy of the model found may be low.
epochs = 100
We can now train the model by calling it's fit
method.
history = model.fit(
train_gen, epochs=epochs, verbose=1, validation_data=test_gen, shuffle=True,
)
['...'] ['...'] Train for 21 steps, validate for 112 steps Epoch 1/100 21/21 [==============================] - 3s 139ms/step - loss: 0.6640 - acc: 0.5824 - val_loss: 0.6188 - val_acc: 0.5982 Epoch 2/100 21/21 [==============================] - 2s 74ms/step - loss: 0.6526 - acc: 0.6234 - val_loss: 0.6003 - val_acc: 0.6429 Epoch 3/100 21/21 [==============================] - 2s 86ms/step - loss: 0.6468 - acc: 0.6643 - val_loss: 0.5987 - val_acc: 0.7411 Epoch 4/100 21/21 [==============================] - 2s 76ms/step - loss: 0.6361 - acc: 0.7123 - val_loss: 0.5843 - val_acc: 0.7321 Epoch 5/100 21/21 [==============================] - 2s 83ms/step - loss: 0.6301 - acc: 0.7143 - val_loss: 0.5786 - val_acc: 0.7500 Epoch 6/100 21/21 [==============================] - 2s 86ms/step - loss: 0.6061 - acc: 0.7073 - val_loss: 0.5716 - val_acc: 0.7500 Epoch 7/100 21/21 [==============================] - 2s 81ms/step - loss: 0.6129 - acc: 0.7173 - val_loss: 0.5626 - val_acc: 0.7500 Epoch 8/100 21/21 [==============================] - 2s 82ms/step - loss: 0.6274 - acc: 0.7163 - val_loss: 0.5637 - val_acc: 0.7411 Epoch 9/100 21/21 [==============================] - 2s 84ms/step - loss: 0.5985 - acc: 0.7243 - val_loss: 0.5606 - val_acc: 0.7411 Epoch 10/100 21/21 [==============================] - 2s 86ms/step - loss: 0.6066 - acc: 0.7223 - val_loss: 0.5568 - val_acc: 0.7411 Epoch 11/100 21/21 [==============================] - 2s 82ms/step - loss: 0.5956 - acc: 0.7273 - val_loss: 0.5530 - val_acc: 0.7411 Epoch 12/100 21/21 [==============================] - 2s 75ms/step - loss: 0.5852 - acc: 0.7203 - val_loss: 0.5493 - val_acc: 0.7500 Epoch 13/100 21/21 [==============================] - 2s 81ms/step - loss: 0.5995 - acc: 0.7233 - val_loss: 0.5482 - val_acc: 0.7500 Epoch 14/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5898 - acc: 0.7303 - val_loss: 0.5452 - val_acc: 0.7411 Epoch 15/100 21/21 [==============================] - 2s 88ms/step - loss: 0.6028 - acc: 0.7233 - val_loss: 0.5467 - val_acc: 0.7589 Epoch 16/100 21/21 [==============================] - 2s 84ms/step - loss: 0.5850 - acc: 0.7223 - val_loss: 0.5444 - val_acc: 0.7500 Epoch 17/100 21/21 [==============================] - 2s 80ms/step - loss: 0.5793 - acc: 0.7243 - val_loss: 0.5436 - val_acc: 0.7589 Epoch 18/100 21/21 [==============================] - 2s 87ms/step - loss: 0.5705 - acc: 0.7133 - val_loss: 0.5413 - val_acc: 0.7500 Epoch 19/100 21/21 [==============================] - 2s 78ms/step - loss: 0.5829 - acc: 0.7263 - val_loss: 0.5426 - val_acc: 0.7411 Epoch 20/100 21/21 [==============================] - 2s 88ms/step - loss: 0.5796 - acc: 0.7133 - val_loss: 0.5423 - val_acc: 0.7411 Epoch 21/100 21/21 [==============================] - 2s 93ms/step - loss: 0.5772 - acc: 0.7053 - val_loss: 0.5397 - val_acc: 0.7321 Epoch 22/100 21/21 [==============================] - 2s 79ms/step - loss: 0.5818 - acc: 0.7143 - val_loss: 0.5378 - val_acc: 0.7500 Epoch 23/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5733 - acc: 0.7133 - val_loss: 0.5381 - val_acc: 0.7321 Epoch 24/100 21/21 [==============================] - 2s 85ms/step - loss: 0.5670 - acc: 0.7143 - val_loss: 0.5390 - val_acc: 0.7321 Epoch 25/100 21/21 [==============================] - 2s 81ms/step - loss: 0.5688 - acc: 0.7143 - val_loss: 0.5374 - val_acc: 0.7321 Epoch 26/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5671 - acc: 0.7103 - val_loss: 0.5372 - val_acc: 0.7232 Epoch 27/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5639 - acc: 0.7103 - val_loss: 0.5362 - val_acc: 0.7232 Epoch 28/100 21/21 [==============================] - 2s 96ms/step - loss: 0.5732 - acc: 0.7143 - val_loss: 0.5377 - val_acc: 0.7321 Epoch 29/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5655 - acc: 0.7073 - val_loss: 0.5363 - val_acc: 0.7232 Epoch 30/100 21/21 [==============================] - 2s 82ms/step - loss: 0.5683 - acc: 0.7153 - val_loss: 0.5366 - val_acc: 0.7321 Epoch 31/100 21/21 [==============================] - 2s 84ms/step - loss: 0.5752 - acc: 0.7203 - val_loss: 0.5345 - val_acc: 0.7232 Epoch 32/100 21/21 [==============================] - 2s 96ms/step - loss: 0.5778 - acc: 0.7183 - val_loss: 0.5392 - val_acc: 0.7321 Epoch 33/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5649 - acc: 0.7253 - val_loss: 0.5352 - val_acc: 0.7500 Epoch 34/100 21/21 [==============================] - 2s 87ms/step - loss: 0.5700 - acc: 0.7153 - val_loss: 0.5337 - val_acc: 0.7321 Epoch 35/100 21/21 [==============================] - 2s 74ms/step - loss: 0.5621 - acc: 0.7083 - val_loss: 0.5358 - val_acc: 0.7411 Epoch 36/100 21/21 [==============================] - 2s 83ms/step - loss: 0.5729 - acc: 0.7273 - val_loss: 0.5371 - val_acc: 0.7232 Epoch 37/100 21/21 [==============================] - 2s 84ms/step - loss: 0.5735 - acc: 0.7153 - val_loss: 0.5316 - val_acc: 0.7321 Epoch 38/100 21/21 [==============================] - 2s 92ms/step - loss: 0.5694 - acc: 0.7043 - val_loss: 0.5309 - val_acc: 0.7411 Epoch 39/100 21/21 [==============================] - 2s 88ms/step - loss: 0.5589 - acc: 0.7173 - val_loss: 0.5315 - val_acc: 0.7411 Epoch 40/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5687 - acc: 0.7163 - val_loss: 0.5314 - val_acc: 0.7321 Epoch 41/100 21/21 [==============================] - ETA: 0s - loss: 0.5534 - acc: 0.728 - 2s 86ms/step - loss: 0.5523 - acc: 0.7283 - val_loss: 0.5301 - val_acc: 0.7411 Epoch 42/100 21/21 [==============================] - 2s 93ms/step - loss: 0.5596 - acc: 0.7113 - val_loss: 0.5306 - val_acc: 0.7411 Epoch 43/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5518 - acc: 0.7193 - val_loss: 0.5293 - val_acc: 0.7500 Epoch 44/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5579 - acc: 0.7153 - val_loss: 0.5299 - val_acc: 0.7500 Epoch 45/100 21/21 [==============================] - 2s 82ms/step - loss: 0.5565 - acc: 0.7253 - val_loss: 0.5276 - val_acc: 0.7500 Epoch 46/100 21/21 [==============================] - 2s 83ms/step - loss: 0.5576 - acc: 0.7113 - val_loss: 0.5294 - val_acc: 0.7500 Epoch 47/100 21/21 [==============================] - 2s 83ms/step - loss: 0.5624 - acc: 0.7203 - val_loss: 0.5291 - val_acc: 0.7500 Epoch 48/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5552 - acc: 0.7223 - val_loss: 0.5268 - val_acc: 0.7500 Epoch 49/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5536 - acc: 0.7223 - val_loss: 0.5250 - val_acc: 0.7589 Epoch 50/100 21/21 [==============================] - 2s 98ms/step - loss: 0.5693 - acc: 0.7153 - val_loss: 0.5281 - val_acc: 0.7589 Epoch 51/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5521 - acc: 0.7243 - val_loss: 0.5256 - val_acc: 0.7589 Epoch 52/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5536 - acc: 0.7203 - val_loss: 0.5217 - val_acc: 0.7589 Epoch 53/100 21/21 [==============================] - 2s 93ms/step - loss: 0.5489 - acc: 0.7143 - val_loss: 0.5197 - val_acc: 0.7679 Epoch 54/100 21/21 [==============================] - 2s 88ms/step - loss: 0.5478 - acc: 0.7283 - val_loss: 0.5211 - val_acc: 0.7679 Epoch 55/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5569 - acc: 0.7263 - val_loss: 0.5201 - val_acc: 0.7589 Epoch 56/100 21/21 [==============================] - 2s 101ms/step - loss: 0.5530 - acc: 0.7183 - val_loss: 0.5204 - val_acc: 0.7857 Epoch 57/100 21/21 [==============================] - 2s 91ms/step - loss: 0.5453 - acc: 0.7183 - val_loss: 0.5171 - val_acc: 0.7768 Epoch 58/100 21/21 [==============================] - 2s 88ms/step - loss: 0.5390 - acc: 0.7303 - val_loss: 0.5161 - val_acc: 0.7857 Epoch 59/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5410 - acc: 0.7283 - val_loss: 0.5128 - val_acc: 0.7857 Epoch 60/100 21/21 [==============================] - 2s 97ms/step - loss: 0.5602 - acc: 0.7213 - val_loss: 0.5173 - val_acc: 0.7679 Epoch 61/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5449 - acc: 0.7243 - val_loss: 0.5138 - val_acc: 0.7768 Epoch 62/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5492 - acc: 0.7243 - val_loss: 0.5125 - val_acc: 0.7768 Epoch 63/100 21/21 [==============================] - 2s 84ms/step - loss: 0.5466 - acc: 0.7213 - val_loss: 0.5161 - val_acc: 0.7768 Epoch 64/100 21/21 [==============================] - 2s 83ms/step - loss: 0.5475 - acc: 0.7213 - val_loss: 0.5135 - val_acc: 0.7768 Epoch 65/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5409 - acc: 0.7243 - val_loss: 0.5125 - val_acc: 0.7857 Epoch 66/100 21/21 [==============================] - 2s 95ms/step - loss: 0.5404 - acc: 0.7303 - val_loss: 0.5095 - val_acc: 0.7857 Epoch 67/100 21/21 [==============================] - 2s 85ms/step - loss: 0.5453 - acc: 0.7213 - val_loss: 0.5029 - val_acc: 0.7857 Epoch 68/100 21/21 [==============================] - 2s 88ms/step - loss: 0.5374 - acc: 0.7293 - val_loss: 0.5086 - val_acc: 0.7768 Epoch 69/100 21/21 [==============================] - 2s 97ms/step - loss: 0.5409 - acc: 0.7353 - val_loss: 0.5077 - val_acc: 0.7768 Epoch 70/100 21/21 [==============================] - 2s 92ms/step - loss: 0.5439 - acc: 0.7293 - val_loss: 0.5043 - val_acc: 0.7857 Epoch 71/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5330 - acc: 0.7313 - val_loss: 0.5090 - val_acc: 0.7768 Epoch 72/100 21/21 [==============================] - 2s 82ms/step - loss: 0.5328 - acc: 0.7303 - val_loss: 0.5092 - val_acc: 0.7768 Epoch 73/100 21/21 [==============================] - 2s 84ms/step - loss: 0.5333 - acc: 0.7273 - val_loss: 0.5098 - val_acc: 0.7857 Epoch 74/100 21/21 [==============================] - 2s 96ms/step - loss: 0.5384 - acc: 0.7313 - val_loss: 0.5049 - val_acc: 0.7679 Epoch 75/100 21/21 [==============================] - 2s 83ms/step - loss: 0.5417 - acc: 0.7233 - val_loss: 0.5086 - val_acc: 0.7768 Epoch 76/100 21/21 [==============================] - 2s 81ms/step - loss: 0.5364 - acc: 0.7253 - val_loss: 0.5088 - val_acc: 0.7589 Epoch 77/100 21/21 [==============================] - 2s 89ms/step - loss: 0.5365 - acc: 0.7313 - val_loss: 0.5083 - val_acc: 0.7768 Epoch 78/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5378 - acc: 0.7363 - val_loss: 0.5084 - val_acc: 0.7679 Epoch 79/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5373 - acc: 0.7293 - val_loss: 0.5049 - val_acc: 0.7768 Epoch 80/100 21/21 [==============================] - 2s 87ms/step - loss: 0.5344 - acc: 0.7373 - val_loss: 0.5063 - val_acc: 0.7679 Epoch 81/100 21/21 [==============================] - 2s 87ms/step - loss: 0.5344 - acc: 0.7313 - val_loss: 0.5039 - val_acc: 0.7679 Epoch 82/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5304 - acc: 0.7363 - val_loss: 0.5078 - val_acc: 0.7589 Epoch 83/100 21/21 [==============================] - 2s 93ms/step - loss: 0.5382 - acc: 0.7303 - val_loss: 0.5116 - val_acc: 0.7589 Epoch 84/100 21/21 [==============================] - 2s 79ms/step - loss: 0.5315 - acc: 0.7293 - val_loss: 0.4988 - val_acc: 0.7500 Epoch 85/100 21/21 [==============================] - 2s 91ms/step - loss: 0.5358 - acc: 0.7293 - val_loss: 0.4974 - val_acc: 0.7679 Epoch 86/100 21/21 [==============================] - 2s 77ms/step - loss: 0.5424 - acc: 0.7283 - val_loss: 0.5009 - val_acc: 0.7679 Epoch 87/100 21/21 [==============================] - 2s 88ms/step - loss: 0.5300 - acc: 0.7403 - val_loss: 0.5085 - val_acc: 0.7768 Epoch 88/100 21/21 [==============================] - 2s 82ms/step - loss: 0.5436 - acc: 0.7253 - val_loss: 0.5046 - val_acc: 0.7500 Epoch 89/100 21/21 [==============================] - 2s 90ms/step - loss: 0.5346 - acc: 0.7323 - val_loss: 0.5002 - val_acc: 0.7589 Epoch 90/100 21/21 [==============================] - 2s 91ms/step - loss: 0.5323 - acc: 0.7373 - val_loss: 0.5056 - val_acc: 0.7679 Epoch 91/100 21/21 [==============================] - 2s 93ms/step - loss: 0.5290 - acc: 0.7313 - val_loss: 0.5071 - val_acc: 0.7589 Epoch 92/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5340 - acc: 0.7313 - val_loss: 0.5086 - val_acc: 0.7679 Epoch 93/100 21/21 [==============================] - 2s 98ms/step - loss: 0.5271 - acc: 0.7313 - val_loss: 0.5063 - val_acc: 0.7679 Epoch 94/100 21/21 [==============================] - 2s 83ms/step - loss: 0.5236 - acc: 0.7413 - val_loss: 0.5102 - val_acc: 0.7679 Epoch 95/100 21/21 [==============================] - 2s 86ms/step - loss: 0.5237 - acc: 0.7333 - val_loss: 0.5103 - val_acc: 0.7411 Epoch 96/100 21/21 [==============================] - 2s 95ms/step - loss: 0.5196 - acc: 0.7353 - val_loss: 0.5110 - val_acc: 0.7768 Epoch 97/100 21/21 [==============================] - 2s 94ms/step - loss: 0.5250 - acc: 0.7293 - val_loss: 0.5076 - val_acc: 0.7411 Epoch 98/100 21/21 [==============================] - 2s 87ms/step - loss: 0.5259 - acc: 0.7403 - val_loss: 0.5087 - val_acc: 0.7679 Epoch 99/100 21/21 [==============================] - 2s 99ms/step - loss: 0.5315 - acc: 0.7413 - val_loss: 0.5080 - val_acc: 0.7679 Epoch 100/100 21/21 [==============================] - 2s 93ms/step - loss: 0.5292 - acc: 0.7313 - val_loss: 0.5223 - val_acc: 0.7589
Let us plot the training history (losses and accuracies for the train and test data).
sg.utils.plot_history(history)
Finally, let us calculate the performance of the trained model on the test data.
test_metrics = model.evaluate(test_gen)
print("\nTest Set Metrics:")
for name, val in zip(model.metrics_names, test_metrics):
print("\t{}: {:0.4f}".format(name, val))
['...'] 112/112 [==============================] - 0s 1ms/step - loss: 0.5223 - acc: 0.7589 Test Set Metrics: loss: 0.5223 acc: 0.7589
We demonstrated the use of StellarGraph
's DeepGraphCNN
implementation for supervised graph classification algorithm. More specifically we showed how to predict whether a chemical compound represented as a graph is an enzyme or not.
Performance is similar to that reported in [1] but a small difference does exist. This difference can be attributed to a small number of factors listed below,