# Linear Regression - MLE¶

Implementation of the classic linear regression model. Weights are fitted with Maximum Likelihood Estimation using PyTorch distributions.

$y \sim \mathcal{N}(\alpha x + \beta, \sigma)$

## Setting up the environment¶

In [1]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

cwd = os.getcwd()
if cwd.endswith('notebook'):
os.chdir('..')
cwd = os.getcwd()

In [2]:
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()

In [3]:
seed = 444
np.random.seed(seed);
torch.manual_seed(seed);


## Generate dataset¶

In [4]:
α_actual = 2.6
β_actual = 3.3
σ_actual = 0.7

In [5]:
def generate_samples(α, β, σ, min_x=-1, max_x=1, n_samples=500):
x = np.linspace(min_x, max_x, n_samples)[:, np.newaxis]
y = α * x + β
dist = torch.distributions.Normal(torch.from_numpy(y), σ)
y_sample = dist.sample()
return x, y, y_sample.detach().numpy()

In [6]:
def plot_line(x, y, y_sample):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(x.flatten(), y.flatten(), '-', color=palette[2], linewidth=3)
ax.scatter(x.flatten(), y_sample.flatten(), color=palette[0], alpha=0.8)
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.set_title(r'$y \sim N(\alpha x + \beta, \sigma)$')
return f, ax

In [7]:
x, y, y_sample = generate_samples(α_actual, β_actual, σ_actual)

In [8]:
f, _ = plot_line(x, y, y_sample);

In [9]:
x.shape

Out[9]:
(500, 1)

## Define model¶

In [10]:
class LinearNormal(torch.nn.Module):

def __init__(self):
super().__init__()
self.α = torch.nn.Parameter(torch.randn(()))
self.β = torch.nn.Parameter(torch.randn(()))
self.s = torch.nn.Parameter(torch.randn(()))

@property
def sigma(self):
return F.softplus(self.s)  # ensure σ > 0

def forward(self, x):
m = self.α * x + self.β
σ = self.sigma

In [11]:
def compute_loss(model, x, y):
out_dist = model(x)
neg_log_likelihood = -out_dist.log_prob(y)

def compute_rmse(model, x_test, y_test):
model.eval()
pred = model(x_test).sample()

def predict(model, x):
model.eval()
out_dist = model(x)
return out_dist.mean, out_dist.stddev, out_dist


## Fit linear model¶

In [12]:
def train_one_step(model, optimizer, x_batch, y_batch):
model.train()
loss = compute_loss(model, x_batch, y_batch)
loss.backward()
optimizer.step()
return loss

In [13]:
def train(model, optimizer, x_train, x_val, y_train, y_val, n_epochs, batch_size=64, print_every=10):
train_losses, val_losses = [], []
for epoch in range(n_epochs):
batch_indices = sample_batch_indices(x_train, y_train, batch_size)

batch_losses_t, batch_losses_v, batch_rmse_v = [], [], []
for batch_ix in batch_indices:
b_train_loss = train_one_step(model, optimizer, x_train[batch_ix], y_train[batch_ix])

model.eval()
b_val_loss = compute_loss(model, x_val, y_val)
b_val_rmse = compute_rmse(model, x_val, y_val)

batch_losses_t.append(b_train_loss.detach().numpy())
batch_losses_v.append(b_val_loss.detach().numpy())
batch_rmse_v.append(b_val_rmse.detach().numpy())

train_loss = np.mean(batch_losses_t)
val_loss = np.mean(batch_losses_v)
val_rmse = np.mean(batch_rmse_v)

train_losses.append(train_loss)
val_losses.append(val_loss)

if epoch == 0 or (epoch + 1) % print_every == 0:
print(f'Epoch {epoch+1} | Validation loss = {val_loss:.4f} | Validation RMSE = {val_rmse:.4f}')

_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(range(1, n_epochs + 1), train_losses, label='Train loss')
ax.plot(range(1, n_epochs + 1), val_losses, label='Validation loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Overview')
ax.legend()

return train_losses, val_losses

def sample_batch_indices(x, y, batch_size, rs=None):
if rs is None:
rs = np.random.RandomState()

train_ix = np.arange(len(x))
rs.shuffle(train_ix)

n_batches = int(np.ceil(len(x) / batch_size))

