In this exercise, we investigate semi-supervised node classification using Graph Convolutional Networks on Zachary’s Karate Club dataset introduced in Example 10.2. Sometime ago there was a dispute between the manager and the coach of the karate club which led to a split of the club into four groups.
Can we use Graph Convolutional Networks to predict the affiliation of each member given the social network of the community and the memberships of only four people?
The exercise uses spektral and networkx. If you don't have yet installed both packages, do so by executing:
import sys
!{sys.executable} -m pip install spektral
!{sys.executable} -m pip install networkx
import keras
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import spektral
print("spektral", spektral.__version__)
print("keras", keras.__version__)
spektral 1.3.0 keras 2.13.1
import gdown
import os
url = "https://drive.google.com/u/0/uc?export=download&confirm=HgGH&id=1OugMZz6VVBjWy0uxsG_rrPdrYdPLklzD"
output = 'karate_club.npz'
if os.path.exists(output) == False:
gdown.download(url, output, quiet=True)
f = np.load(output)
adj, features = f["adj"], f["features"]
print("adjacency matrix\n", adj)
adjacency matrix [[0. 1. 1. ... 1. 0. 0.] [1. 0. 1. ... 0. 0. 0.] [1. 1. 0. ... 0. 1. 0.] ... [1. 0. 0. ... 0. 1. 1.] [0. 0. 1. ... 1. 0. 1.] [0. 0. 0. ... 1. 1. 0.]]
print("features\n", features)
features [[1. 0. 0. ... 0. 0. 0.] [0. 1. 0. ... 0. 0. 0.] [0. 0. 1. ... 0. 0. 0.] ... [0. 0. 0. ... 1. 0. 0.] [0. 0. 0. ... 0. 1. 0.] [0. 0. 0. ... 0. 0. 1.]]
labels_one_hot = f["labels_one_hot"]
def one_hot_to_labels(labels_one_hot):
return np.sum([(labels_one_hot[:, i] == 1) * (i + 1) for i in range(4)], axis=0)
labels = one_hot_to_labels(labels_one_hot)
print("labels:", labels)
labels: [2 2 3 2 1 1 1 2 4 3 1 2 2 2 4 4 1 2 4 2 4 2 4 4 3 3 4 3 3 4 4 3 4 4]
g = nx.from_numpy_array(adj) # define nx graph
fig, _ = plt.subplots(1)
nx.draw(g, pos=nx.random_layout(g), cmap=plt.get_cmap('jet'), node_color=np.log(one_hot_to_labels(labels_one_hot)),
node_size=np.sum(200 * labels_one_hot, axis=-1) + 150)
plt.tight_layout()
Each node symbolizes one member of the Karate Club, and the edges indicate a close social relationship. The colors indicate the group affiliation of each member.
We can further make a more sophisticated visualization of the data, by plotting the graph using the spring_layout
np.random.seed(2)
fig, _ = plt.subplots(1)
nx.draw(g, pos=nx.spring_layout(g), cmap=plt.get_cmap('jet'), node_color=np.log(one_hot_to_labels(labels_one_hot)),
node_size= np.sum(200 * labels_one_hot, axis=-1) + 150)
plt.tight_layout()
In the following, we prepare our data. Let us assume that after the splitting of the karate club, we only have information from 4 members. In this case, each member is part of another group.
This will give us a nice example for a weakly supervised learning task.
np.random.seed(2)
# Pick randomly one karate fighter from each class
labels_to_keep = np.array([np.random.choice(np.nonzero(labels_one_hot[:, c])[0]) for c in range(4)])
mask = np.zeros(shape=labels_one_hot.shape[0], dtype=np.bool)
mask[labels_to_keep] = ~mask[labels_to_keep]
<ipython-input-9-9a99e13501e2>:5: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations mask = np.zeros(shape=labels_one_hot.shape[0], dtype=np.bool)
np.random.seed(2)
fig, axes = plt.subplots(1)
nx.draw(g, cmap=plt.get_cmap('jet'), node_color="grey",
node_size=150)
np.seterr(divide = 'ignore')
np.random.seed(2)
nx.draw(g, cmap=plt.get_cmap('jet'), node_color=np.log(one_hot_to_labels(labels_one_hot * mask[:,np.newaxis])),
node_size=450, ax=axes)
np.seterr(divide = 'warn')
plt.tight_layout()
/usr/local/lib/python3.10/dist-packages/networkx/drawing/nx_pylab.py:433: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored node_collection = ax.scatter(
These data we will now use for performing semi-supervised node classification using graph convolutional networks.
In the following we preprocess the data and create a Graph Convolutional Network to classify the nodes of the graph (determine the membership of each karate fighter). For more details see Sec.10.4.1 of the book.
Additionally, we create a mask for masking the memberships of all karate fighter except the four members (labels_to_keep
) when training the GCN.
train_mask = np.zeros(shape=labels_one_hot.shape[0], dtype=np.bool)
train_mask[labels_to_keep] = ~train_mask[labels_to_keep]
val_mask = ~train_mask
print("val_mask:\n", val_mask)
print("\ntrain_mask:\n", train_mask)
val_mask: [ True True True True False True True True True True True True True True True True True True True False True True True True True True True True True True True False True False] train_mask: [False False False False True False False False False False False False False False False False False False False True False False False False False False False False False False False True False True]
<ipython-input-11-98dfbd763aa5>:1: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations train_mask = np.zeros(shape=labels_one_hot.shape[0], dtype=np.bool)
# Preprocessing and preparing of data
y_train = labels_one_hot * train_mask[..., np.newaxis]
y_val = labels_one_hot * val_mask[..., np.newaxis]
fltr = spektral.layers.GCNConv.preprocess(adj).astype('f4') # normalize the adjacency matrix (See Sec.10.4.1)
X = np.identity(34) # create input for the DNN (the existence of each person (one-hot encoded))
To add a GCN layer to the model use spektral.layers.GCNConv()[feature_input, adjacency]
, where feature_input
denotes the input features and adjacency
the normalized (pre-processed) adjacency matrix ($\hat{A}$).
Note that the adjacency matrix has to be passed to each GCN layer.
F = 4 # number of features
N = adj.shape[0] # number of nodes
X_in = keras.layers.Input(shape=(N,))
fltr_in = keras.layers.Input(shape=(N,))
x = spektral.layers.GCNConv(F, activation='tanh', use_bias=False)([X_in, fltr_in])
x = keras.layers.Dropout(0.4)(x)
x = spektral.layers.GCNConv(F, activation='tanh', use_bias=False)([x, fltr_in])
x = keras.layers.Dropout(0.4)(x)
x = spektral.layers.GCNConv(2, activation='tanh', use_bias=False, name="embedding")([x, fltr_in])
x = keras.layers.Dropout(0.4)(x)
output = spektral.layers.GCNConv(4, activation='softmax', use_bias=False)([x, fltr_in])
model = keras.models.Model(inputs=[X_in, fltr_in], outputs=output)
print(model.summary())
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 34)] 0 [] input_2 (InputLayer) [(None, 34)] 0 [] gcn_conv (GCNConv) (None, 4) 136 ['input_1[0][0]', 'input_2[0][0]'] dropout (Dropout) (None, 4) 0 ['gcn_conv[0][0]'] gcn_conv_1 (GCNConv) (None, 4) 16 ['dropout[0][0]', 'input_2[0][0]'] dropout_1 (Dropout) (None, 4) 0 ['gcn_conv_1[0][0]'] embedding (GCNConv) (None, 2) 8 ['dropout_1[0][0]', 'input_2[0][0]'] dropout_2 (Dropout) (None, 2) 0 ['embedding[0][0]'] gcn_conv_2 (GCNConv) (None, 4) 8 ['dropout_2[0][0]', 'input_2[0][0]'] ================================================================================================== Total params: 168 (672.00 Byte) Trainable params: 168 (672.00 Byte) Non-trainable params: 0 (0.00 Byte) __________________________________________________________________________________________________ None
For training the model, you can make use of the code skeletons below. To mask during training the DNN predictions for nodes for which labels are not known masking must be applied. This will guarantee that only the predictions made for the four nodes contribute to the objective.
For implementing this condition to perform semi-supervised node classification, you can make use of the sample_weight
argument of model.train_on_batch()
.
Note that in this exercise, in contrast to most other exercises, we have to train the network on a single data structure (an undirected graph). Thus, the input data are always [X, fltr]
, and the targets are labels_one_hot
.
learning_rate = 0.01
epochs = 2000
model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate, weight_decay=1e-3),
loss='categorical_crossentropy',
weighted_metrics=['acc'])
history = []
for i in range(epochs):
loss, acc = model.train_on_batch([X, fltr], labels_one_hot,
sample_weight=train_mask)
val_loss, val_acc = model.test_on_batch([X, fltr], labels_one_hot, sample_weight=val_mask)
history.append([val_loss, val_acc])
if i % 100 == 0:
print("iteration:", i, "val_loss:", val_loss, "val_accuracy:", val_acc)
print("iteration:", i, "loss:", loss, "accuracy:", acc)
iteration: 0 val_loss: 1.2282094955444336 val_accuracy: 0.36666667461395264 iteration: 0 loss: 0.16408926248550415 accuracy: 0.25 iteration: 100 val_loss: 0.47007229924201965 val_accuracy: 0.9666666388511658 iteration: 100 loss: 0.06023971363902092 accuracy: 1.0 iteration: 200 val_loss: 0.2796380817890167 val_accuracy: 0.8999999761581421 iteration: 200 loss: 0.03392055630683899 accuracy: 1.0 iteration: 300 val_loss: 0.19965791702270508 val_accuracy: 0.9666666388511658 iteration: 300 loss: 0.07340727746486664 accuracy: 0.75 iteration: 400 val_loss: 0.17798921465873718 val_accuracy: 0.9666666388511658 iteration: 400 loss: 0.033457085490226746 accuracy: 1.0 iteration: 500 val_loss: 0.1554412543773651 val_accuracy: 0.9666666388511658 iteration: 500 loss: 0.036021068692207336 accuracy: 1.0 iteration: 600 val_loss: 0.15126177668571472 val_accuracy: 0.9666666388511658 iteration: 600 loss: 0.01682485267519951 accuracy: 1.0 iteration: 700 val_loss: 0.14785702526569366 val_accuracy: 0.9333333373069763 iteration: 700 loss: 0.020003214478492737 accuracy: 1.0 iteration: 800 val_loss: 0.14266347885131836 val_accuracy: 0.9666666388511658 iteration: 800 loss: 0.013497387990355492 accuracy: 1.0 iteration: 900 val_loss: 0.13860034942626953 val_accuracy: 0.9666666388511658 iteration: 900 loss: 0.043812740594148636 accuracy: 0.75 iteration: 1000 val_loss: 0.1423792541027069 val_accuracy: 0.9333333373069763 iteration: 1000 loss: 0.02671867422759533 accuracy: 1.0 iteration: 1100 val_loss: 0.11848742514848709 val_accuracy: 0.9666666388511658 iteration: 1100 loss: 0.00938910711556673 accuracy: 1.0 iteration: 1200 val_loss: 0.140360489487648 val_accuracy: 0.9333333373069763 iteration: 1200 loss: 0.010986154899001122 accuracy: 1.0 iteration: 1300 val_loss: 0.1254778951406479 val_accuracy: 0.9666666388511658 iteration: 1300 loss: 0.007421881891787052 accuracy: 1.0 iteration: 1400 val_loss: 0.1252700537443161 val_accuracy: 0.9666666388511658 iteration: 1400 loss: 0.007992210797965527 accuracy: 1.0 iteration: 1500 val_loss: 0.11396044492721558 val_accuracy: 0.9666666388511658 iteration: 1500 loss: 0.06013913080096245 accuracy: 0.75 iteration: 1600 val_loss: 0.11579329520463943 val_accuracy: 0.9666666388511658 iteration: 1600 loss: 0.009962714277207851 accuracy: 1.0 iteration: 1700 val_loss: 0.11803137511014938 val_accuracy: 0.9666666388511658 iteration: 1700 loss: 0.014591223560273647 accuracy: 1.0 iteration: 1800 val_loss: 0.11534292995929718 val_accuracy: 0.9666666388511658 iteration: 1800 loss: 0.013465136289596558 accuracy: 1.0 iteration: 1900 val_loss: 0.11544264853000641 val_accuracy: 0.9666666388511658 iteration: 1900 loss: 0.010754341259598732 accuracy: 1.0
fig, axes = plt.subplots(2, figsize=(12,8))
if type(history) == dict:
loss = history["val_loss"]
acc = history["val_acc"]
else:
loss, acc = np.split(np.array(history), 2, axis=-1)
x = np.arange(len(loss))
axes[0].plot(x, loss, c="navy")
axes[0].set_yscale("log")
axes[0].set_ylabel("Validation loss")
axes[1].plot(x, acc, c="firebrick")
axes[1].set_ylabel("Validation accuracy")
axes[1].set_ylim(0, 1)
axes[0].set_xlabel("Iterations")
axes[1].set_xlabel("Iterations")
plt.tight_layout()