import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as D
import tables as tb
from sklearn.metrics import (matthews_corrcoef,
confusion_matrix,
f1_score,
roc_auc_score,
accuracy_score,
roc_auc_score)
# set the device to GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MAIN_PATH = '.'
DATA_FILE = 'mt_data.h5'
MODEL_FILE = 'chembl_mt.model'
N_WORKERS = 8 # Dataloader workers, prefetch data in parallel to have it ready for the model after each batch train
BATCH_SIZE = 32 # https://twitter.com/ylecun/status/989610208497360896?lang=es
LR = 2 # Learning rate. Big value because of the way we are weighting the targets
N_EPOCHS = 2 # You should train longer!!!
Simple 80/20 train/test split for the example
class ChEMBLDataset(D.Dataset):
def __init__(self, file_path):
self.file_path = file_path
with tb.open_file(self.file_path, mode='r') as t_file:
self.length = t_file.root.fps.shape[0]
self.n_targets = t_file.root.labels.shape[1]
def __len__(self):
return self.length
def __getitem__(self, index):
with tb.open_file(self.file_path, mode='r') as t_file:
structure = t_file.root.fps[index]
labels = t_file.root.labels[index]
return structure, labels
dataset = ChEMBLDataset(f"{MAIN_PATH}/{DATA_FILE}")
validation_split = .2
random_seed= 42
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
train_sampler = D.sampler.SubsetRandomSampler(train_indices)
test_sampler = D.sampler.SubsetRandomSampler(test_indices)
# dataloaders can prefetch the next batch if using n workers while
# the model is tranining
train_loader = torch.utils.data.DataLoader(dataset,
batch_size=BATCH_SIZE,
num_workers=N_WORKERS,
sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset,
batch_size=BATCH_SIZE,
num_workers=N_WORKERS,
sampler=test_sampler)
class ChEMBLMultiTask(nn.Module):
"""
Architecture borrowed from: https://arxiv.org/abs/1502.02072
"""
def __init__(self, n_tasks):
super(ChEMBLMultiTask, self).__init__()
self.n_tasks = n_tasks
self.fc1 = nn.Linear(1024, 2000)
self.fc2 = nn.Linear(2000, 100)
self.dropout = nn.Dropout(0.25)
# add an independet output for each task int the output laer
for n_m in range(self.n_tasks):
self.add_module(f"y{n_m}o", nn.Linear(100, 1))
def forward(self, x):
h1 = self.dropout(F.relu(self.fc1(x)))
h2 = F.relu(self.fc2(h1))
out = [torch.sigmoid(getattr(self, f"y{n_m}o")(h2)) for n_m in range(self.n_tasks)]
return out
# create the model, to GPU if available
model = ChEMBLMultiTask(dataset.n_targets).to(device)
# binary cross entropy
# each task loss is weighted inversely proportional to its number of datapoints, borrowed from:
# http://www.bioinf.at/publications/2014/NIPS2014a.pdf
with tb.open_file(f"{MAIN_PATH}/{DATA_FILE}", mode='r') as t_file:
weights = torch.tensor(t_file.root.weights[:])
weights = weights.to(device)
criterion = [nn.BCELoss(weight=w) for x, w in zip(range(dataset.n_targets), weights.float())]
# stochastic gradient descend as an optimiser
optimizer = torch.optim.SGD(model.parameters(), LR)
Given the extremely sparse nature of the dataset is difficult to clearly see how the loss is improving after every batch. It looks clearer after several epochs and much more clear when testing :)
# model is by default in train mode. Training can be resumed after .eval() but needs to be set to .train() again
model.train()
for ep in range(N_EPOCHS):
for i, (fps, labels) in enumerate(train_loader):
# move it to GPU if available
fps, labels = fps.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(fps)
# calc the loss
loss = torch.tensor(0.0).to(device)
for j, crit in enumerate(criterion):
# mask keeping labeled molecules for each task
mask = labels[:, j] >= 0.0
if len(labels[:, j][mask]) != 0:
# the loss is the sum of each task/target loss.
# there are labeled samples for this task, so we add it's loss
loss += crit(outputs[j][mask], labels[:, j][mask].view(-1, 1))
loss.backward()
optimizer.step()
if (i+1) % 500 == 0:
print(f"Epoch: [{ep+1}/{N_EPOCHS}], Step: [{i+1}/{len(train_indices)//BATCH_SIZE}], Loss: {loss.item()}")
Epoch: [1/2], Step: [500/17789], Loss: 0.01780553348362446 Epoch: [1/2], Step: [1000/17789], Loss: 0.01136045902967453 Epoch: [1/2], Step: [1500/17789], Loss: 0.018664617091417313 Epoch: [1/2], Step: [2000/17789], Loss: 0.013626799918711185 Epoch: [1/2], Step: [2500/17789], Loss: 0.012855792418122292 Epoch: [1/2], Step: [3000/17789], Loss: 0.013796127401292324 Epoch: [1/2], Step: [3500/17789], Loss: 0.021601887419819832 Epoch: [1/2], Step: [4000/17789], Loss: 0.00950919184833765 Epoch: [1/2], Step: [4500/17789], Loss: 0.02028888650238514 Epoch: [1/2], Step: [5000/17789], Loss: 0.013251284137368202 Epoch: [1/2], Step: [5500/17789], Loss: 0.008788244798779488 Epoch: [1/2], Step: [6000/17789], Loss: 0.012066680938005447 Epoch: [1/2], Step: [6500/17789], Loss: 0.013928443193435669 Epoch: [1/2], Step: [7000/17789], Loss: 0.011484757997095585 Epoch: [1/2], Step: [7500/17789], Loss: 0.0071386718191206455 Epoch: [1/2], Step: [8000/17789], Loss: 0.014712771400809288 Epoch: [1/2], Step: [8500/17789], Loss: 0.010457032360136509 Epoch: [1/2], Step: [9000/17789], Loss: 0.00854165107011795 Epoch: [1/2], Step: [9500/17789], Loss: 0.009312299080193043 Epoch: [1/2], Step: [10000/17789], Loss: 0.010153095237910748 Epoch: [1/2], Step: [10500/17789], Loss: 0.006983090192079544 Epoch: [1/2], Step: [11000/17789], Loss: 0.010238541290163994 Epoch: [1/2], Step: [11500/17789], Loss: 0.012679124251008034 Epoch: [1/2], Step: [12000/17789], Loss: 0.01116170920431614 Epoch: [1/2], Step: [12500/17789], Loss: 0.011749005876481533 Epoch: [1/2], Step: [13000/17789], Loss: 0.015176426619291306 Epoch: [1/2], Step: [13500/17789], Loss: 0.013586488552391529 Epoch: [1/2], Step: [14000/17789], Loss: 0.012365413829684258 Epoch: [1/2], Step: [14500/17789], Loss: 0.009591283276677132 Epoch: [1/2], Step: [15000/17789], Loss: 0.01857740990817547 Epoch: [1/2], Step: [15500/17789], Loss: 0.009823130443692207 Epoch: [1/2], Step: [16000/17789], Loss: 0.01805167831480503 Epoch: [1/2], Step: [16500/17789], Loss: 0.011896809563040733 Epoch: [1/2], Step: [17000/17789], Loss: 0.008349821902811527 Epoch: [1/2], Step: [17500/17789], Loss: 0.013517800718545914 Epoch: [2/2], Step: [500/17789], Loss: 0.007128629367798567 Epoch: [2/2], Step: [1000/17789], Loss: 0.01153416559100151 Epoch: [2/2], Step: [1500/17789], Loss: 0.02041609212756157 Epoch: [2/2], Step: [2000/17789], Loss: 0.0165218748152256 Epoch: [2/2], Step: [2500/17789], Loss: 0.011772445403039455 Epoch: [2/2], Step: [3000/17789], Loss: 0.011200090870261192 Epoch: [2/2], Step: [3500/17789], Loss: 0.012209323234856129 Epoch: [2/2], Step: [4000/17789], Loss: 0.007769708056002855 Epoch: [2/2], Step: [4500/17789], Loss: 0.012243629433214664 Epoch: [2/2], Step: [5000/17789], Loss: 0.018942933529615402 Epoch: [2/2], Step: [5500/17789], Loss: 0.013197326101362705 Epoch: [2/2], Step: [6000/17789], Loss: 0.011520257219672203 Epoch: [2/2], Step: [6500/17789], Loss: 0.020596494898200035 Epoch: [2/2], Step: [7000/17789], Loss: 0.018161792308092117 Epoch: [2/2], Step: [7500/17789], Loss: 0.01610906422138214 Epoch: [2/2], Step: [8000/17789], Loss: 0.004183729644864798 Epoch: [2/2], Step: [8500/17789], Loss: 0.01284581795334816 Epoch: [2/2], Step: [9000/17789], Loss: 0.014269811101257801 Epoch: [2/2], Step: [9500/17789], Loss: 0.009626287035644054 Epoch: [2/2], Step: [10000/17789], Loss: 0.008639814332127571 Epoch: [2/2], Step: [10500/17789], Loss: 0.011639382690191269 Epoch: [2/2], Step: [11000/17789], Loss: 0.005331861320883036 Epoch: [2/2], Step: [11500/17789], Loss: 0.011540957726538181 Epoch: [2/2], Step: [12000/17789], Loss: 0.010148015804588795 Epoch: [2/2], Step: [12500/17789], Loss: 0.011556670069694519 Epoch: [2/2], Step: [13000/17789], Loss: 0.0069694253616034985 Epoch: [2/2], Step: [13500/17789], Loss: 0.008971192874014378 Epoch: [2/2], Step: [14000/17789], Loss: 0.02061212807893753 Epoch: [2/2], Step: [14500/17789], Loss: 0.013362124562263489 Epoch: [2/2], Step: [15000/17789], Loss: 0.00966110359877348 Epoch: [2/2], Step: [15500/17789], Loss: 0.017838571220636368 Epoch: [2/2], Step: [16000/17789], Loss: 0.007174369413405657 Epoch: [2/2], Step: [16500/17789], Loss: 0.0074622794054448605 Epoch: [2/2], Step: [17000/17789], Loss: 0.015448285266757011 Epoch: [2/2], Step: [17500/17789], Loss: 0.011626753024756908
y_trues = []
y_preds = []
y_preds_proba = []
# do not track history
with torch.no_grad():
for fps, labels in test_loader:
# move it to GPU if available
fps, labels = fps.to(device), labels.to(device)
# set model to eval, so will not use the dropout layer
model.eval()
outputs = model(fps)
for j, out in enumerate(outputs):
mask = labels[:, j] >= 0.0
y_pred = torch.where(out[mask] > 0.5, torch.ones(1), torch.zeros(1)).view(1, -1)
if y_pred.shape[1] > 0:
for l in labels[:, j][mask].long().tolist():
y_trues.append(l)
for p in y_pred.view(-1, 1).tolist():
y_preds.append(int(p[0]))
for p in out[mask].view(-1, 1).tolist():
y_preds_proba.append(float(p[0]))
tn, fp, fn, tp = confusion_matrix(y_trues, y_preds).ravel()
sens = tp / (tp + fn)
spec = tn / (tn + fp)
prec = tp / (tp + fp)
f1 = f1_score(y_trues, y_preds)
acc = accuracy_score(y_trues, y_preds)
mcc = matthews_corrcoef(y_trues, y_preds)
auc = roc_auc_score(y_trues, y_preds_proba)
print(f"accuracy: {acc}, auc: {auc}, sens: {sens}, spec: {spec}, prec: {prec}, mcc: {mcc}, f1: {f1}")
print(f"Not bad for only {N_EPOCHS} epochs!")
accuracy: 0.8371918235997756, auc: 0.8942389411754185, sens: 0.7053822792666977, spec: 0.8987519347341067, prec: 0.7649158653846154, mcc: 0.6179805824644773, f1: 0.733943790291889 Not bad for only 2 epochs!
torch.save(model.state_dict(), f"./{MODEL_FILE}")
model = ChEMBLMultiTask(560) # number of tasks
model.load_state_dict(torch.load(f"./{MODEL_FILE}"))
model.eval()
ChEMBLMultiTask( (fc1): Linear(in_features=1024, out_features=2000, bias=True) (fc2): Linear(in_features=2000, out_features=100, bias=True) (dropout): Dropout(p=0.25) (y0o): Linear(in_features=100, out_features=1, bias=True) (y1o): Linear(in_features=100, out_features=1, bias=True) (y2o): Linear(in_features=100, out_features=1, bias=True) (y3o): Linear(in_features=100, out_features=1, bias=True) (y4o): Linear(in_features=100, out_features=1, bias=True) (y5o): Linear(in_features=100, out_features=1, bias=True) (y6o): Linear(in_features=100, out_features=1, bias=True) (y7o): Linear(in_features=100, out_features=1, bias=True) (y8o): Linear(in_features=100, out_features=1, bias=True) (y9o): Linear(in_features=100, out_features=1, bias=True) (y10o): Linear(in_features=100, out_features=1, bias=True) (y11o): Linear(in_features=100, out_features=1, bias=True) (y12o): Linear(in_features=100, out_features=1, bias=True) (y13o): Linear(in_features=100, out_features=1, bias=True) (y14o): Linear(in_features=100, out_features=1, bias=True) (y15o): Linear(in_features=100, out_features=1, bias=True) (y16o): Linear(in_features=100, out_features=1, bias=True) (y17o): Linear(in_features=100, out_features=1, bias=True) (y18o): Linear(in_features=100, out_features=1, bias=True) (y19o): Linear(in_features=100, out_features=1, bias=True) (y20o): Linear(in_features=100, out_features=1, bias=True) (y21o): Linear(in_features=100, out_features=1, bias=True) (y22o): Linear(in_features=100, out_features=1, bias=True) (y23o): Linear(in_features=100, out_features=1, bias=True) (y24o): Linear(in_features=100, out_features=1, bias=True) (y25o): Linear(in_features=100, out_features=1, bias=True) (y26o): Linear(in_features=100, out_features=1, bias=True) (y27o): Linear(in_features=100, out_features=1, bias=True) (y28o): Linear(in_features=100, out_features=1, bias=True) (y29o): Linear(in_features=100, out_features=1, bias=True) (y30o): Linear(in_features=100, out_features=1, bias=True) (y31o): Linear(in_features=100, out_features=1, bias=True) (y32o): Linear(in_features=100, out_features=1, bias=True) (y33o): Linear(in_features=100, out_features=1, bias=True) (y34o): Linear(in_features=100, out_features=1, bias=True) (y35o): Linear(in_features=100, out_features=1, bias=True) (y36o): Linear(in_features=100, out_features=1, bias=True) (y37o): Linear(in_features=100, out_features=1, bias=True) (y38o): Linear(in_features=100, out_features=1, bias=True) (y39o): Linear(in_features=100, out_features=1, bias=True) (y40o): Linear(in_features=100, out_features=1, bias=True) (y41o): Linear(in_features=100, out_features=1, bias=True) (y42o): Linear(in_features=100, out_features=1, bias=True) (y43o): Linear(in_features=100, out_features=1, bias=True) (y44o): Linear(in_features=100, out_features=1, bias=True) (y45o): Linear(in_features=100, out_features=1, bias=True) (y46o): Linear(in_features=100, out_features=1, bias=True) (y47o): Linear(in_features=100, out_features=1, bias=True) (y48o): Linear(in_features=100, out_features=1, bias=True) (y49o): Linear(in_features=100, out_features=1, bias=True) (y50o): Linear(in_features=100, out_features=1, bias=True) (y51o): Linear(in_features=100, out_features=1, bias=True) (y52o): Linear(in_features=100, out_features=1, bias=True) (y53o): Linear(in_features=100, out_features=1, bias=True) (y54o): Linear(in_features=100, out_features=1, bias=True) (y55o): Linear(in_features=100, out_features=1, bias=True) (y56o): Linear(in_features=100, out_features=1, bias=True) (y57o): Linear(in_features=100, out_features=1, bias=True) (y58o): Linear(in_features=100, out_features=1, bias=True) (y59o): Linear(in_features=100, out_features=1, bias=True) (y60o): Linear(in_features=100, out_features=1, bias=True) (y61o): Linear(in_features=100, out_features=1, bias=True) (y62o): Linear(in_features=100, out_features=1, bias=True) (y63o): Linear(in_features=100, out_features=1, bias=True) (y64o): Linear(in_features=100, out_features=1, bias=True) (y65o): Linear(in_features=100, out_features=1, bias=True) (y66o): Linear(in_features=100, out_features=1, bias=True) (y67o): Linear(in_features=100, out_features=1, bias=True) (y68o): Linear(in_features=100, out_features=1, bias=True) (y69o): Linear(in_features=100, out_features=1, bias=True) (y70o): Linear(in_features=100, out_features=1, bias=True) (y71o): Linear(in_features=100, out_features=1, bias=True) (y72o): Linear(in_features=100, out_features=1, bias=True) (y73o): Linear(in_features=100, out_features=1, bias=True) (y74o): Linear(in_features=100, out_features=1, bias=True) (y75o): Linear(in_features=100, out_features=1, bias=True) (y76o): Linear(in_features=100, out_features=1, bias=True) (y77o): Linear(in_features=100, out_features=1, bias=True) (y78o): Linear(in_features=100, out_features=1, bias=True) (y79o): Linear(in_features=100, out_features=1, bias=True) (y80o): Linear(in_features=100, out_features=1, bias=True) (y81o): Linear(in_features=100, out_features=1, bias=True) (y82o): Linear(in_features=100, out_features=1, bias=True) (y83o): Linear(in_features=100, out_features=1, bias=True) (y84o): Linear(in_features=100, out_features=1, bias=True) (y85o): Linear(in_features=100, out_features=1, bias=True) (y86o): Linear(in_features=100, out_features=1, bias=True) (y87o): Linear(in_features=100, out_features=1, bias=True) (y88o): Linear(in_features=100, out_features=1, bias=True) (y89o): Linear(in_features=100, out_features=1, bias=True) (y90o): Linear(in_features=100, out_features=1, bias=True) (y91o): Linear(in_features=100, out_features=1, bias=True) (y92o): Linear(in_features=100, out_features=1, bias=True) (y93o): Linear(in_features=100, out_features=1, bias=True) (y94o): Linear(in_features=100, out_features=1, bias=True) (y95o): Linear(in_features=100, out_features=1, bias=True) (y96o): Linear(in_features=100, out_features=1, bias=True) (y97o): Linear(in_features=100, out_features=1, bias=True) (y98o): Linear(in_features=100, out_features=1, bias=True) (y99o): Linear(in_features=100, out_features=1, bias=True) (y100o): Linear(in_features=100, out_features=1, bias=True) (y101o): Linear(in_features=100, out_features=1, bias=True) (y102o): Linear(in_features=100, out_features=1, bias=True) (y103o): Linear(in_features=100, out_features=1, bias=True) (y104o): Linear(in_features=100, out_features=1, bias=True) (y105o): Linear(in_features=100, out_features=1, bias=True) (y106o): Linear(in_features=100, out_features=1, bias=True) (y107o): Linear(in_features=100, out_features=1, bias=True) (y108o): Linear(in_features=100, out_features=1, bias=True) (y109o): Linear(in_features=100, out_features=1, bias=True) (y110o): Linear(in_features=100, out_features=1, bias=True) (y111o): Linear(in_features=100, out_features=1, bias=True) (y112o): Linear(in_features=100, out_features=1, bias=True) (y113o): Linear(in_features=100, out_features=1, bias=True) (y114o): Linear(in_features=100, out_features=1, bias=True) (y115o): Linear(in_features=100, out_features=1, bias=True) (y116o): Linear(in_features=100, out_features=1, bias=True) (y117o): Linear(in_features=100, out_features=1, bias=True) (y118o): Linear(in_features=100, out_features=1, bias=True) (y119o): Linear(in_features=100, out_features=1, bias=True) (y120o): Linear(in_features=100, out_features=1, bias=True) (y121o): Linear(in_features=100, out_features=1, bias=True) (y122o): Linear(in_features=100, out_features=1, bias=True) (y123o): Linear(in_features=100, out_features=1, bias=True) (y124o): Linear(in_features=100, out_features=1, bias=True) (y125o): Linear(in_features=100, out_features=1, bias=True) (y126o): Linear(in_features=100, out_features=1, bias=True) (y127o): Linear(in_features=100, out_features=1, bias=True) (y128o): Linear(in_features=100, out_features=1, bias=True) (y129o): Linear(in_features=100, out_features=1, bias=True) (y130o): Linear(in_features=100, out_features=1, bias=True) (y131o): Linear(in_features=100, out_features=1, bias=True) (y132o): Linear(in_features=100, out_features=1, bias=True) (y133o): Linear(in_features=100, out_features=1, bias=True) (y134o): Linear(in_features=100, out_features=1, bias=True) (y135o): Linear(in_features=100, out_features=1, bias=True) (y136o): Linear(in_features=100, out_features=1, bias=True) (y137o): Linear(in_features=100, out_features=1, bias=True) (y138o): Linear(in_features=100, out_features=1, bias=True) (y139o): Linear(in_features=100, out_features=1, bias=True) (y140o): Linear(in_features=100, out_features=1, bias=True) (y141o): Linear(in_features=100, out_features=1, bias=True) (y142o): Linear(in_features=100, out_features=1, bias=True) (y143o): Linear(in_features=100, out_features=1, bias=True) (y144o): Linear(in_features=100, out_features=1, bias=True) (y145o): Linear(in_features=100, out_features=1, bias=True) (y146o): Linear(in_features=100, out_features=1, bias=True) (y147o): Linear(in_features=100, out_features=1, bias=True) (y148o): Linear(in_features=100, out_features=1, bias=True) (y149o): Linear(in_features=100, out_features=1, bias=True) (y150o): Linear(in_features=100, out_features=1, bias=True) (y151o): Linear(in_features=100, out_features=1, bias=True) (y152o): Linear(in_features=100, out_features=1, bias=True) (y153o): Linear(in_features=100, out_features=1, bias=True) (y154o): Linear(in_features=100, out_features=1, bias=True) (y155o): Linear(in_features=100, out_features=1, bias=True) (y156o): Linear(in_features=100, out_features=1, bias=True) (y157o): Linear(in_features=100, out_features=1, bias=True) (y158o): Linear(in_features=100, out_features=1, bias=True) (y159o): Linear(in_features=100, out_features=1, bias=True) (y160o): Linear(in_features=100, out_features=1, bias=True) (y161o): Linear(in_features=100, out_features=1, bias=True) (y162o): Linear(in_features=100, out_features=1, bias=True) (y163o): Linear(in_features=100, out_features=1, bias=True) (y164o): Linear(in_features=100, out_features=1, bias=True) (y165o): Linear(in_features=100, out_features=1, bias=True) (y166o): Linear(in_features=100, out_features=1, bias=True) (y167o): Linear(in_features=100, out_features=1, bias=True) (y168o): Linear(in_features=100, out_features=1, bias=True) (y169o): Linear(in_features=100, out_features=1, bias=True) (y170o): Linear(in_features=100, out_features=1, bias=True) (y171o): Linear(in_features=100, out_features=1, bias=True) (y172o): Linear(in_features=100, out_features=1, bias=True) (y173o): Linear(in_features=100, out_features=1, bias=True) (y174o): Linear(in_features=100, out_features=1, bias=True) (y175o): Linear(in_features=100, out_features=1, bias=True) (y176o): Linear(in_features=100, out_features=1, bias=True) (y177o): Linear(in_features=100, out_features=1, bias=True) (y178o): Linear(in_features=100, out_features=1, bias=True) (y179o): Linear(in_features=100, out_features=1, bias=True) (y180o): Linear(in_features=100, out_features=1, bias=True) (y181o): Linear(in_features=100, out_features=1, bias=True) (y182o): Linear(in_features=100, out_features=1, bias=True) (y183o): Linear(in_features=100, out_features=1, bias=True) (y184o): Linear(in_features=100, out_features=1, bias=True) (y185o): Linear(in_features=100, out_features=1, bias=True) (y186o): Linear(in_features=100, out_features=1, bias=True) (y187o): Linear(in_features=100, out_features=1, bias=True) (y188o): Linear(in_features=100, out_features=1, bias=True) (y189o): Linear(in_features=100, out_features=1, bias=True) (y190o): Linear(in_features=100, out_features=1, bias=True) (y191o): Linear(in_features=100, out_features=1, bias=True) (y192o): Linear(in_features=100, out_features=1, bias=True) (y193o): Linear(in_features=100, out_features=1, bias=True) (y194o): Linear(in_features=100, out_features=1, bias=True) (y195o): Linear(in_features=100, out_features=1, bias=True) (y196o): Linear(in_features=100, out_features=1, bias=True) (y197o): Linear(in_features=100, out_features=1, bias=True) (y198o): Linear(in_features=100, out_features=1, bias=True) (y199o): Linear(in_features=100, out_features=1, bias=True) (y200o): Linear(in_features=100, out_features=1, bias=True) (y201o): Linear(in_features=100, out_features=1, bias=True) (y202o): Linear(in_features=100, out_features=1, bias=True) (y203o): Linear(in_features=100, out_features=1, bias=True) (y204o): Linear(in_features=100, out_features=1, bias=True) (y205o): Linear(in_features=100, out_features=1, bias=True) (y206o): Linear(in_features=100, out_features=1, bias=True) (y207o): Linear(in_features=100, out_features=1, bias=True) (y208o): Linear(in_features=100, out_features=1, bias=True) (y209o): Linear(in_features=100, out_features=1, bias=True) (y210o): Linear(in_features=100, out_features=1, bias=True) (y211o): Linear(in_features=100, out_features=1, bias=True) (y212o): Linear(in_features=100, out_features=1, bias=True) (y213o): Linear(in_features=100, out_features=1, bias=True) (y214o): Linear(in_features=100, out_features=1, bias=True) (y215o): Linear(in_features=100, out_features=1, bias=True) (y216o): Linear(in_features=100, out_features=1, bias=True) (y217o): Linear(in_features=100, out_features=1, bias=True) (y218o): Linear(in_features=100, out_features=1, bias=True) (y219o): Linear(in_features=100, out_features=1, bias=True) (y220o): Linear(in_features=100, out_features=1, bias=True) (y221o): Linear(in_features=100, out_features=1, bias=True) (y222o): Linear(in_features=100, out_features=1, bias=True) (y223o): Linear(in_features=100, out_features=1, bias=True) (y224o): Linear(in_features=100, out_features=1, bias=True) (y225o): Linear(in_features=100, out_features=1, bias=True) (y226o): Linear(in_features=100, out_features=1, bias=True) (y227o): Linear(in_features=100, out_features=1, bias=True) (y228o): Linear(in_features=100, out_features=1, bias=True) (y229o): Linear(in_features=100, out_features=1, bias=True) (y230o): Linear(in_features=100, out_features=1, bias=True) (y231o): Linear(in_features=100, out_features=1, bias=True) (y232o): Linear(in_features=100, out_features=1, bias=True) (y233o): Linear(in_features=100, out_features=1, bias=True) (y234o): Linear(in_features=100, out_features=1, bias=True) (y235o): Linear(in_features=100, out_features=1, bias=True) (y236o): Linear(in_features=100, out_features=1, bias=True) (y237o): Linear(in_features=100, out_features=1, bias=True) (y238o): Linear(in_features=100, out_features=1, bias=True) (y239o): Linear(in_features=100, out_features=1, bias=True) (y240o): Linear(in_features=100, out_features=1, bias=True) (y241o): Linear(in_features=100, out_features=1, bias=True) (y242o): Linear(in_features=100, out_features=1, bias=True) (y243o): Linear(in_features=100, out_features=1, bias=True) (y244o): Linear(in_features=100, out_features=1, bias=True) (y245o): Linear(in_features=100, out_features=1, bias=True) (y246o): Linear(in_features=100, out_features=1, bias=True) (y247o): Linear(in_features=100, out_features=1, bias=True) (y248o): Linear(in_features=100, out_features=1, bias=True) (y249o): Linear(in_features=100, out_features=1, bias=True) (y250o): Linear(in_features=100, out_features=1, bias=True) (y251o): Linear(in_features=100, out_features=1, bias=True) (y252o): Linear(in_features=100, out_features=1, bias=True) (y253o): Linear(in_features=100, out_features=1, bias=True) (y254o): Linear(in_features=100, out_features=1, bias=True) (y255o): Linear(in_features=100, out_features=1, bias=True) (y256o): Linear(in_features=100, out_features=1, bias=True) (y257o): Linear(in_features=100, out_features=1, bias=True) (y258o): Linear(in_features=100, out_features=1, bias=True) (y259o): Linear(in_features=100, out_features=1, bias=True) (y260o): Linear(in_features=100, out_features=1, bias=True) (y261o): Linear(in_features=100, out_features=1, bias=True) (y262o): Linear(in_features=100, out_features=1, bias=True) (y263o): Linear(in_features=100, out_features=1, bias=True) (y264o): Linear(in_features=100, out_features=1, bias=True) (y265o): Linear(in_features=100, out_features=1, bias=True) (y266o): Linear(in_features=100, out_features=1, bias=True) (y267o): Linear(in_features=100, out_features=1, bias=True) (y268o): Linear(in_features=100, out_features=1, bias=True) (y269o): Linear(in_features=100, out_features=1, bias=True) (y270o): Linear(in_features=100, out_features=1, bias=True) (y271o): Linear(in_features=100, out_features=1, bias=True) (y272o): Linear(in_features=100, out_features=1, bias=True) (y273o): Linear(in_features=100, out_features=1, bias=True) (y274o): Linear(in_features=100, out_features=1, bias=True) (y275o): Linear(in_features=100, out_features=1, bias=True) (y276o): Linear(in_features=100, out_features=1, bias=True) (y277o): Linear(in_features=100, out_features=1, bias=True) (y278o): Linear(in_features=100, out_features=1, bias=True) (y279o): Linear(in_features=100, out_features=1, bias=True) (y280o): Linear(in_features=100, out_features=1, bias=True) (y281o): Linear(in_features=100, out_features=1, bias=True) (y282o): Linear(in_features=100, out_features=1, bias=True) (y283o): Linear(in_features=100, out_features=1, bias=True) (y284o): Linear(in_features=100, out_features=1, bias=True) (y285o): Linear(in_features=100, out_features=1, bias=True) (y286o): Linear(in_features=100, out_features=1, bias=True) (y287o): Linear(in_features=100, out_features=1, bias=True) (y288o): Linear(in_features=100, out_features=1, bias=True) (y289o): Linear(in_features=100, out_features=1, bias=True) (y290o): Linear(in_features=100, out_features=1, bias=True) (y291o): Linear(in_features=100, out_features=1, bias=True) (y292o): Linear(in_features=100, out_features=1, bias=True) (y293o): Linear(in_features=100, out_features=1, bias=True) (y294o): Linear(in_features=100, out_features=1, bias=True) (y295o): Linear(in_features=100, out_features=1, bias=True) (y296o): Linear(in_features=100, out_features=1, bias=True) (y297o): Linear(in_features=100, out_features=1, bias=True) (y298o): Linear(in_features=100, out_features=1, bias=True) (y299o): Linear(in_features=100, out_features=1, bias=True) (y300o): Linear(in_features=100, out_features=1, bias=True) (y301o): Linear(in_features=100, out_features=1, bias=True) (y302o): Linear(in_features=100, out_features=1, bias=True) (y303o): Linear(in_features=100, out_features=1, bias=True) (y304o): Linear(in_features=100, out_features=1, bias=True) (y305o): Linear(in_features=100, out_features=1, bias=True) (y306o): Linear(in_features=100, out_features=1, bias=True) (y307o): Linear(in_features=100, out_features=1, bias=True) (y308o): Linear(in_features=100, out_features=1, bias=True) (y309o): Linear(in_features=100, out_features=1, bias=True) (y310o): Linear(in_features=100, out_features=1, bias=True) (y311o): Linear(in_features=100, out_features=1, bias=True) (y312o): Linear(in_features=100, out_features=1, bias=True) (y313o): Linear(in_features=100, out_features=1, bias=True) (y314o): Linear(in_features=100, out_features=1, bias=True) (y315o): Linear(in_features=100, out_features=1, bias=True) (y316o): Linear(in_features=100, out_features=1, bias=True) (y317o): Linear(in_features=100, out_features=1, bias=True) (y318o): Linear(in_features=100, out_features=1, bias=True) (y319o): Linear(in_features=100, out_features=1, bias=True) (y320o): Linear(in_features=100, out_features=1, bias=True) (y321o): Linear(in_features=100, out_features=1, bias=True) (y322o): Linear(in_features=100, out_features=1, bias=True) (y323o): Linear(in_features=100, out_features=1, bias=True) (y324o): Linear(in_features=100, out_features=1, bias=True) (y325o): Linear(in_features=100, out_features=1, bias=True) (y326o): Linear(in_features=100, out_features=1, bias=True) (y327o): Linear(in_features=100, out_features=1, bias=True) (y328o): Linear(in_features=100, out_features=1, bias=True) (y329o): Linear(in_features=100, out_features=1, bias=True) (y330o): Linear(in_features=100, out_features=1, bias=True) (y331o): Linear(in_features=100, out_features=1, bias=True) (y332o): Linear(in_features=100, out_features=1, bias=True) (y333o): Linear(in_features=100, out_features=1, bias=True) (y334o): Linear(in_features=100, out_features=1, bias=True) (y335o): Linear(in_features=100, out_features=1, bias=True) (y336o): Linear(in_features=100, out_features=1, bias=True) (y337o): Linear(in_features=100, out_features=1, bias=True) (y338o): Linear(in_features=100, out_features=1, bias=True) (y339o): Linear(in_features=100, out_features=1, bias=True) (y340o): Linear(in_features=100, out_features=1, bias=True) (y341o): Linear(in_features=100, out_features=1, bias=True) (y342o): Linear(in_features=100, out_features=1, bias=True) (y343o): Linear(in_features=100, out_features=1, bias=True) (y344o): Linear(in_features=100, out_features=1, bias=True) (y345o): Linear(in_features=100, out_features=1, bias=True) (y346o): Linear(in_features=100, out_features=1, bias=True) (y347o): Linear(in_features=100, out_features=1, bias=True) (y348o): Linear(in_features=100, out_features=1, bias=True) (y349o): Linear(in_features=100, out_features=1, bias=True) (y350o): Linear(in_features=100, out_features=1, bias=True) (y351o): Linear(in_features=100, out_features=1, bias=True) (y352o): Linear(in_features=100, out_features=1, bias=True) (y353o): Linear(in_features=100, out_features=1, bias=True) (y354o): Linear(in_features=100, out_features=1, bias=True) (y355o): Linear(in_features=100, out_features=1, bias=True) (y356o): Linear(in_features=100, out_features=1, bias=True) (y357o): Linear(in_features=100, out_features=1, bias=True) (y358o): Linear(in_features=100, out_features=1, bias=True) (y359o): Linear(in_features=100, out_features=1, bias=True) (y360o): Linear(in_features=100, out_features=1, bias=True) (y361o): Linear(in_features=100, out_features=1, bias=True) (y362o): Linear(in_features=100, out_features=1, bias=True) (y363o): Linear(in_features=100, out_features=1, bias=True) (y364o): Linear(in_features=100, out_features=1, bias=True) (y365o): Linear(in_features=100, out_features=1, bias=True) (y366o): Linear(in_features=100, out_features=1, bias=True) (y367o): Linear(in_features=100, out_features=1, bias=True) (y368o): Linear(in_features=100, out_features=1, bias=True) (y369o): Linear(in_features=100, out_features=1, bias=True) (y370o): Linear(in_features=100, out_features=1, bias=True) (y371o): Linear(in_features=100, out_features=1, bias=True) (y372o): Linear(in_features=100, out_features=1, bias=True) (y373o): Linear(in_features=100, out_features=1, bias=True) (y374o): Linear(in_features=100, out_features=1, bias=True) (y375o): Linear(in_features=100, out_features=1, bias=True) (y376o): Linear(in_features=100, out_features=1, bias=True) (y377o): Linear(in_features=100, out_features=1, bias=True) (y378o): Linear(in_features=100, out_features=1, bias=True) (y379o): Linear(in_features=100, out_features=1, bias=True) (y380o): Linear(in_features=100, out_features=1, bias=True) (y381o): Linear(in_features=100, out_features=1, bias=True) (y382o): Linear(in_features=100, out_features=1, bias=True) (y383o): Linear(in_features=100, out_features=1, bias=True) (y384o): Linear(in_features=100, out_features=1, bias=True) (y385o): Linear(in_features=100, out_features=1, bias=True) (y386o): Linear(in_features=100, out_features=1, bias=True) (y387o): Linear(in_features=100, out_features=1, bias=True) (y388o): Linear(in_features=100, out_features=1, bias=True) (y389o): Linear(in_features=100, out_features=1, bias=True) (y390o): Linear(in_features=100, out_features=1, bias=True) (y391o): Linear(in_features=100, out_features=1, bias=True) (y392o): Linear(in_features=100, out_features=1, bias=True) (y393o): Linear(in_features=100, out_features=1, bias=True) (y394o): Linear(in_features=100, out_features=1, bias=True) (y395o): Linear(in_features=100, out_features=1, bias=True) (y396o): Linear(in_features=100, out_features=1, bias=True) (y397o): Linear(in_features=100, out_features=1, bias=True) (y398o): Linear(in_features=100, out_features=1, bias=True) (y399o): Linear(in_features=100, out_features=1, bias=True) (y400o): Linear(in_features=100, out_features=1, bias=True) (y401o): Linear(in_features=100, out_features=1, bias=True) (y402o): Linear(in_features=100, out_features=1, bias=True) (y403o): Linear(in_features=100, out_features=1, bias=True) (y404o): Linear(in_features=100, out_features=1, bias=True) (y405o): Linear(in_features=100, out_features=1, bias=True) (y406o): Linear(in_features=100, out_features=1, bias=True) (y407o): Linear(in_features=100, out_features=1, bias=True) (y408o): Linear(in_features=100, out_features=1, bias=True) (y409o): Linear(in_features=100, out_features=1, bias=True) (y410o): Linear(in_features=100, out_features=1, bias=True) (y411o): Linear(in_features=100, out_features=1, bias=True) (y412o): Linear(in_features=100, out_features=1, bias=True) (y413o): Linear(in_features=100, out_features=1, bias=True) (y414o): Linear(in_features=100, out_features=1, bias=True) (y415o): Linear(in_features=100, out_features=1, bias=True) (y416o): Linear(in_features=100, out_features=1, bias=True) (y417o): Linear(in_features=100, out_features=1, bias=True) (y418o): Linear(in_features=100, out_features=1, bias=True) (y419o): Linear(in_features=100, out_features=1, bias=True) (y420o): Linear(in_features=100, out_features=1, bias=True) (y421o): Linear(in_features=100, out_features=1, bias=True) (y422o): Linear(in_features=100, out_features=1, bias=True) (y423o): Linear(in_features=100, out_features=1, bias=True) (y424o): Linear(in_features=100, out_features=1, bias=True) (y425o): Linear(in_features=100, out_features=1, bias=True) (y426o): Linear(in_features=100, out_features=1, bias=True) (y427o): Linear(in_features=100, out_features=1, bias=True) (y428o): Linear(in_features=100, out_features=1, bias=True) (y429o): Linear(in_features=100, out_features=1, bias=True) (y430o): Linear(in_features=100, out_features=1, bias=True) (y431o): Linear(in_features=100, out_features=1, bias=True) (y432o): Linear(in_features=100, out_features=1, bias=True) (y433o): Linear(in_features=100, out_features=1, bias=True) (y434o): Linear(in_features=100, out_features=1, bias=True) (y435o): Linear(in_features=100, out_features=1, bias=True) (y436o): Linear(in_features=100, out_features=1, bias=True) (y437o): Linear(in_features=100, out_features=1, bias=True) (y438o): Linear(in_features=100, out_features=1, bias=True) (y439o): Linear(in_features=100, out_features=1, bias=True) (y440o): Linear(in_features=100, out_features=1, bias=True) (y441o): Linear(in_features=100, out_features=1, bias=True) (y442o): Linear(in_features=100, out_features=1, bias=True) (y443o): Linear(in_features=100, out_features=1, bias=True) (y444o): Linear(in_features=100, out_features=1, bias=True) (y445o): Linear(in_features=100, out_features=1, bias=True) (y446o): Linear(in_features=100, out_features=1, bias=True) (y447o): Linear(in_features=100, out_features=1, bias=True) (y448o): Linear(in_features=100, out_features=1, bias=True) (y449o): Linear(in_features=100, out_features=1, bias=True) (y450o): Linear(in_features=100, out_features=1, bias=True) (y451o): Linear(in_features=100, out_features=1, bias=True) (y452o): Linear(in_features=100, out_features=1, bias=True) (y453o): Linear(in_features=100, out_features=1, bias=True) (y454o): Linear(in_features=100, out_features=1, bias=True) (y455o): Linear(in_features=100, out_features=1, bias=True) (y456o): Linear(in_features=100, out_features=1, bias=True) (y457o): Linear(in_features=100, out_features=1, bias=True) (y458o): Linear(in_features=100, out_features=1, bias=True) (y459o): Linear(in_features=100, out_features=1, bias=True) (y460o): Linear(in_features=100, out_features=1, bias=True) (y461o): Linear(in_features=100, out_features=1, bias=True) (y462o): Linear(in_features=100, out_features=1, bias=True) (y463o): Linear(in_features=100, out_features=1, bias=True) (y464o): Linear(in_features=100, out_features=1, bias=True) (y465o): Linear(in_features=100, out_features=1, bias=True) (y466o): Linear(in_features=100, out_features=1, bias=True) (y467o): Linear(in_features=100, out_features=1, bias=True) (y468o): Linear(in_features=100, out_features=1, bias=True) (y469o): Linear(in_features=100, out_features=1, bias=True) (y470o): Linear(in_features=100, out_features=1, bias=True) (y471o): Linear(in_features=100, out_features=1, bias=True) (y472o): Linear(in_features=100, out_features=1, bias=True) (y473o): Linear(in_features=100, out_features=1, bias=True) (y474o): Linear(in_features=100, out_features=1, bias=True) (y475o): Linear(in_features=100, out_features=1, bias=True) (y476o): Linear(in_features=100, out_features=1, bias=True) (y477o): Linear(in_features=100, out_features=1, bias=True) (y478o): Linear(in_features=100, out_features=1, bias=True) (y479o): Linear(in_features=100, out_features=1, bias=True) (y480o): Linear(in_features=100, out_features=1, bias=True) (y481o): Linear(in_features=100, out_features=1, bias=True) (y482o): Linear(in_features=100, out_features=1, bias=True) (y483o): Linear(in_features=100, out_features=1, bias=True) (y484o): Linear(in_features=100, out_features=1, bias=True) (y485o): Linear(in_features=100, out_features=1, bias=True) (y486o): Linear(in_features=100, out_features=1, bias=True) (y487o): Linear(in_features=100, out_features=1, bias=True) (y488o): Linear(in_features=100, out_features=1, bias=True) (y489o): Linear(in_features=100, out_features=1, bias=True) (y490o): Linear(in_features=100, out_features=1, bias=True) (y491o): Linear(in_features=100, out_features=1, bias=True) (y492o): Linear(in_features=100, out_features=1, bias=True) (y493o): Linear(in_features=100, out_features=1, bias=True) (y494o): Linear(in_features=100, out_features=1, bias=True) (y495o): Linear(in_features=100, out_features=1, bias=True) (y496o): Linear(in_features=100, out_features=1, bias=True) (y497o): Linear(in_features=100, out_features=1, bias=True) (y498o): Linear(in_features=100, out_features=1, bias=True) (y499o): Linear(in_features=100, out_features=1, bias=True) (y500o): Linear(in_features=100, out_features=1, bias=True) (y501o): Linear(in_features=100, out_features=1, bias=True) (y502o): Linear(in_features=100, out_features=1, bias=True) (y503o): Linear(in_features=100, out_features=1, bias=True) (y504o): Linear(in_features=100, out_features=1, bias=True) (y505o): Linear(in_features=100, out_features=1, bias=True) (y506o): Linear(in_features=100, out_features=1, bias=True) (y507o): Linear(in_features=100, out_features=1, bias=True) (y508o): Linear(in_features=100, out_features=1, bias=True) (y509o): Linear(in_features=100, out_features=1, bias=True) (y510o): Linear(in_features=100, out_features=1, bias=True) (y511o): Linear(in_features=100, out_features=1, bias=True) (y512o): Linear(in_features=100, out_features=1, bias=True) (y513o): Linear(in_features=100, out_features=1, bias=True) (y514o): Linear(in_features=100, out_features=1, bias=True) (y515o): Linear(in_features=100, out_features=1, bias=True) (y516o): Linear(in_features=100, out_features=1, bias=True) (y517o): Linear(in_features=100, out_features=1, bias=True) (y518o): Linear(in_features=100, out_features=1, bias=True) (y519o): Linear(in_features=100, out_features=1, bias=True) (y520o): Linear(in_features=100, out_features=1, bias=True) (y521o): Linear(in_features=100, out_features=1, bias=True) (y522o): Linear(in_features=100, out_features=1, bias=True) (y523o): Linear(in_features=100, out_features=1, bias=True) (y524o): Linear(in_features=100, out_features=1, bias=True) (y525o): Linear(in_features=100, out_features=1, bias=True) (y526o): Linear(in_features=100, out_features=1, bias=True) (y527o): Linear(in_features=100, out_features=1, bias=True) (y528o): Linear(in_features=100, out_features=1, bias=True) (y529o): Linear(in_features=100, out_features=1, bias=True) (y530o): Linear(in_features=100, out_features=1, bias=True) (y531o): Linear(in_features=100, out_features=1, bias=True) (y532o): Linear(in_features=100, out_features=1, bias=True) (y533o): Linear(in_features=100, out_features=1, bias=True) (y534o): Linear(in_features=100, out_features=1, bias=True) (y535o): Linear(in_features=100, out_features=1, bias=True) (y536o): Linear(in_features=100, out_features=1, bias=True) (y537o): Linear(in_features=100, out_features=1, bias=True) (y538o): Linear(in_features=100, out_features=1, bias=True) (y539o): Linear(in_features=100, out_features=1, bias=True) (y540o): Linear(in_features=100, out_features=1, bias=True) (y541o): Linear(in_features=100, out_features=1, bias=True) (y542o): Linear(in_features=100, out_features=1, bias=True) (y543o): Linear(in_features=100, out_features=1, bias=True) (y544o): Linear(in_features=100, out_features=1, bias=True) (y545o): Linear(in_features=100, out_features=1, bias=True) (y546o): Linear(in_features=100, out_features=1, bias=True) (y547o): Linear(in_features=100, out_features=1, bias=True) (y548o): Linear(in_features=100, out_features=1, bias=True) (y549o): Linear(in_features=100, out_features=1, bias=True) (y550o): Linear(in_features=100, out_features=1, bias=True) (y551o): Linear(in_features=100, out_features=1, bias=True) (y552o): Linear(in_features=100, out_features=1, bias=True) (y553o): Linear(in_features=100, out_features=1, bias=True) (y554o): Linear(in_features=100, out_features=1, bias=True) (y555o): Linear(in_features=100, out_features=1, bias=True) (y556o): Linear(in_features=100, out_features=1, bias=True) (y557o): Linear(in_features=100, out_features=1, bias=True) (y558o): Linear(in_features=100, out_features=1, bias=True) (y559o): Linear(in_features=100, out_features=1, bias=True) )