batch_indices = []
for i in range(n_batches):
start = i + batch_size
end = start + batch_size
batch_indices.append(
train_ix[start:end].tolist()
)

return batch_indices

In [14]:
def compute_train_test_split(x, y, test_size):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
return (
torch.from_numpy(x_train),
torch.from_numpy(x_test),
torch.from_numpy(y_train),
torch.from_numpy(y_test),
)

In [15]:
model = LinearNormal()

optimizer = torch.optim.SGD(model.parameters(), lr=5e-3)

In [16]:
x_train, x_test, y_train, y_test = compute_train_test_split(x, y_sample, test_size=0.2)

In [17]:
train(model, optimizer, x_train, x_test, y_train, y_test, n_epochs=2000, print_every=200);

Epoch 1 | Validation loss = 2.9909 | Validation RMSE = 3.8504
Epoch 200 | Validation loss = 2.0633 | Validation RMSE = 2.9213
Epoch 400 | Validation loss = 1.5048 | Validation RMSE = 1.6410
Epoch 600 | Validation loss = 1.1390 | Validation RMSE = 1.0149
Epoch 800 | Validation loss = 1.1402 | Validation RMSE = 1.0237
Epoch 1000 | Validation loss = 1.1451 | Validation RMSE = 0.9709
Epoch 1200 | Validation loss = 1.1471 | Validation RMSE = 1.0116
Epoch 1400 | Validation loss = 1.1419 | Validation RMSE = 1.0203
Epoch 1600 | Validation loss = 1.1433 | Validation RMSE = 1.0337
Epoch 1800 | Validation loss = 1.1448 | Validation RMSE = 1.0112
Epoch 2000 | Validation loss = 1.1406 | Validation RMSE = 1.0301


## Validation¶

In [18]:
y_pred, _, y_dist = predict(model, x_test)
y_pred_sample = y_dist.sample()

In [19]:
val_rmse = float(compute_rmse(model, x_test, y_test).detach().numpy())
print(f'Validation RMSE = {val_rmse}')

Validation RMSE = 0.9420654411386103

In [20]:
val_r2 = r2_score(y_test.detach().numpy(), y_pred_sample.detach().numpy())
print(f'Validation R squared = {val_r2}')

Validation R squared = 0.6884557510131627


## Plot results¶

In [21]:
def plot_results(x, y, y_sample, y_pred, std):
f, ax = plt.subplots(1, 1, figsize=(12, 6))

x_arg_sort = x.flatten().argsort()
xx = x.flatten()[x_arg_sort]
v = y_pred.flatten()[x_arg_sort]
std_val = std.flatten()[0]
v_err_min = v - 2 * std_val
v_err_max = v + 2 * std_val

ax.fill_between(
xx,
v_err_min ,
v_err_max,
color=palette[1],
alpha=0.2,
label='2$\sigma$ error',
)

ax.plot(xx, y.flatten()[x_arg_sort], '-', color=palette[0], linewidth=2, label='Actual')
ax.scatter(xx, y_sample.flatten()[x_arg_sort], color=palette[0])

ax.plot(xx, v, '-', color=palette[1], linewidth=2, label='Predicted')

ax.set_title('Predictions on the validation set')
ax.set_xlabel('input x')
ax.set_ylabel('output y')
ax.legend()
return f, ax

In [22]:
f, ax = plot_results(
x_test.detach().numpy(),
α_actual * x_test.detach().numpy() + β_actual,
y_test.detach().numpy(),
y_pred.detach().numpy(),
y_dist.stddev.detach().numpy(),
)

In [23]:
α_hat = float(model.α.detach().numpy())
β_hat = float(model.β.detach().numpy())
σ_hat = float(y_dist.stddev.detach().numpy()[0])

print(f'Actual α = {α_actual:.1f} | Predicted α = {α_hat:.1f}')
print(f'Actual β = {β_actual:.1f} | Predicted β = {β_hat:.1f}')
print(f'Actual σ = {σ_actual:.1f} | Predicted σ = {σ_hat:.1f}')

Actual α = 2.6 | Predicted α = 2.6
Actual β = 3.3 | Predicted β = 3.3
Actual σ = 0.7 | Predicted σ = 0.7

In [ ]: