A sample example demonstrate Naive initialization vs. Xavier initialization with Tanh activation function.
require 'torch'
require 'nn'
Plot = require 'itorch.Plot'
model = nn.Sequential()
:add(nn.Linear(50, 500))
:add(nn.Tanh())
for i = 2,20 do
model:add(nn.Linear(500, 500))
:add(nn.Tanh())
end
-- Gaussian initialization
for k,v in pairs(model:findModules('nn.Linear')) do
v.weight:normal(0,0.01)
end
x = torch.Tensor(10, 50)
model:forward(x)
-- Plot histogram of activations
std = torch.Tensor(10)
mean = torch.Tensor(10)
for i = 2, 20, 2 do
out = model.modules[i].output
plot = Plot():histogram(out, 100, -1, 1):draw()
plot:title(('Activation histogram: Layer %d'):format(i/2)):redraw()
plot:xaxis('Value'):yaxis('Number'):redraw()
std[i/2] = out:std()
mean[i/2] = out:mean()
end
-- plot variance
plot = Plot():line(torch.range(1, 10), std):draw()
plot:title('Variance'):redraw()
plot:xaxis('Layer'):redraw()
-- plot mean
plot = Plot():line(torch.range(1, 10), mean):draw()
plot:title('Mean'):redraw()
plot:xaxis('Layer'):redraw()
### Step 3: Use Xavier method to initialize the weights. Plot the activation of each Tanh layer.
-- Xavier initialization
for k,v in pairs(model:findModules('nn.Linear')) do
local n = v.weight:size(1)
v.weight:normal(0,1/math.sqrt(n))
end
x = torch.Tensor(10, 50)
model:forward(x)
-- Plot histogram of activations
std = torch.Tensor(10)
mean = torch.Tensor(10)
for i = 2, 20, 2 do
out = model.modules[i].output
plot = Plot():histogram(out, 100, -1, 1):draw()
plot:title(('Activation histogram: Layer %d'):format(i/2)):redraw()
plot:xaxis('Value'):yaxis('Number'):redraw()
std[i/2] = out:std()
mean[i/2] = out:mean()
end
-- plot variance
plot = Plot():line(torch.range(1, 10), std):draw()
plot:title('Variance'):redraw()
plot:xaxis('Layer'):redraw()
-- plot mean
plot = Plot():line(torch.range(1, 10), mean):draw()
plot:title('Mean'):redraw()
plot:xaxis('Layer'):redraw()