!date
Thu Jan 5 13:34:48 UTC 2023
This is an example for how to use skorch with torch geometric. The code is based on the introduction example but modified to have a proper train/valid/test split. This example is showcasing a quite small data set that does not need to employ batching to be trained efficiently. How to do batching with skorch + torch geometric will not be handled here since it is non-trivial and quite dataset specific - if you need this and are stuck, feel free to open an issue so that we can support you the best we can.
Dependencies of this notebook besides skorch base installation:
It is recommended to install the dependencies as documented by pytorch geometric.
import subprocess
# Installation on Google Colab
try:
import google.colab
import torch
subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch_geometric'])
subprocess.run(['python', '-m', 'pip', 'install', 'torch-sparse' , '-f', f'https://data.pyg.org/whl/torch-{torch.__version__}.html'])
subprocess.run(['python', '-m', 'pip', 'install', 'torch-scatter' , '-f', f'https://data.pyg.org/whl/torch-{torch.__version__}.html'])
except ImportError:
pass
import skorch
import torch
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index Processing... Done!
dataset.data, dataset.num_classes
(Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708]), 7)
In order to use pytorch geometric / the cora dataset with skorch we need to address the following things:
train_mask
, val_mask
, test_mask
)To deal with (1) we will split the data into three datasets, creating three sub-graphs in the process; these complete sub-graphs can then be convolved over without errors. We use the masks mentioned in (2) to identify the nodes and edges of the subgraphs.
(3) will be handled by specifying our own XYDataset
which will just have length 1 and return the dataset and the respective y values. We will therefore basically simulate a batch_size=1
scenario.
from torch_geometric.data import Data
# simulating batch_size=1 by returning the whole dataset and the
# y-values. this way, the data loader can iterate over the 'batches'
# and produce X/y values for us.
class XYDataset(torch.utils.data.Dataset):
def __init__(self, data: Data, y: torch.tensor):
self.data = data
self.y = y
def __len__(self):
return 1
def __getitem__(self, i):
return self.data, self.y
Split the graph into train, validation and test sub-graphs. This ensures that there will be no leakage between steps when we apply graph convolution operators on the graph since each split has its own sub-graph.
We use relabel_nodes=True
to make the node indices in the edge tensor
zero-based for each sub-graph. If we would not do this the node subsets
(now zero-based after applying the mask) would not match the indices in the
edge tensor.
from torch_geometric.utils import subgraph
data = dataset[0]
edge_index_train, _ = subgraph(
subset=data.train_mask,
edge_index=data.edge_index,
relabel_nodes=True
)
ds_train = XYDataset(
Data(x=data.x[data.train_mask], edge_index=edge_index_train),
data.y[data.train_mask],
)
edge_index_valid, _ = subgraph(
subset=data.val_mask,
edge_index=data.edge_index,
relabel_nodes=True
)
ds_valid = XYDataset(
Data(x=data.x[data.val_mask], edge_index=edge_index_valid),
data.y[data.val_mask],
)
edge_index_test, _ = subgraph(
subset=data.test_mask,
edge_index=data.edge_index,
relabel_nodes=True
)
ds_test = XYDataset(
Data(x=data.x[data.test_mask], edge_index=edge_index_test),
data.y[data.test_mask],
)
Our "batch" consists of the whole dataset so if we unpack the
batch into (X, y)
we will have X = Data(...)
and y = [y_true]
.
The DataLoader
does not modify X
but y
gets a new batch dimension.
This will lead to a shape mismatch as y.shape
would then be (1, #num_samples)
. Therefore, we need our own loader that strips the first dimension to
match the predicted y
and the labelled y
in length.
Note: It is possible to avoid this by stripping this dimension by overriding get_loss
in the NeuralNet
class. For brevity we won't do this in this example. It is possible to use one of the many DataLoader
classes provided by torch geometric using the approach outlined below (just base the RawDataloader
on one of the other classes) - chances are, though, that if you are doing this you need to deal with batching anyway which is a topic that is not handled here since it is not trivial.
from torch_geometric.loader import DataLoader
class RawLoader(DataLoader):
def __iter__(self):
it = super().__iter__()
for X, y in it:
yield X, y[0]
This is the CORA example module as seen in the torch geometric introduction.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.softmax(x, dim=1)
from skorch.helper import predefined_split
torch.manual_seed(42)
net = skorch.NeuralNetClassifier(
module=GCN,
lr=0.1,
optimizer__weight_decay=5e-4,
max_epochs=200,
train_split=skorch.helper.predefined_split(ds_valid),
batch_size=1,
iterator_train=RawLoader,
iterator_valid=RawLoader,
)
net.fit(ds_train, None)
epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------ 1 1.9724 0.1680 1.9398 0.0963 2 1.9625 0.1740 1.9376 0.0091 3 1.9327 0.1720 1.9342 0.0115 4 1.9321 0.1760 1.9324 0.0096 5 1.9142 0.1800 1.9307 0.0069 6 1.8923 0.1800 1.9290 0.0090 7 1.8848 0.1880 1.9269 0.0092 8 1.8936 0.1920 1.9247 0.0152 9 1.8783 0.1960 1.9219 0.0082 10 1.8737 0.2040 1.9192 0.0082 11 1.8542 0.2060 1.9176 0.0070 12 1.8489 0.2100 1.9156 0.0078 13 1.8314 0.2120 1.9121 0.0061 14 1.8334 0.2220 1.9100 0.0076 15 1.8041 0.2240 1.9085 0.0080 16 1.8089 0.2220 1.9065 0.0101 17 1.8082 0.2200 1.9043 0.0078 18 1.7759 0.2220 1.9023 0.0064 19 1.7745 0.2220 1.8992 0.0076 20 1.7630 0.2280 1.8970 0.0089 21 1.7411 0.2340 1.8946 0.0087 22 1.7732 0.2360 1.8921 0.0086 23 1.7407 0.2420 1.8893 0.0077 24 1.7259 0.2420 1.8857 0.0092 25 1.6920 0.2520 1.8836 0.0069 26 1.7033 0.2540 1.8805 0.0133 27 1.7080 0.2580 1.8767 0.0090 28 1.6924 0.2620 1.8741 0.0072 29 1.6882 0.2620 1.8703 0.0075 30 1.6850 0.2660 1.8679 0.0074 31 1.6438 0.2580 1.8650 0.0073 32 1.6345 0.2680 1.8618 0.0083 33 1.6816 0.2660 1.8579 0.0084 34 1.6169 0.2620 1.8559 0.0085 35 1.6373 0.2720 1.8522 0.0093 36 1.6107 0.2700 1.8491 0.0084 37 1.6035 0.2800 1.8449 0.0080 38 1.6060 0.2840 1.8421 0.0079 39 1.5604 0.2960 1.8389 0.0080 40 1.5724 0.3060 1.8354 0.0102 41 1.5371 0.3160 1.8319 0.0071 42 1.5246 0.3240 1.8281 0.0073 43 1.5524 0.3200 1.8241 0.0076 44 1.5282 0.3300 1.8211 0.0098 45 1.5356 0.3380 1.8169 0.0109 46 1.5079 0.3440 1.8137 0.0082 47 1.5192 0.3500 1.8090 0.0077 48 1.4991 0.3540 1.8063 0.0078 49 1.4949 0.3460 1.8036 0.0086 50 1.4892 0.3640 1.8000 0.0100 51 1.5165 0.3760 1.7968 0.0082 52 1.4367 0.3740 1.7931 0.0081 53 1.4473 0.3700 1.7894 0.0081 54 1.4387 0.3840 1.7855 0.0100 55 1.4261 0.3840 1.7825 0.0081 56 1.4355 0.4040 1.7768 0.0085 57 1.4270 0.3900 1.7749 0.0082 58 1.4029 0.4000 1.7714 0.0085 59 1.3793 0.4040 1.7679 0.0085 60 1.3493 0.4020 1.7629 0.0086 61 1.3624 0.4160 1.7597 0.0081 62 1.3970 0.4180 1.7562 0.0085 63 1.3552 0.4220 1.7516 0.0110 64 1.3745 0.4240 1.7480 0.0064 65 1.4002 0.4260 1.7448 0.0087 66 1.2924 0.4280 1.7405 0.0083 67 1.2954 0.4300 1.7375 0.0106 68 1.2785 0.4320 1.7319 0.0074 69 1.3192 0.4300 1.7290 0.0089 70 1.3049 0.4360 1.7246 0.0063 71 1.2504 0.4420 1.7198 0.0094 72 1.2841 0.4340 1.7165 0.0085 73 1.2304 0.4460 1.7120 0.0071 74 1.2414 0.4540 1.7070 0.0062 75 1.1753 0.4520 1.7020 0.0089 76 1.2608 0.4580 1.6981 0.0086 77 1.2053 0.4580 1.6935 0.0083 78 1.2640 0.4600 1.6910 0.0077 79 1.2251 0.4700 1.6845 0.0082 80 1.2221 0.4780 1.6801 0.0063 81 1.1499 0.4760 1.6761 0.0076 82 1.1761 0.4820 1.6727 0.0081 83 1.1286 0.4880 1.6673 0.0073 84 1.1338 0.4920 1.6634 0.0074 85 1.1273 0.4940 1.6593 0.0076 86 1.1289 0.4900 1.6548 0.0085 87 1.1618 0.4960 1.6512 0.0083 88 1.1306 0.4980 1.6474 0.0085 89 1.1436 0.5000 1.6438 0.0102 90 1.0675 0.5020 1.6397 0.0087 91 1.0798 0.5000 1.6360 0.0086 92 1.1148 0.4980 1.6330 0.0079 93 1.0830 0.5040 1.6276 0.0077 94 1.1569 0.5020 1.6246 0.0076 95 1.0338 0.5020 1.6197 0.0075 96 1.0800 0.5100 1.6139 0.0085 97 1.0869 0.5080 1.6121 0.0085 98 1.1144 0.5100 1.6074 0.0087 99 1.0271 0.5060 1.6045 0.0083 100 1.0465 0.5100 1.6020 0.0141 101 1.0348 0.5200 1.5963 0.0105 102 1.0045 0.5220 1.5936 0.0086 103 1.0307 0.5260 1.5895 0.0103 104 0.9970 0.5300 1.5839 0.0097 105 0.9644 0.5300 1.5814 0.0093 106 0.9879 0.5320 1.5770 0.0094 107 0.9986 0.5320 1.5732 0.0094 108 0.9234 0.5320 1.5692 0.0096 109 0.9704 0.5300 1.5642 0.0064 110 1.0256 0.5300 1.5621 0.0060 111 0.9590 0.5240 1.5585 0.0078 112 1.0168 0.5300 1.5565 0.0087 113 0.9994 0.5320 1.5534 0.0078 114 0.9635 0.5320 1.5492 0.0113 115 0.9872 0.5340 1.5452 0.0079 116 0.9749 0.5340 1.5411 0.0126 117 0.9667 0.5340 1.5392 0.0104 118 0.8757 0.5300 1.5351 0.0087 119 0.9306 0.5340 1.5340 0.0087 120 0.8284 0.5380 1.5300 0.0084 121 0.8389 0.5400 1.5254 0.0127 122 0.9347 0.5440 1.5226 0.0165 123 0.8502 0.5340 1.5207 0.0126 124 0.8519 0.5480 1.5163 0.0096 125 0.8536 0.5460 1.5127 0.0094 126 0.8926 0.5480 1.5082 0.0112 127 0.8605 0.5480 1.5050 0.0105 128 0.8853 0.5540 1.5037 0.0120 129 0.8483 0.5540 1.4992 0.0102 130 0.8745 0.5540 1.4954 0.0103 131 0.7866 0.5500 1.4940 0.0107 132 0.8322 0.5520 1.4901 0.0108 133 0.8019 0.5520 1.4864 0.0105 134 0.8829 0.5540 1.4846 0.0103 135 0.8545 0.5560 1.4829 0.0088 136 0.9028 0.5560 1.4802 0.0090 137 0.8797 0.5540 1.4794 0.0102 138 0.7967 0.5560 1.4744 0.0087 139 0.7614 0.5560 1.4723 0.0088 140 0.8399 0.5580 1.4696 0.0097 141 0.8502 0.5580 1.4671 0.0094 142 0.7301 0.5580 1.4648 0.0087 143 0.7543 0.5640 1.4617 0.0102 144 0.7023 0.5620 1.4580 0.0090 145 0.7329 0.5640 1.4561 0.0090 146 0.7820 0.5640 1.4526 0.0100 147 0.8137 0.5640 1.4521 0.0079 148 0.7950 0.5600 1.4489 0.0083 149 0.7702 0.5600 1.4468 0.0082 150 0.7851 0.5580 1.4469 0.0079 151 0.7881 0.5600 1.4456 0.0078 152 0.7375 0.5600 1.4405 0.0102 153 0.7888 0.5580 1.4401 0.0079 154 0.8128 0.5580 1.4376 0.0082 155 0.6960 0.5600 1.4345 0.0085 156 0.7073 0.5600 1.4328 0.0081 157 0.7129 0.5620 1.4301 0.0083 158 0.7282 0.5620 1.4283 0.0078 159 0.7855 0.5580 1.4263 0.0083 160 0.7444 0.5620 1.4237 0.0081 161 0.7081 0.5620 1.4204 0.0095 162 0.6947 0.5620 1.4188 0.0085 163 0.7121 0.5620 1.4152 0.0067 164 0.7374 0.5620 1.4139 0.0095 165 0.6866 0.5600 1.4113 0.0090 166 0.7360 0.5640 1.4102 0.0083 167 0.6596 0.5600 1.4073 0.0085 168 0.6245 0.5600 1.4048 0.0145 169 0.6546 0.5600 1.4032 0.0069 170 0.6534 0.5640 1.4001 0.0074 171 0.7578 0.5640 1.3992 0.0082 172 0.5842 0.5640 1.3976 0.0075 173 0.5937 0.5600 1.3950 0.0068 174 0.6404 0.5620 1.3937 0.0088 175 0.6674 0.5700 1.3930 0.0090 176 0.6230 0.5660 1.3901 0.0082 177 0.7345 0.5660 1.3891 0.0083 178 0.6696 0.5700 1.3872 0.0084 179 0.5714 0.5720 1.3838 0.0070 180 0.5873 0.5720 1.3810 0.0098 181 0.6660 0.5780 1.3787 0.0095 182 0.6348 0.5760 1.3771 0.0083 183 0.6988 0.5780 1.3746 0.0076 184 0.5842 0.5760 1.3719 0.0074 185 0.5969 0.5760 1.3698 0.0075 186 0.6382 0.5780 1.3682 0.0077 187 0.5822 0.5780 1.3693 0.0063 188 0.6263 0.5760 1.3677 0.0077 189 0.6212 0.5740 1.3666 0.0077 190 0.6059 0.5740 1.3641 0.0115 191 0.5797 0.5760 1.3610 0.0080 192 0.6375 0.5780 1.3598 0.0088 193 0.6777 0.5740 1.3610 0.0103 194 0.6331 0.5700 1.3609 0.0083 195 0.6305 0.5720 1.3601 0.0075 196 0.5502 0.5700 1.3568 0.0080 197 0.7014 0.5720 1.3562 0.0079 198 0.5605 0.5720 1.3551 0.0119 199 0.5541 0.5760 1.3557 0.0071 200 0.5496 0.5720 1.3506 0.0074
<class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=GCN( (conv1): GCNConv(1433, 16) (conv2): GCNConv(16, 7) ), )
from sklearn.metrics import accuracy_score
accuracy_score(ds_test.y, net.predict(ds_test))
0.682
In conclusion this example showed you how to use a basic data graph dataset using pytorch geometric in conjunction with skorch. The final test score is lower than the ~80% accuracy in the introduction example which can be explained by the reduced leakage between train and validation sets due to our splitting the data into subgraphs beforehand.
The model is now incorporated into the sklearn world (as you could already see, you can simply use sklearn metrics to evaluate the model). Thus, tools like grid and random search are available to you and it is easily possible to include a graph neural net as a feature transformer in your next ML pipeline!