Parametrization of a PyTorch distribution with a neural network, for the purpose of uncertainty quantification.
We'll consider the OLS Regression Challenge, which aims at predicting cancer mortality rates for US counties.
Notes:
import os
from os.path import join
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from scipy.stats import binned_statistic
cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
seed = 456
np.random.seed(seed);
torch.manual_seed(seed);
torch.set_default_dtype(torch.float64)
df = pd.read_csv(join(cwd, 'data/cancer_reg.csv'))
df.head()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | binnedInc | MedianAge | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1397.0 | 469 | 164.9 | 489.8 | 61898 | 260131 | 11.2 | 499.748204 | (61494.5, 125635] | 39.3 | ... | NaN | 41.6 | 32.9 | 14.0 | 81.780529 | 2.594728 | 4.821857 | 1.843479 | 52.856076 | 6.118831 |
1 | 173.0 | 70 | 161.3 | 411.6 | 48127 | 43269 | 18.6 | 23.111234 | (48021.6, 51046.4] | 33.0 | ... | 53.8 | 43.6 | 31.1 | 15.3 | 89.228509 | 0.969102 | 2.246233 | 3.741352 | 45.372500 | 4.333096 |
2 | 102.0 | 50 | 174.7 | 349.7 | 49348 | 21026 | 14.6 | 47.560164 | (48021.6, 51046.4] | 45.0 | ... | 43.5 | 34.9 | 42.1 | 21.1 | 90.922190 | 0.739673 | 0.465898 | 2.747358 | 54.444868 | 3.729488 |
3 | 427.0 | 202 | 194.8 | 430.4 | 44243 | 75882 | 17.1 | 342.637253 | (42724.4, 45201] | 42.8 | ... | 40.3 | 35.0 | 45.3 | 25.0 | 91.744686 | 0.782626 | 1.161359 | 1.362643 | 51.021514 | 4.603841 |
4 | 57.0 | 26 | 144.4 | 350.1 | 49955 | 10321 | 12.5 | 0.000000 | (48021.6, 51046.4] | 48.3 | ... | 43.9 | 35.1 | 44.0 | 22.7 | 94.104024 | 0.270192 | 0.665830 | 0.492135 | 54.027460 | 6.796657 |
5 rows × 34 columns
df.describe()
avgAnnCount | avgDeathsPerYear | TARGET_deathRate | incidenceRate | medIncome | popEst2015 | povertyPercent | studyPerCap | MedianAge | MedianAgeMale | ... | PctPrivateCoverageAlone | PctEmpPrivCoverage | PctPublicCoverage | PctPublicCoverageAlone | PctWhite | PctBlack | PctAsian | PctOtherRace | PctMarriedHouseholds | BirthRate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3.047000e+03 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | ... | 2438.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 | 3047.000000 |
mean | 606.338544 | 185.965868 | 178.664063 | 448.268586 | 47063.281917 | 1.026374e+05 | 16.878175 | 155.399415 | 45.272333 | 39.570725 | ... | 48.453774 | 41.196324 | 36.252642 | 19.240072 | 83.645286 | 9.107978 | 1.253965 | 1.983523 | 51.243872 | 5.640306 |
std | 1416.356223 | 504.134286 | 27.751511 | 54.560733 | 12040.090836 | 3.290592e+05 | 6.409087 | 529.628366 | 45.304480 | 5.226017 | ... | 10.083006 | 9.447687 | 7.841741 | 6.113041 | 16.380025 | 14.534538 | 2.610276 | 3.517710 | 6.572814 | 1.985816 |
min | 6.000000 | 3.000000 | 59.700000 | 201.300000 | 22640.000000 | 8.270000e+02 | 3.200000 | 0.000000 | 22.300000 | 22.400000 | ... | 15.700000 | 13.500000 | 11.200000 | 2.600000 | 10.199155 | 0.000000 | 0.000000 | 0.000000 | 22.992490 | 0.000000 |
25% | 76.000000 | 28.000000 | 161.200000 | 420.300000 | 38882.500000 | 1.168400e+04 | 12.150000 | 0.000000 | 37.700000 | 36.350000 | ... | 41.000000 | 34.500000 | 30.900000 | 14.850000 | 77.296180 | 0.620675 | 0.254199 | 0.295172 | 47.763063 | 4.521419 |
50% | 171.000000 | 61.000000 | 178.100000 | 453.549422 | 45207.000000 | 2.664300e+04 | 15.900000 | 0.000000 | 41.000000 | 39.600000 | ... | 48.700000 | 41.100000 | 36.300000 | 18.800000 | 90.059774 | 2.247576 | 0.549812 | 0.826185 | 51.669941 | 5.381478 |
75% | 518.000000 | 149.000000 | 195.200000 | 480.850000 | 52492.000000 | 6.867100e+04 | 20.400000 | 83.650776 | 44.000000 | 42.500000 | ... | 55.600000 | 47.700000 | 41.550000 | 23.100000 | 95.451693 | 10.509732 | 1.221037 | 2.177960 | 55.395132 | 6.493677 |
max | 38150.000000 | 14010.000000 | 362.800000 | 1206.900000 | 125635.000000 | 1.017029e+07 | 47.400000 | 9762.308998 | 624.000000 | 64.700000 | ... | 78.900000 | 70.700000 | 65.100000 | 46.600000 | 100.000000 | 85.947799 | 42.619425 | 41.930251 | 78.075397 | 21.326165 |
8 rows × 32 columns
_, ax = plt.subplots(1, 1, figsize=(10, 5))
df['TARGET_deathRate'].hist(bins=50, ax=ax);
ax.set_title('Distribution of cancer death rate per 100,000 people');
ax.set_xlabel('Cancer death rate in county (per 100,000 people)');
ax.set_ylabel('Count');
target = 'TARGET_deathRate'
features = [
col for col in df.columns
if col not in [
target,
'Geography', # Label describing the county - each row has a different one
'binnedInc', # Redundant with median income?
'PctSomeCol18_24', # contains null values - ignoring for now
'PctEmployed16_Over', # contains null values - ignoring for now
'PctPrivateCoverageAlone', # contains null values - ignoring for now
]
]
print(len(features), 'features')
28 features
x = df[features].values
y = df[[target]].values
print(x.shape, y.shape)
(3047, 28) (3047, 1)
class DeepNormalModel(torch.nn.Module):
def __init__(
self,
n_inputs,
n_hidden,
x_scaler,
y_scaler,
):
super().__init__()
self.x_scaler = x_scaler
self.y_scaler = y_scaler
self.jitter = 1e-6
self.shared = torch.nn.Linear(n_inputs, n_hidden)
self.mean_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.mean_linear = torch.nn.Linear(n_hidden, 1)
self.std_hidden = torch.nn.Linear(n_hidden, n_hidden)
self.std_linear = torch.nn.Linear(n_hidden, 1)
self.dropout = torch.nn.Dropout()
def forward(self, x):
# Normalization
shared = self.x_scaler(x)
# Shared layer
shared = self.shared(shared)
shared = F.relu(shared)
shared = self.dropout(shared)
# Parametrization of the mean
mean_hidden = self.mean_hidden(shared)
mean_hidden = F.relu(mean_hidden)
mean_hidden = self.dropout(mean_hidden)
mean = self.mean_linear(mean_hidden)
# Parametrization fo the standard deviation
std_hidden = self.std_hidden(shared)
std_hidden = F.relu(std_hidden)
std_hidden = self.dropout(std_hidden)
std = F.softplus(self.std_linear(std_hidden)) + self.jitter
return torch.distributions.Normal(mean, std)
def compute_loss(model, x, y, kl_reg=0.1):
y_scaled = model.y_scaler(y)
y_hat = model(x)
neg_log_likelihood = -y_hat.log_prob(y_scaled)
return torch.mean(neg_log_likelihood)
def compute_rmse(model, x_test, y_test):
model.eval()
y_hat = model(x_test)
pred = model.y_scaler.inverse_transform(y_hat.mean)
return torch.sqrt(torch.mean((pred - y_test)**2))
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
optimizer.zero_grad()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size, scheduler=None, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)
batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])
model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)
batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())
if scheduler is not None:
scheduler.step()
train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)
train_losses.append(train_loss)
val_losses.append(val_loss)
if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()
return train_losses, val_losses
def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()
train_ix = np.arange(len(x))
rs.shuffle(train_ix)
n_batches = int(np.ceil(len(x) / batch_size))
batch_indices = []
for i in range(n_batches):
start = i + batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)
return batch_indices
class StandardScaler(object):
"""
Standardize data by removing the mean and scaling to unit variance.
"""
def __init__(self):
self.mean = None
self.scale = None
def fit(self, sample):
self.mean = sample.mean(0, keepdim=True)
self.scale = sample.std(0, unbiased=False, keepdim=True)
return self
def __call__(self, sample):
return self.transform(sample)
def transform(self, sample):
return (sample - self.mean) / self.scale
def inverse_transform(self, sample):
return sample * self.scale + self.mean
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
)
x_train, x_val, y_train, y_val = compute_train_test_split(x, y, test_size=0.2)
print(x_train.shape, y_train.shape)
print(x_val.shape, y_val.shape)
torch.Size([2437, 28]) torch.Size([2437, 1]) torch.Size([610, 28]) torch.Size([610, 1])
x_scaler = StandardScaler().fit(torch.from_numpy(x))
y_scaler = StandardScaler().fit(torch.from_numpy(y))
%%time
learning_rate = 1e-3
momentum = 0.9
weight_decay = 1e-5
n_epochs = 300
batch_size = 64
print_every = 50
n_hidden = 100
model = DeepNormalModel(
n_inputs=x.shape[1],
n_hidden=n_hidden,
x_scaler=x_scaler,
y_scaler=y_scaler,
)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{pytorch_total_params:,} trainable parameters')
print()
optimizer = torch.optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
nesterov=True,
weight_decay=weight_decay,
)
scheduler = None
train_losses, val_losses = train(
model,
optimizer,
x_train,
x_val,
y_train,
y_val,
n_epochs=n_epochs,
batch_size=batch_size,
scheduler=scheduler,
print_every=print_every,
)
23,302 trainable parameters Epoch 1 | Validation loss = 1.3502 | Validation RMSE = 25.9476 Epoch 50 | Validation loss = 0.9483 | Validation RMSE = 18.4543 Epoch 100 | Validation loss = 0.9232 | Validation RMSE = 18.0808 Epoch 150 | Validation loss = 0.9703 | Validation RMSE = 18.5300 Epoch 200 | Validation loss = 0.8858 | Validation RMSE = 17.6373 Epoch 250 | Validation loss = 0.9364 | Validation RMSE = 18.0332 Epoch 300 | Validation loss = 0.9045 | Validation RMSE = 17.4502 CPU times: user 57.7 s, sys: 6.7 s, total: 1min 4s Wall time: 1min
y_dist = model(x_val)
y_hat = model.y_scaler.inverse_transform(y_dist.mean)
val_rmse = float(compute_rmse(model, x_val, y_val).detach().numpy())
print(f'Validation RMSE = {val_rmse:.2f}')
Validation RMSE = 17.46
val_r2 = r2_score(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
)
print(f'Validation $R^2$ = {val_r2:.2f}')
Validation $R^2$ = 0.56
def plot_results(y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(7, 7))
palette = sns.color_palette()
min_value = min(np.amin(y_true), np.amin(y_pred))
max_value = max(np.amax(y_true), np.amax(y_pred))
y_mid = np.linspace(min_value, max_value)
ax.plot(y_mid, y_mid, '--', color=palette[1])
ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5);
return f, ax
f, ax = plot_results(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
);
ax.text(225, 95, f'$R^2 = {val_r2:.2f}$')
ax.text(225, 80, f'$RMSE = {val_rmse:.2f}$')
ax.set_xlabel('Actuals');
ax.set_ylabel('Predictions');
ax.set_title('Regression results on validation set');
def make_predictions(model, x):
dist = model(x)
inv_tr = model.y_scaler.inverse_transform
y_hat = inv_tr(dist.mean)
# Recover standard deviation's original scale
std = inv_tr(dist.mean + dist.stddev) - y_hat
return y_hat, std
y_hat, std = make_predictions(model, x_val)
plt.hist(y_hat.detach().numpy().flatten(), bins=50);
plt.hist(std.detach().numpy().flatten(), bins=50);
absolute_errors = torch.abs(y_hat - y_val).detach().numpy().flatten()
stds = std.detach().numpy().flatten()
df = pd.DataFrame.from_dict({
'error': absolute_errors,
'uncertainty': 2 * stds
})
df.corr()
error | uncertainty | |
---|---|---|
error | 1.000000 | 0.332653 |
uncertainty | 0.332653 | 1.000000 |
def plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned):
absolute_errors = torch.abs(y_hat - y_val).detach().numpy().flatten()
stds = std.detach().numpy().flatten()
if binned:
ret = binned_statistic(absolute_errors, stds, bins=100)
a = ret.bin_edges[1:]
b = ret.statistic
alpha = 1.0
else:
a = absolute_errors
b = stds
alpha = 0.3
f, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.scatter(a, b, alpha=alpha)
return f, ax
f, ax = plot_absolute_error_vs_uncertainty(y_val, y_hat, std, binned=True)
ax.set_title('Absolute error vs uncertainty (binned & averaged)');
ax.set_xlabel('Absolute error (binned)');
ax.set_ylabel('Uncertainty (average within bin)');
def plot_results_with_uncertainty(y_true, y_pred, y_pred_std, subset):
_, ax = plt.subplots(1, 1, figsize=(7, 7))
palette = sns.color_palette()
min_value = min(np.amin(y_true), np.amin(y_pred))
max_value = max(np.amax(y_true), np.amax(y_pred))
y_mid = np.linspace(min_value, max_value)
ax.plot(y_mid, y_mid, '--', color=palette[1])
ax.scatter(y_true, y_pred, color=palette[0], alpha=0.1);
ax.errorbar(y_true[subset], y_pred[subset], yerr=y_pred_std[subset], color=palette[0], fmt='o')
return ax
ix = [i for i, v in enumerate(stds) if v > stds.mean() + 3 * stds.std()]
plot_results_with_uncertainty(
y_val.detach().numpy().flatten(),
y_hat.detach().numpy().flatten(),
std.detach().numpy().flatten(),
subset=ix,
);
We'll consider an input composed of the maximum of all features.
random_input_np = np.zeros((1, x_val.shape[1]))
for i in range(x_val.shape[1]):
fn = torch.amax
random_input_np[0, i] = fn(x_val[:, i]) * 2
random_input = torch.from_numpy(random_input_np)
random_input.shape
torch.Size([1, 28])
_, std = make_predictions(model, random_input)
uncertainty = float(std.detach().numpy())
print(f'Uncertainty on made up input: {uncertainty:.2f}')
Uncertainty on made up input: 19.59
Uncertainty is high, which is what we want!