%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.autograd import Variable
from tqdm import tqdm, trange
import copy
from sklearn.metrics import mean_squared_error
torch.__version__
'1.6.0+cu101'
sns.set_context("paper")
sns.set_style("ticks")
np.random.seed(101217209)
torch.manual_seed(101217209)
<torch._C.Generator at 0x7f734e3adfd8>
def get_data(N, min_x, max_x):
w, b = np.random.randn(2)
def true_model(X):
lin_factor = w * X + b
y = np.sin(lin_factor*10)*(lin_factor**2)
return y
X_true = np.arange(min_x, max_x, 0.01)
y_true = true_model(X_true)
span = (max_x - min_x)
scale = 0.2
X_obs = min_x + span*scale + np.random.rand(N)*(span - 2*scale*span)
y_obs = true_model(X_obs) + np.random.randn(N)*0.1
y_mean, y_std = y_obs.mean(), y_obs.std()
y_obs = (y_obs - y_mean)/y_std
y_true = (y_true - y_mean)/y_std
return (X_obs, y_obs, X_true, y_true), (w, b, true_model)
N = 40
min_x, max_x = -1, 1
(X_obs, y_obs, X_true, y_true), (w, b, true_model) = get_data(N, min_x, max_x)
class SimpleModel(torch.nn.Module):
def __init__(self, p=0.05, decay=0.001, non_linearity=torch.nn.ReLU):
super(SimpleModel, self).__init__()
self.dropout_p = p
self.decay = decay
self.f = torch.nn.Sequential(
torch.nn.Linear(1,20),
torch.nn.ReLU(),
torch.nn.Dropout(p=self.dropout_p),
torch.nn.Linear(20, 20),
non_linearity(),
torch.nn.Dropout(p=self.dropout_p),
torch.nn.Linear(20,1)
)
def forward(self, X):
X = Variable(torch.Tensor(X), requires_grad=False)
return self.f(X)
def ewc_criterion(model, fisher_matrix, prior_params, alpha=1):
ewc_loss = Variable(torch.zeros(1))
for p, fm, prior in zip(model.parameters(), fisher_matrix, prior_params):
param_loss = torch.pow(p - Variable(prior, requires_grad=False), 2)
fisher_scaled_loss = alpha*Variable(fm, requires_grad=False).mul(param_loss)
ewc_loss += fisher_scaled_loss.sum()
return ewc_loss[0] # Scalar
def get_criterion_optimizer(model):
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(
model.parameters(),
#lr=0.001,
#momentum=0.5,
weight_decay=model.decay)
return criterion, optimizer
def plot_predictions(model, model_simple, ax=None, uncertainty=0, N=0, l2=0.005):
if ax is None:
fig, ax = plt.subplots(1,1)
ax.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
ax.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
ax.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
ax.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
for model_t, model_name, color in [
(model, "ewc", "r"),
(model_simple, "simple", "g")
]:
if uncertainty > 0:
model_t.train()
y_preds = []
for i in trange(uncertainty):
y_pred = model_t(X_true[:, np.newaxis]).data.numpy()
y_preds.append(y_pred)
y_preds = np.hstack(y_preds)
y_pred = y_preds.mean(axis=1)
y_var = y_preds.var(axis=1)
tau = l2 * (model_t.dropout_p) / (2*N*model_t.decay)
y_var += (1/tau)
y_std = np.sqrt(y_var) #+ (1/tau)
for i in range(4):
ax.fill_between(
X_true,
y_pred - y_std * ((i+1)/2),
y_pred + y_std * ((i+1)/2),
color=color,
alpha=0.1**(i+1)
)
else:
model_t.eval()
y_pred = model_t(X_true[:, np.newaxis]).data.numpy()
mse = mean_squared_error(y_true, y_pred)
ax.plot(X_true, y_pred, ls="--", lw=1, color=color,
label="{} ({:.3f})".format(model_name, mse))
ax.legend()
ax.set_xlabel("X")
ax.set_ylabel("Y")
sns.despine(offset=10)
return ax
def plot_losses(losses, losses_simple, ax=None):
if ax is None:
fig, ax = plt.subplots(1,1,figsize=(10,4))
ax.plot(losses, lw=0.1, color='r', label='ewc')
ax.plot(losses_simple, lw=0.1, color='0.5', label='simple')
ax.legend()
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
sns.despine(offset=10)
return ax
def get_fisher_matrix(model, criterion, X, Y, old_fisher_matrix=None):
fisher_matrix = [torch.zeros(p.data.shape) for p in model.parameters()]
counts = 0
#print(X.shape, Y.shape)
for x,y in tqdm(zip(X[:, np.newaxis, np.newaxis], Y[:, np.newaxis])):
criterion.zero_grad()
#print(x.shape, y.shape)
y_pred = model(x)
#print(y_pred.data.shape)
loss = criterion(y_pred, Variable(torch.Tensor(y), requires_grad=False))
loss.backward()
for i, p in enumerate(model.parameters()):
grad_squared = torch.pow(p.grad.data, 2)
#print(grad_squared.shape, fisher_matrix[i].shape)
fisher_matrix[i] += grad_squared
counts += 1
#print(grad_squared.shape, fisher_matrix[i].shape)
## Calculate the means
fisher_matrix = [k/counts for k in fisher_matrix]
if old_fisher_matrix:
fisher_matrix = [k + o for k,o in zip(fisher_matrix, old_fisher_matrix)]
return fisher_matrix
def train_step(
model, X, Y, criterion, optimizer,
fisher_matrix=None, model_params=None,
ecw_alpha=0, clip=10):
y_pred = model(X[:, np.newaxis])
loss = criterion(y_pred, Variable(torch.Tensor(Y), requires_grad=False))
train_loss = loss.detach().numpy().item()
if ecw_alpha > 0:
ewc_loss = ewc_criterion(model, fisher_matrix, model_params, alpha=ecw_alpha)
loss += ewc_loss
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(),clip)
optimizer.step()
return train_loss
def train_plot_compare(X_obs, y_obs, prefix="IID", clip=10, uncertainty=0, l2=0.01, **kwargs):
model = SimpleModel(**kwargs)
model_simple = copy.deepcopy(model)
fig, ax = plt.subplots(1,3, sharex=True, sharey=True, figsize=(18, 6))
fig_loss, ax_loss = plt.subplots(3,1, sharex=True, sharey=False, figsize=(10, 12))
## Task 1
fisher_matrix = None
model_params = None
N = 0
for i, X, Y, alpha in zip(
range(3),
[X_obs[:10], X_obs[10:20], X_obs[20:]],
[y_obs[:10], y_obs[10:20], y_obs[20:]],
[0, 1e6, 1e6] # Alpha should increase with tasks
):
title = "Task {}".format(" + ".join(str(k+1) for k in range(i+1)))
model.train()
criterion, optimizer = get_criterion_optimizer(model)
model_simple.train()
criterion_simple, optimizer_simple = get_criterion_optimizer(model_simple)
losses = []
losses_simple = []
for epoch in trange(10000):
loss = train_step(
model, X, Y,
criterion, optimizer,
fisher_matrix=fisher_matrix,
model_params=model_params,
ecw_alpha=alpha, clip=clip
)
losses.append(loss)
loss = train_step(
model_simple, X, Y, criterion_simple, optimizer_simple, clip=clip)
losses_simple.append(loss)
plot_losses(losses, losses_simple, ax=ax_loss[i])
N = X.shape[0]
plot_predictions(model, model_simple, ax=ax[i], uncertainty=uncertainty, N=N, l2=l2)
ax_loss[i].set_title(title)
ax[i].set_title(title)
print(title)
model.eval()
fisher_matrix = get_fisher_matrix(model, criterion, X, Y, old_fisher_matrix=fisher_matrix)
model_params = [copy.deepcopy(p.data) for p in model.parameters()]
fig.tight_layout()
fig_loss.tight_layout()
sns.despine(offset=10, fig=fig)
sns.despine(offset=10, fig=fig_loss)
fig.savefig("{}_EWC_predictions.pdf".format(prefix), bbox_inches="tight")
fig.savefig("{}_EWC_losses.pdf".format(prefix), bbox_inches="tight")
return model, model_simple
plt.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
plt.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
plt.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
plt.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
plt.legend()
sns.despine(offset=10)
model, model_simple = train_plot_compare(
X_obs, y_obs, prefix="IID",
p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU)
0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 100%|██████████| 10000/10000 [00:24<00:00, 412.10it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 10it [00:00, 1117.08it/s] 0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 0%| | 30/10000 [00:00<00:34, 292.75it/s]
Task 1
100%|██████████| 10000/10000 [00:32<00:00, 311.32it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 10it [00:00, 1062.31it/s] 0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([20])) that is different to the input size (torch.Size([20, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 0%| | 30/10000 [00:00<00:34, 292.95it/s]
Task 1 + 2
100%|██████████| 10000/10000 [00:32<00:00, 304.10it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 20it [00:00, 1407.79it/s]
Task 1 + 2 + 3
"""
model.eval()
fisher_matrix = get_fisher_matrix(model, criterion, X_obs[:10], y_obs[:10])
model_params = [copy.deepcopy(p.data) for p in model.parameters()]
fig = plt.figure(0)
ax = fig.add_subplot(111)
el = plt.matplotlib.patches.Ellipse(
xy=[
model_params[0].squeeze().numpy()[0],
model_params[1].numpy()[0]
],
width=2./fisher_matrix[0].squeeze().numpy()[0],
height=2./fisher_matrix[1].numpy()[0])
ax.add_artist(el)
#el.set_clip_box(ax.bbox)
el.set_color('0.5')
el.set_alpha(0.1)
ax.scatter(model_params[0].squeeze().numpy(), model_params[1].numpy(), s=1)
"""
"\nmodel.eval()\nfisher_matrix = get_fisher_matrix(model, criterion, X_obs[:10], y_obs[:10])\nmodel_params = [copy.deepcopy(p.data) for p in model.parameters()]\n\nfig = plt.figure(0)\nax = fig.add_subplot(111)\n\n\nel = plt.matplotlib.patches.Ellipse(\n xy=[\n model_params[0].squeeze().numpy()[0],\n model_params[1].numpy()[0]\n ],\n width=2./fisher_matrix[0].squeeze().numpy()[0],\n height=2./fisher_matrix[1].numpy()[0])\nax.add_artist(el)\n#el.set_clip_box(ax.bbox)\nel.set_color('0.5')\nel.set_alpha(0.1)\nax.scatter(model_params[0].squeeze().numpy(), model_params[1].numpy(), s=1)\n"
sorted_idx = np.argsort(X_obs)
X_obs = X_obs[sorted_idx]
y_obs = y_obs[sorted_idx]
plt.plot(X_obs[:10], y_obs[:10], ls="none", marker="o", color="b", label="task 1")
plt.plot(X_obs[10:20], y_obs[10:20], ls="none", marker="d", color="cyan", label="task 2")
plt.plot(X_obs[20:], y_obs[20:], ls="none", marker="s", color="magenta", label="task 3")
plt.plot(X_true, y_true, ls="-", lw=1, color="0.5", label="true")
plt.legend()
sns.despine(offset=10)
model, model_simple = train_plot_compare(
X_obs, y_obs, prefix="Seq",
p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU)
0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 100%|██████████| 10000/10000 [00:24<00:00, 410.83it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 10it [00:00, 1205.02it/s] 0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 0%| | 30/10000 [00:00<00:33, 299.92it/s]
Task 1
100%|██████████| 10000/10000 [00:31<00:00, 313.66it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 10it [00:00, 1024.68it/s] 0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([20])) that is different to the input size (torch.Size([20, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 0%| | 29/10000 [00:00<00:34, 284.89it/s]
Task 1 + 2
100%|██████████| 10000/10000 [00:32<00:00, 305.85it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 20it [00:00, 1322.19it/s]
Task 1 + 2 + 3
Before each forward pass
$$ \begin{equation} \sigma_{\theta} = F_{\theta}^{-1/2}\\ \epsilon \sim \mathcal{N}(0, 1)\\ \theta = \theta + \sigma*\epsilon \end{equation} $$model, model_simple = train_plot_compare(
X_obs, y_obs, prefix="Seq", uncertainty=100, l2=0.01,
p=0.1, decay=1e-6, non_linearity=torch.nn.ReLU
)
0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 100%|██████████| 10000/10000 [00:23<00:00, 418.13it/s] 100%|██████████| 100/100 [00:00<00:00, 2087.48it/s] 100%|██████████| 100/100 [00:00<00:00, 2137.25it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 10it [00:00, 1106.21it/s] 0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([10])) that is different to the input size (torch.Size([10, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 0%| | 30/10000 [00:00<00:33, 295.57it/s]
Task 1
100%|██████████| 10000/10000 [00:32<00:00, 306.91it/s] 100%|██████████| 100/100 [00:00<00:00, 2181.82it/s] 100%|██████████| 100/100 [00:00<00:00, 2066.33it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 10it [00:00, 1113.58it/s] 0%| | 0/10000 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([20])) that is different to the input size (torch.Size([20, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:128: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 0%| | 30/10000 [00:00<00:33, 296.60it/s]
Task 1 + 2
100%|██████████| 10000/10000 [00:35<00:00, 280.99it/s] 100%|██████████| 100/100 [00:00<00:00, 2104.83it/s] 100%|██████████| 100/100 [00:00<00:00, 2165.10it/s] 0it [00:00, ?it/s]/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) 20it [00:00, 1247.17it/s]
Task 1 + 2 + 3