%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
input_size = 28*28 # размер изображения
output_size = 10 # 10 классов
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True)
# show some images
plt.figure(figsize=(16, 6))
for i in range(10):
plt.subplot(2, 5, i + 1)
image, _ = train_loader.dataset.__getitem__(i)
plt.imshow(image.squeeze().numpy())
plt.axis('off');
train_loader.dataset.__getitem__(0)[0].squeeze()[15:20, 15:20]
tensor([[ 1.9432, 2.7960, 2.7960, 1.4850, -0.0806], [-0.2206, 0.7595, 2.7833, 2.7960, 1.9560], [-0.4242, -0.4242, 2.7451, 2.7960, 2.7451], [ 1.2305, 1.9051, 2.7960, 2.7960, 2.2105], [ 2.7960, 2.7960, 2.7960, 2.7578, 1.8923]])
class Mlinear(nn.Module):
"""
линейная
"""
def __init__(self, input_size, output_size):
super(Mlinear, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=28)
# self.fc = nn.Linear(28*28, output_size)
def forward(self, x, verbose=False):
x = self.conv(x)
x = x.view(-1, 10)
x = F.log_softmax(x, dim=1)
return x
accuracy_list = []
def train(epoch, model):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# send to device
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model):
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
# send to device
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
accuracy_list.append(accuracy)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
accuracy))
model = Mlinear(input_size, output_size)
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
print('Number of parameters: {}'.format(get_n_params(model_cnn)))
for epoch in range(0, 3):
train(epoch, model)
test(model)
# Train Epoch: 0 [57600/60000 (96%)] Loss: 0.122960
#Test set: Average loss: 0.4662, Accuracy: 8823/10000 (88%)
Number of parameters: 14070 Train Epoch: 0 [0/60000 (0%)] Loss: 2.283597 Train Epoch: 0 [6400/60000 (11%)] Loss: 0.512199 Train Epoch: 0 [12800/60000 (21%)] Loss: 0.525875 Train Epoch: 0 [19200/60000 (32%)] Loss: 0.345186 Train Epoch: 0 [25600/60000 (43%)] Loss: 0.178718 Train Epoch: 0 [32000/60000 (53%)] Loss: 0.506424 Train Epoch: 0 [38400/60000 (64%)] Loss: 0.202364 Train Epoch: 0 [44800/60000 (75%)] Loss: 0.322217 Train Epoch: 0 [51200/60000 (85%)] Loss: 0.226272 Train Epoch: 0 [57600/60000 (96%)] Loss: 0.338484 Test set: Average loss: 0.3061, Accuracy: 9151/10000 (92%) Train Epoch: 1 [0/60000 (0%)] Loss: 0.218785 Train Epoch: 1 [6400/60000 (11%)] Loss: 0.269072 Train Epoch: 1 [12800/60000 (21%)] Loss: 0.366377 Train Epoch: 1 [19200/60000 (32%)] Loss: 0.290509 Train Epoch: 1 [25600/60000 (43%)] Loss: 0.251645 Train Epoch: 1 [32000/60000 (53%)] Loss: 0.261202 Train Epoch: 1 [38400/60000 (64%)] Loss: 0.258060 Train Epoch: 1 [44800/60000 (75%)] Loss: 0.233534 Train Epoch: 1 [51200/60000 (85%)] Loss: 0.273988 Train Epoch: 1 [57600/60000 (96%)] Loss: 0.356423 Test set: Average loss: 0.2911, Accuracy: 9176/10000 (92%) Train Epoch: 2 [0/60000 (0%)] Loss: 0.418493 Train Epoch: 2 [6400/60000 (11%)] Loss: 0.511104 Train Epoch: 2 [12800/60000 (21%)] Loss: 0.163205 Train Epoch: 2 [19200/60000 (32%)] Loss: 0.419903 Train Epoch: 2 [25600/60000 (43%)] Loss: 0.322113 Train Epoch: 2 [32000/60000 (53%)] Loss: 0.124521 Train Epoch: 2 [38400/60000 (64%)] Loss: 0.266031 Train Epoch: 2 [44800/60000 (75%)] Loss: 0.305101 Train Epoch: 2 [51200/60000 (85%)] Loss: 0.322588 Train Epoch: 2 [57600/60000 (96%)] Loss: 0.277742 Test set: Average loss: 0.2833, Accuracy: 9195/10000 (92%)
# параметры модели
for name, param in model.named_parameters():
if param.requires_grad:
print (name, param.data)
conv.weight tensor([[[[ 0.0226, 0.0314, -0.0149, ..., -0.0147, -0.0005, -0.0037], [ 0.0301, 0.0315, 0.0082, ..., -0.0155, 0.0063, 0.0263], [-0.0069, -0.0016, 0.0011, ..., -0.0034, 0.0063, -0.0007], ..., [-0.0143, 0.0260, 0.0359, ..., 0.0212, -0.0014, 0.0079], [-0.0086, 0.0406, 0.0243, ..., 0.0045, -0.0114, 0.0303], [-0.0223, 0.0289, 0.0297, ..., -0.0161, 0.0109, 0.0090]]], [[[ 0.0372, -0.0093, -0.0227, ..., -0.0186, -0.0234, -0.0084], [-0.0021, 0.0118, -0.0075, ..., -0.0204, -0.0042, -0.0095], [-0.0171, 0.0390, -0.0163, ..., -0.0231, 0.0450, -0.0163], ..., [ 0.0337, -0.0048, 0.0227, ..., 0.0343, -0.0059, -0.0232], [ 0.0364, 0.0253, 0.0053, ..., -0.0078, 0.0340, 0.0336], [-0.0073, -0.0214, 0.0097, ..., 0.0100, 0.0346, -0.0243]]], [[[ 0.0179, -0.0104, -0.0209, ..., -0.0145, -0.0295, -0.0445], [ 0.0157, 0.0008, 0.0187, ..., -0.0127, 0.0226, 0.0146], [-0.0152, -0.0182, 0.0007, ..., -0.0065, 0.0029, -0.0177], ..., [ 0.0226, 0.0092, -0.0215, ..., -0.0273, 0.0170, -0.0367], [ 0.0053, 0.0043, -0.0378, ..., -0.0387, 0.0227, -0.0377], [ 0.0021, 0.0102, 0.0192, ..., 0.0236, -0.0358, -0.0168]]], ..., [[[-0.0392, 0.0213, 0.0231, ..., -0.0304, -0.0163, -0.0082], [-0.0027, -0.0191, 0.0062, ..., 0.0007, 0.0132, -0.0334], [-0.0396, -0.0168, 0.0098, ..., -0.0361, -0.0246, -0.0291], ..., [-0.0095, 0.0250, -0.0040, ..., 0.0004, -0.0078, 0.0257], [ 0.0312, 0.0038, 0.0069, ..., -0.0081, 0.0238, -0.0151], [ 0.0059, 0.0230, 0.0299, ..., 0.0052, -0.0314, 0.0218]]], [[[ 0.0098, 0.0200, 0.0365, ..., 0.0001, 0.0372, 0.0097], [ 0.0156, 0.0131, 0.0190, ..., 0.0357, 0.0061, 0.0299], [-0.0277, 0.0351, -0.0051, ..., 0.0199, 0.0011, -0.0021], ..., [ 0.0211, 0.0105, -0.0317, ..., 0.0391, -0.0150, -0.0245], [ 0.0068, 0.0194, 0.0189, ..., 0.0385, -0.0274, -0.0284], [-0.0209, 0.0247, 0.0258, ..., -0.0041, 0.0297, -0.0160]]], [[[ 0.0144, 0.0129, 0.0034, ..., 0.0358, 0.0403, 0.0397], [ 0.0325, -0.0279, 0.0264, ..., 0.0209, 0.0158, -0.0078], [ 0.0061, -0.0172, 0.0012, ..., -0.0117, 0.0024, 0.0112], ..., [ 0.0002, -0.0011, -0.0235, ..., 0.0125, 0.0040, 0.0332], [-0.0248, 0.0313, 0.0207, ..., -0.0183, -0.0133, 0.0145], [-0.0121, 0.0245, -0.0039, ..., 0.0285, 0.0236, 0.0286]]]], device='cuda:0') conv.bias tensor([-0.0448, 0.0074, 0.0315, 0.0370, 0.0048, 0.0503, -0.0397, 0.0155, 0.0160, -0.0418], device='cuda:0')
param.data
tensor([-0.0448, 0.0074, 0.0315, 0.0370, 0.0048, 0.0503, -0.0397, 0.0155, 0.0160, -0.0418], device='cuda:0')
h = list(model.named_parameters())[0][1][2].detach().to('cpu').numpy()[0,:,:]
plt.imshow(h, cmap='Purples_r')
<matplotlib.image.AxesImage at 0x15f891e2588>
# выученные патерны классов
plt.figure(figsize=(16, 6))
for i in range(10):
plt.subplot(2, 5, i + 1)
image = list(model.named_parameters())[0][1][i].detach().to('cpu').numpy()[0,:,:]
plt.imshow(image, cmap='Purples_r')
plt.axis('off');
class CNN(nn.Module):
def __init__(self, input_size, n_feature, output_size):
super(CNN, self).__init__()
self.n_feature = n_feature
self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_feature, kernel_size=5)
self.conv2 = nn.Conv2d(n_feature, n_feature, kernel_size=5)
self.fc1 = nn.Linear(n_feature*4*4, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x, verbose=False):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = x.view(-1, self.n_feature*4*4) # torch.flatten(x, 1 или 2)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, dim=1)
return x
def get_n_params(model):
"""
число параметров
"""
np=0
for p in list(model.parameters()):
np += p.nelement()
return np
# Training settings
n_features = 10 # number of feature maps
model_cnn = CNN(input_size, n_features, output_size)
model_cnn.to(device)
optimizer = torch.optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.5)
print('Number of parameters: {}'.format(get_n_params(model_cnn)))
for epoch in range(0, 3):
train(epoch, model_cnn)
test(model_cnn)
# Test set: Average loss: 0.2849, Accuracy: 9203/10000 (92%)
# Test set: Average loss: 0.1034, Accuracy: 9679/10000 (97%)
# Test set: Average loss: 0.0784, Accuracy: 9753/10000 (98%)
Number of parameters: 11330 Train Epoch: 0 [0/60000 (0%)] Loss: 2.298711 Train Epoch: 0 [6400/60000 (11%)] Loss: 1.123118 Train Epoch: 0 [12800/60000 (21%)] Loss: 0.357951 Train Epoch: 0 [19200/60000 (32%)] Loss: 0.267037 Train Epoch: 0 [25600/60000 (43%)] Loss: 0.251017 Train Epoch: 0 [32000/60000 (53%)] Loss: 0.254121 Train Epoch: 0 [38400/60000 (64%)] Loss: 0.247766 Train Epoch: 0 [44800/60000 (75%)] Loss: 0.222735 Train Epoch: 0 [51200/60000 (85%)] Loss: 0.214110 Train Epoch: 0 [57600/60000 (96%)] Loss: 0.292420 Test set: Average loss: 0.1361, Accuracy: 9603/10000 (96%) Train Epoch: 1 [0/60000 (0%)] Loss: 0.185601 Train Epoch: 1 [6400/60000 (11%)] Loss: 0.121008 Train Epoch: 1 [12800/60000 (21%)] Loss: 0.135417 Train Epoch: 1 [19200/60000 (32%)] Loss: 0.052468 Train Epoch: 1 [25600/60000 (43%)] Loss: 0.061729 Train Epoch: 1 [32000/60000 (53%)] Loss: 0.110413 Train Epoch: 1 [38400/60000 (64%)] Loss: 0.115863 Train Epoch: 1 [44800/60000 (75%)] Loss: 0.076380 Train Epoch: 1 [51200/60000 (85%)] Loss: 0.028216 Train Epoch: 1 [57600/60000 (96%)] Loss: 0.035985 Test set: Average loss: 0.0748, Accuracy: 9772/10000 (98%) Train Epoch: 2 [0/60000 (0%)] Loss: 0.110379 Train Epoch: 2 [6400/60000 (11%)] Loss: 0.178392 Train Epoch: 2 [12800/60000 (21%)] Loss: 0.233576 Train Epoch: 2 [19200/60000 (32%)] Loss: 0.026292 Train Epoch: 2 [25600/60000 (43%)] Loss: 0.099997 Train Epoch: 2 [32000/60000 (53%)] Loss: 0.158976 Train Epoch: 2 [38400/60000 (64%)] Loss: 0.063617 Train Epoch: 2 [44800/60000 (75%)] Loss: 0.032139 Train Epoch: 2 [51200/60000 (85%)] Loss: 0.075119 Train Epoch: 2 [57600/60000 (96%)] Loss: 0.101296 Test set: Average loss: 0.0689, Accuracy: 9768/10000 (98%)
# show some images
plt.figure(figsize=(16, 6))
for i in range(10):
plt.subplot(2, 5, i + 1)
image = list(model_cnn.named_parameters())[0][1][i].detach().to('cpu').numpy()[0,:,:]
plt.imshow(image, cmap='Purples_r')
plt.axis('off');
f = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
x = torch.randn(1, 1, 28, 28)
print (x.shape)
x = f(x)
print (x.shape)
x = F.max_pool2d(x, kernel_size=2)
print (x.shape)
x = f(x)
print (x.shape)
x = F.max_pool2d(x, kernel_size=2)
print (x.shape)
torch.Size([1, 1, 28, 28]) torch.Size([1, 1, 26, 26]) torch.Size([1, 1, 13, 13]) torch.Size([1, 1, 11, 11]) torch.Size([1, 1, 5, 5])
class CNN(nn.Module):
def __init__(self, input_size, n_feature, output_size):
super(CNN, self).__init__()
self.n_feature = n_feature
self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_feature, kernel_size=3)
self.conv2 = nn.Conv2d(n_feature, n_feature, kernel_size=3)
self.fc1 = nn.Linear(n_feature*5*5, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x, verbose=False):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = x.view(-1, self.n_feature*5*5)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, dim=1)
return x
# Training settings
n_features = 10 # number of feature maps
model_cnn = CNN(input_size, n_features, output_size)
model_cnn.to(device)
optimizer = torch.optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.5)
print('Number of parameters: {}'.format(get_n_params(model_cnn)))
for epoch in range(0, 3):
train(epoch, model_cnn)
test(model_cnn)
Number of parameters: 14070 Train Epoch: 0 [0/60000 (0%)] Loss: 2.331818 Train Epoch: 0 [6400/60000 (11%)] Loss: 1.055310 Train Epoch: 0 [12800/60000 (21%)] Loss: 0.538627 Train Epoch: 0 [19200/60000 (32%)] Loss: 0.351977 Train Epoch: 0 [25600/60000 (43%)] Loss: 0.215159 Train Epoch: 0 [32000/60000 (53%)] Loss: 0.555545 Train Epoch: 0 [38400/60000 (64%)] Loss: 0.221845 Train Epoch: 0 [44800/60000 (75%)] Loss: 0.319424 Train Epoch: 0 [51200/60000 (85%)] Loss: 0.169147 Train Epoch: 0 [57600/60000 (96%)] Loss: 0.065263 Test set: Average loss: 0.1448, Accuracy: 9565/10000 (96%) Train Epoch: 1 [0/60000 (0%)] Loss: 0.130074 Train Epoch: 1 [6400/60000 (11%)] Loss: 0.156131 Train Epoch: 1 [12800/60000 (21%)] Loss: 0.127868 Train Epoch: 1 [19200/60000 (32%)] Loss: 0.061610 Train Epoch: 1 [25600/60000 (43%)] Loss: 0.156055 Train Epoch: 1 [32000/60000 (53%)] Loss: 0.246336 Train Epoch: 1 [38400/60000 (64%)] Loss: 0.159108 Train Epoch: 1 [44800/60000 (75%)] Loss: 0.185926 Train Epoch: 1 [51200/60000 (85%)] Loss: 0.067308 Train Epoch: 1 [57600/60000 (96%)] Loss: 0.090610 Test set: Average loss: 0.1085, Accuracy: 9661/10000 (97%) Train Epoch: 2 [0/60000 (0%)] Loss: 0.136817 Train Epoch: 2 [6400/60000 (11%)] Loss: 0.069054 Train Epoch: 2 [12800/60000 (21%)] Loss: 0.024138 Train Epoch: 2 [19200/60000 (32%)] Loss: 0.058952 Train Epoch: 2 [25600/60000 (43%)] Loss: 0.025060 Train Epoch: 2 [32000/60000 (53%)] Loss: 0.162402 Train Epoch: 2 [38400/60000 (64%)] Loss: 0.078908 Train Epoch: 2 [44800/60000 (75%)] Loss: 0.073329 Train Epoch: 2 [51200/60000 (85%)] Loss: 0.130581 Train Epoch: 2 [57600/60000 (96%)] Loss: 0.072167 Test set: Average loss: 0.0972, Accuracy: 9719/10000 (97%)
# show some images
plt.figure(figsize=(16, 6))
for i in range(10):
plt.subplot(2, 5, i + 1)
image = list(model_cnn.named_parameters())[0][1][i].detach().to('cpu').numpy()[0,:,:]
plt.imshow(image, cmap='Purples_r')
plt.axis('off');
from torchsummary import summary
summary(model_cnn, (1, 28, 28))
--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) <ipython-input-108-da6ed966fdfe> in <module> ----> 1 from torchsummary import summary 2 summary(model_cnn, (1, 28, 28)) ModuleNotFoundError: No module named 'torchsummary'
for name, param in model_cnn.named_parameters():
if param.requires_grad:
print (name, param.data)
conv1.weight tensor([[[[-0.0514, -0.2431, -0.1723], [ 0.0057, -0.3095, 0.3371], [ 0.0058, -0.1519, 0.1029]]], [[[ 0.3809, 0.6067, 0.4013], [-0.0851, 0.5524, 0.4146], [ 0.3475, -0.2006, -0.2027]]], [[[-0.1586, -0.4720, -0.3614], [ 0.1608, 0.2759, 0.0646], [ 0.4119, 0.1763, 0.4591]]], [[[ 0.1649, -0.0511, -0.0633], [-0.1093, -0.3075, -0.0859], [ 0.3718, 0.0793, -0.0369]]], [[[ 0.4346, -0.1011, 0.2801], [ 0.2101, 0.2285, 0.3284], [ 0.0127, -0.1007, 0.2325]]], [[[-0.0070, -0.2006, -0.0871], [ 0.4075, 0.0946, 0.3187], [-0.2273, 0.3887, 0.0815]]], [[[ 0.4000, 0.4220, -0.4082], [ 0.2722, 0.5044, -0.1749], [ 0.2318, 0.2525, 0.2803]]], [[[-0.2414, 0.3981, 0.3849], [-0.0811, -0.3138, -0.1542], [-0.0118, -0.1257, -0.2086]]], [[[ 0.1675, 0.3784, 0.4560], [ 0.6253, 0.6030, 0.5062], [ 0.3326, -0.0977, 0.0698]]], [[[ 0.4483, -0.0163, 0.2090], [ 0.5148, 0.0363, -0.4205], [ 0.4959, 0.2006, -0.2910]]]], device='cuda:0') conv1.bias tensor([-0.1804, 0.4074, 0.2444, -0.0375, 0.1065, 0.1235, -0.1803, 0.0294, 0.2389, 0.2232], device='cuda:0') conv2.weight tensor([[[[ 5.9076e-02, -2.6565e-02, -4.1788e-02], [ 3.1889e-03, -8.6533e-02, -6.9778e-04], [ 8.4063e-02, 3.7424e-02, 4.5176e-02]], [[-1.1259e-02, -4.1894e-02, -2.0995e-02], [-6.3136e-02, -4.9346e-02, -1.1085e-01], [-6.6448e-02, -4.8364e-04, 9.2690e-02]], [[-9.4987e-04, -6.7392e-02, 1.9795e-02], [ 9.0485e-02, 7.4686e-02, -2.7720e-02], [-3.7392e-02, -1.0458e-01, -5.3981e-02]], [[-2.5824e-02, 1.1920e-02, -7.6085e-04], [ 9.5956e-02, -5.3208e-02, -8.6919e-02], [ 4.7483e-02, 4.9886e-02, 1.8662e-02]], [[ 1.5804e-02, 3.7792e-02, -9.3772e-02], [-7.1214e-02, 3.0196e-03, -5.9971e-02], [-2.0813e-02, 3.1865e-02, -2.4242e-02]], [[-6.1230e-03, 5.3320e-02, 9.7085e-02], [ 8.1711e-02, 8.2846e-02, 8.0976e-02], [-7.0254e-02, 5.1828e-02, 2.8898e-02]], [[ 7.3385e-02, 4.1816e-02, -7.7529e-03], [ 8.9649e-02, 3.4419e-02, -7.2960e-02], [-7.5113e-02, 3.4371e-02, -6.8647e-02]], [[ 2.2018e-02, 6.1263e-02, -7.4162e-02], [ 6.5115e-02, 2.3076e-02, -9.9376e-02], [ 7.5677e-02, -8.0766e-02, -5.5521e-02]], [[ 3.7780e-02, -5.8972e-02, 4.0146e-02], [-1.6778e-02, 6.0969e-02, 4.8875e-02], [-5.8669e-02, -6.5347e-02, -5.1612e-02]], [[-4.7206e-02, -7.8889e-03, -1.0336e-01], [ 7.8692e-02, 6.1455e-02, 1.7561e-02], [-1.0706e-01, -1.9978e-02, -6.2109e-02]]], [[[-1.8537e-02, 9.9011e-02, -2.0159e-02], [ 8.4831e-02, -5.9867e-02, 8.3410e-03], [ 8.7308e-02, 8.9856e-02, 6.6945e-02]], [[-3.9370e-02, -4.3701e-02, 1.5527e-01], [ 2.1149e-02, 1.9016e-01, -6.6300e-02], [-7.1316e-02, 1.0533e-01, 9.6015e-02]], [[-8.2563e-02, -1.1281e-01, -1.1413e-01], [ 4.3033e-02, -7.5594e-02, 4.6499e-03], [ 1.6978e-02, 5.7545e-02, 9.3755e-02]], [[ 7.2455e-02, 4.9544e-02, 5.8619e-02], [ 5.3484e-02, 2.9096e-02, -4.0237e-02], [ 3.5845e-02, -1.0365e-02, 7.9267e-02]], [[-8.3150e-02, -5.5840e-02, 4.2487e-02], [-5.2144e-02, -1.5745e-02, -6.1907e-02], [ 7.9476e-02, 6.5762e-02, 8.1716e-02]], [[ 3.7845e-02, 5.4487e-03, -3.9383e-02], [ 3.4140e-02, -5.4639e-02, -1.1093e-01], [ 1.6301e-02, 4.6800e-03, 1.0656e-01]], [[-9.9955e-02, -3.3375e-03, 1.4250e-01], [-6.9353e-02, 8.3617e-02, 1.2884e-02], [-7.5606e-03, 9.2606e-02, 3.9810e-02]], [[-1.9099e-02, -6.1345e-02, -2.5744e-02], [-6.4966e-03, -1.9361e-02, -9.7550e-02], [-1.2468e-01, -7.2318e-02, 1.8926e-02]], [[-7.4540e-02, 4.2869e-02, 1.5220e-01], [ 7.9301e-02, 5.4725e-02, 7.3146e-02], [ 4.0899e-02, 7.8887e-02, -6.2909e-02]], [[-5.7160e-02, 1.2252e-01, 6.7105e-02], [-4.6142e-02, -1.2378e-02, 2.3991e-03], [ 1.9359e-02, 1.1894e-01, -8.0045e-02]]], [[[ 6.9704e-02, -1.8565e-02, 7.6791e-02], [ 4.7858e-02, 7.4375e-02, 2.9920e-03], [-4.4445e-02, -2.0205e-02, 5.9885e-02]], [[-1.4701e-01, 1.4020e-01, 1.7964e-01], [ 1.0141e-01, 1.3842e-01, 1.0777e-01], [ 1.2638e-01, -7.4196e-02, -2.1484e-02]], [[-2.1506e-02, -1.6837e-02, -4.3567e-02], [-6.9740e-02, -1.4438e-01, -6.2384e-02], [-1.3625e-01, -1.5922e-01, -1.0496e-01]], [[-5.0757e-02, -1.2064e-01, -3.1925e-02], [-7.6427e-02, -6.2134e-02, -8.6800e-02], [ 7.0225e-02, -9.1693e-02, 7.6837e-02]], [[-7.5175e-02, -2.5560e-02, 7.3325e-02], [-7.2111e-02, 1.4787e-02, 1.2389e-02], [ 1.0860e-01, 9.3941e-03, -1.3877e-01]], [[ 5.7890e-02, -7.4921e-02, -4.6909e-02], [ 3.4880e-02, -1.1088e-02, -3.9691e-02], [ 7.8888e-02, -1.1274e-02, -7.5852e-02]], [[-4.9826e-02, 1.3141e-01, 2.3973e-02], [ 2.7764e-02, 1.0563e-01, -7.7097e-02], [ 1.4997e-01, -6.7689e-02, -1.3027e-01]], [[-5.2840e-02, -4.2551e-02, 2.4171e-02], [-1.6392e-02, 1.8365e-02, -8.8095e-02], [ 7.5372e-03, -2.9376e-02, -4.3591e-02]], [[-5.2274e-02, 5.2929e-02, 8.7200e-02], [ 8.5147e-02, 1.3840e-01, 2.7805e-02], [ 1.3810e-01, -2.4092e-02, -2.0679e-02]], [[-1.4207e-01, 8.7894e-02, 3.9349e-02], [ 4.9848e-02, 1.8381e-01, -6.5054e-03], [ 1.7410e-01, -2.6596e-02, -1.1204e-03]]], [[[ 8.2345e-02, 6.2382e-02, 1.0324e-01], [-1.9988e-02, 7.7680e-02, -1.3389e-02], [ 1.2256e-02, -1.0321e-01, -5.2999e-02]], [[ 1.9159e-01, 3.0689e-02, -2.0002e-01], [ 8.0461e-02, -6.6078e-02, -1.5594e-01], [ 1.0643e-01, -9.9147e-02, -1.8547e-02]], [[ 1.2399e-02, -2.2664e-02, -1.7081e-01], [-4.3163e-02, -1.2718e-01, -1.0661e-01], [-9.4850e-02, -3.8544e-02, -4.1183e-02]], [[ 2.2451e-02, 6.7264e-02, 4.1288e-02], [-1.0630e-01, -4.0538e-02, 6.0974e-02], [-1.0021e-01, -2.4984e-02, 2.6707e-02]], [[ 1.0796e-01, 2.5386e-02, -2.0009e-02], [ 1.2417e-01, 2.5089e-02, -8.2183e-04], [-1.2193e-02, 1.2333e-02, 2.4157e-02]], [[-5.5838e-02, -1.2641e-01, -1.0063e-01], [-4.5092e-02, 2.5683e-02, 9.6554e-03], [-7.1762e-02, -3.1490e-02, 4.1435e-02]], [[ 3.1010e-02, 1.3505e-01, -1.4548e-01], [ 8.6445e-02, -2.0193e-02, -1.5723e-01], [ 2.4260e-02, -6.8915e-02, -1.3459e-02]], [[-1.0876e-01, -8.3097e-02, -1.1764e-01], [-6.2837e-02, -9.5701e-03, -3.6829e-03], [-2.0633e-02, -4.3346e-02, -3.1538e-02]], [[ 1.4669e-01, 1.6590e-01, -7.5302e-02], [ 2.0891e-01, -6.8543e-03, -1.3578e-01], [ 4.0695e-02, -1.4731e-01, -1.2741e-01]], [[ 5.1247e-03, 1.2057e-01, -1.1057e-01], [ 2.1724e-01, 1.4437e-01, -1.3677e-01], [ 1.4806e-01, -2.3436e-02, -1.5112e-01]]], [[[ 2.0710e-02, -4.1513e-02, -1.0415e-01], [-3.4979e-02, -4.8402e-02, 1.5652e-02], [-3.8971e-02, -1.0191e-01, -4.9485e-02]], [[ 8.2084e-02, -1.3087e-01, -4.4392e-03], [ 8.2083e-02, -7.1619e-02, -1.0195e-01], [ 2.0567e-01, 1.6907e-01, 1.2443e-01]], [[ 1.7063e-02, 3.6242e-02, 3.2769e-02], [-2.6827e-03, 8.4845e-02, -1.0511e-02], [ 7.4503e-02, 1.4263e-01, -9.5349e-04]], [[ 1.1938e-01, -4.4060e-02, -2.9910e-02], [-1.3693e-02, 5.5252e-02, -6.2627e-02], [-5.4462e-02, 9.5586e-02, 2.6886e-02]], [[-9.3587e-02, -1.1237e-01, -5.4325e-02], [ 1.0213e-01, -9.4426e-03, 2.8959e-02], [ 3.4891e-02, 4.0331e-02, 1.2232e-01]], [[ 4.1780e-02, -5.2358e-02, -1.0307e-01], [-1.3574e-02, -1.3659e-02, 1.2250e-01], [ 2.7789e-02, 7.5784e-02, 4.1689e-02]], [[ 1.6387e-01, -9.3739e-04, 5.2678e-02], [ 2.6345e-01, 9.4617e-02, -4.4928e-03], [ 4.0353e-02, -1.6552e-02, 4.2607e-02]], [[-9.6102e-02, 1.9253e-02, 3.0643e-02], [ 5.5283e-02, -4.3689e-02, 1.2706e-02], [ 8.1290e-02, 1.4512e-01, -2.3140e-02]], [[ 3.0978e-02, -8.7364e-02, -1.6240e-01], [-3.6135e-03, 1.9760e-02, -3.4974e-02], [ 1.2382e-01, 1.3257e-01, 8.0831e-03]], [[ 1.0662e-01, -7.3724e-02, -2.4207e-02], [ 2.2863e-02, -3.1883e-02, 6.1743e-02], [-3.0489e-02, -4.1665e-02, -1.0659e-02]]], [[[-7.7013e-02, 4.4411e-02, -5.7106e-03], [ 1.0085e-02, -6.9025e-02, 6.7091e-02], [-4.4660e-02, -9.9406e-02, 1.1563e-02]], [[-2.8620e-02, -1.7471e-01, -1.2465e-01], [-8.6638e-02, -1.2746e-01, -1.4815e-01], [ 2.0066e-01, 2.3034e-01, 1.0988e-01]], [[ 1.1712e-02, 3.1940e-02, -1.4684e-01], [ 3.2194e-01, 2.0890e-01, 2.2180e-01], [ 1.9225e-01, 2.4178e-01, 2.0562e-01]], [[ 5.3113e-02, -6.3626e-02, 3.1137e-02], [ 7.4386e-02, 3.0867e-02, -4.2311e-02], [-5.6162e-02, 9.9251e-02, -7.4473e-02]], [[-1.7348e-01, -4.1872e-02, -1.4924e-01], [ 8.7330e-02, -4.1795e-02, 7.0794e-02], [ 1.2144e-01, 2.5517e-01, 1.3382e-01]], [[-5.0383e-02, -6.2036e-02, 2.9752e-03], [ 1.8679e-01, -2.0486e-02, -2.5592e-02], [-4.5648e-03, 1.3278e-01, 1.0128e-01]], [[-1.8310e-01, -9.2753e-02, -1.1574e-01], [ 1.1653e-01, -7.9821e-02, -7.3743e-03], [ 6.1697e-02, 1.0216e-01, 1.1081e-01]], [[ 1.3767e-01, 1.4784e-01, 1.2850e-01], [-4.0645e-03, -7.8108e-02, 8.8162e-02], [ 1.7162e-02, -7.5136e-03, 7.0292e-03]], [[-2.3875e-01, -1.3495e-01, -2.6735e-01], [ 1.8919e-02, -4.6029e-02, -8.7981e-02], [ 3.2470e-01, 2.5244e-01, 3.2636e-01]], [[-1.0996e-01, -1.0246e-01, -1.0940e-01], [ 4.5100e-02, -1.3811e-01, -9.9027e-02], [ 1.8178e-01, -6.9141e-02, 2.2425e-02]]], [[[ 8.1304e-02, 8.5523e-02, 2.8239e-02], [ 1.2146e-01, -1.6598e-03, -6.9442e-02], [ 7.6967e-02, 5.6451e-03, -9.5069e-02]], [[ 4.3914e-02, 9.1648e-02, -1.7205e-01], [ 1.7080e-01, -1.1071e-01, -1.3310e-01], [ 1.3667e-01, -7.5021e-03, -1.1446e-01]], [[-1.4627e-01, -1.6245e-01, -1.5733e-01], [-1.4693e-01, -7.3456e-02, -1.1224e-01], [-2.5679e-02, -4.2461e-03, 3.8108e-02]], [[-7.4153e-02, 3.8227e-02, 1.8934e-02], [ 9.6632e-02, 4.2418e-02, -5.1987e-02], [ 1.1242e-01, 3.8215e-02, 9.0500e-02]], [[-3.8713e-02, -4.5122e-02, -3.1551e-02], [ 1.6441e-01, 7.4103e-02, -5.1727e-02], [ 1.5601e-01, 7.5679e-02, -5.8606e-02]], [[ 4.6529e-02, -8.6091e-02, -1.5009e-01], [-2.4289e-02, 3.7246e-02, 2.5940e-03], [-3.9534e-02, 1.4855e-01, 9.1134e-02]], [[ 1.1227e-01, 1.3539e-01, 4.2598e-03], [ 2.1720e-01, 1.9731e-01, -2.4246e-02], [ 2.3318e-01, 1.7844e-01, 8.4043e-02]], [[-1.1961e-01, 3.5296e-02, 3.1607e-02], [-1.4756e-01, 3.9193e-02, 7.1521e-02], [ 2.5645e-02, -5.3175e-02, -2.4968e-02]], [[ 8.4871e-02, -4.1209e-02, -2.2891e-01], [ 6.7216e-02, -9.0103e-02, -1.8440e-01], [ 1.0143e-01, 1.4776e-01, 3.0317e-02]], [[ 1.7134e-01, 2.7122e-01, -1.4705e-01], [ 2.2126e-01, 2.6238e-01, 2.7936e-02], [ 1.3597e-01, 1.2821e-01, -3.1927e-02]]], [[[ 1.0050e-01, 1.0947e-02, -2.8499e-02], [ 3.3402e-02, 4.9107e-02, -6.5103e-02], [ 7.1228e-02, -4.5566e-02, -7.1941e-02]], [[-3.2692e-02, 1.0563e-01, 1.3657e-01], [ 1.5408e-01, 1.5776e-02, -7.8609e-02], [ 2.0654e-02, 3.4493e-02, -8.3062e-02]], [[ 1.1758e-01, 6.8881e-03, -1.1688e-01], [-6.8616e-02, 8.1720e-02, -8.2880e-02], [-7.6988e-02, -1.1155e-02, 5.4405e-02]], [[ 2.7133e-02, -8.8362e-02, 3.9897e-02], [ 1.4136e-02, 1.0811e-02, 3.9810e-02], [-8.3883e-02, -4.2966e-03, 2.8923e-02]], [[-3.2469e-02, 9.3366e-02, 3.0084e-03], [ 6.7961e-03, 4.9804e-02, -9.9732e-03], [ 9.5707e-02, -4.7226e-02, 4.3290e-02]], [[ 7.2742e-02, 2.4293e-03, 6.8489e-02], [ 6.1281e-02, -7.8944e-02, -3.2846e-02], [-2.2176e-02, 1.6491e-02, -1.1227e-02]], [[-6.8623e-03, -5.4505e-02, 6.8332e-02], [ 4.0548e-02, 1.4178e-01, -2.8324e-02], [ 7.0508e-03, 2.2471e-02, 2.0372e-02]], [[-1.0305e-01, 4.0219e-03, 5.8676e-02], [ 2.6044e-02, 4.9744e-02, -1.0405e-02], [-3.0880e-05, -1.8293e-02, -5.0226e-03]], [[-2.3174e-04, 2.0282e-01, 2.5846e-01], [ 1.7774e-01, -1.0610e-02, -1.2066e-01], [ 5.5648e-02, -8.6045e-02, 1.3294e-02]], [[ 8.3238e-02, 1.4007e-01, 4.3231e-02], [ 1.9963e-01, 3.1187e-02, -2.7098e-02], [ 2.7223e-02, 2.1056e-02, 7.7691e-02]]], [[[-6.5525e-02, 1.3377e-02, 2.2981e-02], [ 6.1054e-02, -3.2745e-02, 1.6046e-02], [ 8.6746e-03, 5.1198e-03, -3.7787e-02]], [[ 3.5905e-02, 1.6534e-01, 1.3293e-01], [ 8.0374e-02, 1.7968e-03, 2.2853e-01], [-2.6486e-01, -1.0116e-01, 8.8503e-02]], [[ 3.6201e-02, 2.3252e-03, 9.4554e-02], [-1.3000e-01, 7.6731e-02, -4.7001e-03], [-1.3416e-01, -1.9855e-01, -7.9124e-03]], [[ 3.5421e-02, 2.8535e-03, 1.3598e-01], [-1.4396e-02, 4.9013e-02, 9.9984e-02], [ 6.1836e-02, -5.3517e-02, 8.0026e-03]], [[ 1.2745e-01, 4.5174e-02, 1.7392e-01], [ 9.5329e-02, 9.5045e-02, 2.6830e-02], [-8.2258e-02, -1.8045e-01, 1.0185e-01]], [[ 1.2672e-01, 2.1717e-02, 1.5535e-01], [ 3.4371e-02, -1.6216e-02, 1.0340e-02], [-7.2134e-02, -4.4705e-02, 1.2603e-02]], [[ 3.1017e-02, 1.4737e-01, 1.7498e-01], [ 6.0591e-02, 1.0475e-01, 7.5809e-02], [-2.1633e-01, -2.2847e-01, -3.9649e-02]], [[-5.1637e-03, 6.0022e-02, -4.2316e-03], [ 1.1513e-01, 7.1008e-02, -2.0683e-02], [ 1.5601e-01, 3.1579e-02, 7.5166e-02]], [[ 1.3309e-01, 1.7667e-01, 1.4942e-01], [-5.9897e-02, 1.1532e-01, 1.5143e-01], [-3.4537e-01, -1.3462e-01, -2.2330e-02]], [[-5.9672e-02, 1.1794e-01, 1.4169e-01], [-9.1086e-02, -1.1489e-01, -1.9463e-03], [-1.7472e-01, -2.1847e-01, -1.8539e-01]]], [[[-1.0154e-01, -6.8658e-02, 7.2154e-02], [ 3.6535e-02, 6.9012e-02, -6.8737e-02], [ 7.3614e-02, -2.3271e-02, -2.3943e-02]], [[ 3.2182e-02, -8.5749e-02, -6.0566e-02], [ 1.5804e-01, 2.2050e-01, 2.3182e-01], [-8.8922e-02, 1.2927e-02, -5.4912e-02]], [[ 1.4580e-01, 2.7453e-02, 2.4084e-01], [ 5.0754e-04, 5.6418e-02, 5.0791e-02], [-1.8144e-01, -1.1698e-01, -4.5776e-03]], [[ 1.5509e-02, 1.4607e-01, 1.1602e-02], [-2.7769e-02, -1.0466e-01, -1.0128e-01], [ 2.7542e-02, 4.3408e-02, 3.0959e-02]], [[ 7.6873e-02, -2.3637e-02, -1.3169e-02], [ 7.3301e-02, 6.1768e-02, 1.9723e-01], [-2.6288e-02, 6.5549e-03, -1.8587e-02]], [[ 1.3100e-01, 1.0570e-01, -2.2552e-02], [ 8.1559e-02, 5.6471e-02, 1.4532e-01], [-1.5560e-01, -8.9411e-02, -1.2147e-01]], [[ 8.5973e-02, -1.1380e-01, -1.0949e-01], [-3.8129e-02, 1.3487e-01, 2.5764e-02], [-5.2529e-02, -1.5731e-01, -1.4858e-01]], [[-3.5424e-02, 6.8988e-02, -9.2071e-02], [ 1.0150e-01, 3.1708e-02, -2.8064e-02], [ 6.5840e-02, 9.1770e-02, 7.8782e-02]], [[-8.1560e-02, -6.3391e-02, -6.7798e-02], [ 2.0728e-01, 2.6188e-01, 2.7651e-01], [ 8.2488e-03, -1.6854e-01, -1.4332e-01]], [[ 7.8573e-02, -3.8036e-02, -1.0342e-01], [-8.6150e-02, -3.8577e-02, 1.1166e-01], [ 8.8807e-03, -7.5183e-02, -9.8038e-02]]]], device='cuda:0') conv2.bias tensor([-0.0873, -0.0187, -0.0091, 0.0426, -0.0208, -0.0851, -0.0851, -0.0227, -0.0885, -0.1077], device='cuda:0') fc1.weight tensor([[-0.0195, 0.0407, -0.0497, ..., 0.0387, 0.0073, 0.0591], [-0.0337, 0.0293, -0.0566, ..., -0.1021, -0.0571, -0.0017], [ 0.0021, 0.0311, 0.0280, ..., -0.0186, -0.0534, 0.0147], ..., [ 0.0150, -0.0585, 0.0503, ..., -0.0338, -0.0676, 0.0110], [-0.0219, -0.0225, -0.0213, ..., 0.0353, 0.0634, 0.0807], [-0.0557, 0.0409, 0.0342, ..., 0.0260, 0.0347, 0.0383]], device='cuda:0') fc1.bias tensor([-0.0015, 0.0131, -0.0586, 0.0559, 0.0051, 0.0570, -0.0521, -0.0528, 0.0492, 0.0417, 0.0464, 0.0089, 0.0383, 0.0204, -0.0518, 0.0489, -0.0118, 0.0068, -0.0457, -0.0373, -0.0361, 0.0346, 0.0443, -0.0285, 0.0407, 0.0626, 0.0078, 0.0578, -0.0300, -0.0030, 0.0601, -0.0393, -0.0550, 0.0210, -0.0345, 0.0270, -0.0119, 0.0304, -0.0331, 0.0514, 0.0085, -0.0688, 0.0424, -0.0420, -0.0274, 0.0356, -0.0328, -0.0116, 0.0268, -0.0090], device='cuda:0') fc2.weight tensor([[ 1.6251e-02, -1.5670e-01, 6.2179e-02, -2.5666e-01, -1.0849e-01, -7.9805e-02, -1.4580e-01, 5.3894e-02, 4.7264e-02, 9.3232e-03, 3.7225e-02, 1.8096e-04, -1.8057e-01, -9.5027e-02, 2.1574e-01, 1.3417e-01, -9.0796e-02, -6.4141e-02, 2.0663e-01, 1.1615e-01, -2.2332e-02, 3.0119e-02, 1.8959e-01, -1.0565e-01, 3.6511e-02, -3.0243e-01, -8.3646e-02, -2.3040e-01, 3.0957e-02, 1.0421e-02, 5.1469e-02, -2.3561e-01, 1.6106e-01, 1.0504e-01, 1.4167e-01, 6.3622e-02, 2.6722e-01, -6.3189e-02, 3.0331e-02, -8.3915e-02, 1.3751e-01, 2.4958e-02, -1.7238e-01, 7.9102e-02, 8.9501e-02, 4.8521e-02, -4.3476e-02, 1.6667e-01, 1.0159e-01, -2.4930e-01], [ 1.8496e-01, -7.0289e-02, -1.6053e-01, 3.3202e-02, 2.7690e-02, 2.9634e-02, -2.5264e-01, 4.0322e-02, -5.9596e-03, 2.0254e-01, -2.7789e-01, 1.6509e-01, -9.2184e-02, -3.6417e-02, -2.3216e-02, 1.2620e-01, 1.6619e-01, 4.1052e-02, -1.7910e-01, -2.8514e-02, 7.9238e-02, -1.0797e-01, -3.1161e-03, -1.7624e-01, 2.3983e-01, 1.7228e-01, 1.2235e-01, 6.2304e-04, -1.5081e-01, 4.2691e-02, 1.9024e-01, -4.8015e-02, -1.5896e-01, 4.0967e-02, -1.3923e-02, -2.4055e-02, 2.6338e-02, -9.5930e-02, 3.0892e-02, -9.1458e-02, 1.2541e-01, -1.2786e-01, -1.1661e-01, 1.0776e-01, -2.0580e-01, -5.8174e-02, 4.2228e-02, -1.0984e-01, -8.9303e-02, 2.8916e-01], [ 3.7768e-02, -2.7102e-02, 2.1311e-01, -2.1505e-01, -8.3649e-02, 1.6869e-01, 3.3240e-02, 7.1981e-02, -3.8186e-02, -4.6143e-02, -4.0482e-02, -2.2851e-01, 5.7068e-02, 8.1019e-03, -1.6423e-02, -1.3503e-01, -1.4524e-01, -1.3020e-01, 1.1191e-01, -1.1636e-01, 3.0619e-02, 4.7546e-02, -1.4880e-01, 3.2853e-02, 2.5500e-01, 2.8400e-01, 1.1715e-01, 1.4218e-01, 4.7207e-02, 1.8416e-01, -8.5473e-02, -2.0722e-01, 1.2043e-01, -7.7677e-02, -3.2183e-02, 2.2422e-01, 2.0048e-02, 2.8272e-02, 6.2195e-02, -8.9332e-03, -1.3854e-01, 1.3734e-01, -1.7061e-01, 1.2403e-01, -1.1897e-01, -6.9141e-02, 2.9589e-03, -1.1858e-01, 3.1244e-01, -7.1026e-02], [ 7.2070e-02, -8.2947e-02, 1.0462e-01, 1.1510e-01, -1.3414e-01, -4.1341e-02, -3.9892e-03, -4.3937e-02, -5.7818e-02, -1.9363e-02, 2.3097e-01, 1.7871e-01, -1.0932e-01, -1.5344e-01, -1.4060e-01, 1.2338e-01, -1.0610e-01, -3.5021e-02, -1.0639e-01, -1.7873e-01, -1.9816e-02, -1.1914e-01, -1.2282e-01, -1.0730e-01, 4.8685e-02, 2.8685e-01, -3.6455e-02, 1.0155e-01, -1.0218e-01, 7.3388e-02, 1.8135e-01, 1.4195e-02, -5.3808e-02, -6.0250e-02, -2.4354e-01, -1.8929e-03, -4.5196e-02, -2.8742e-02, -8.3463e-02, -2.5242e-01, -6.3273e-02, 4.4235e-02, 1.6873e-01, 1.5587e-02, 2.7582e-01, 7.1417e-02, 5.0392e-02, -3.0197e-02, -7.0127e-02, -1.3965e-01], [ 1.8744e-01, 1.2431e-01, -3.1609e-02, 3.9978e-02, 1.6053e-01, -2.7083e-01, -1.9984e-01, 1.6198e-02, 1.0145e-01, 3.3261e-02, -5.8703e-02, 1.4000e-01, 1.5216e-01, 8.4478e-02, -2.0519e-01, 1.1327e-01, 8.1851e-02, 1.0000e-01, 4.6565e-02, 1.8760e-01, 1.5652e-01, -9.0877e-02, 9.0847e-02, -4.7611e-02, -1.1455e-01, 9.8377e-02, 1.1625e-01, -2.9026e-01, 5.0541e-02, -2.4534e-01, -2.1454e-01, 2.1028e-02, 8.3265e-02, 9.8461e-02, -7.1467e-02, -7.3329e-02, 1.6112e-01, -6.8067e-02, 4.6256e-02, 2.0718e-01, 1.7332e-02, 1.2349e-01, 2.6190e-01, 9.0074e-02, -3.0716e-01, -9.5819e-02, -2.9653e-02, -2.3727e-01, -5.8132e-02, -1.4492e-01], [-2.8708e-01, 2.1172e-02, -7.7577e-02, 2.3168e-01, 1.4551e-01, -1.3710e-01, 8.5609e-02, -1.3015e-01, 3.8169e-02, 1.8272e-01, -6.1653e-02, -1.3796e-01, -5.5258e-02, 2.0500e-01, 1.3102e-01, 2.1143e-01, -6.5387e-02, 2.8457e-01, -8.1624e-02, -1.6772e-01, -4.7079e-02, -1.3348e-01, 8.8649e-02, 2.5312e-01, -2.0344e-01, 6.7435e-02, -4.2907e-02, -4.8078e-02, -7.8303e-02, 1.0079e-01, -8.3547e-02, 9.7197e-02, 4.1389e-02, -7.8404e-02, 1.4313e-01, 1.3699e-01, 7.2760e-02, -1.7880e-01, 1.3151e-01, -2.6291e-01, 6.7164e-02, -1.0825e-01, 1.5789e-01, -3.3017e-02, 2.4336e-01, -1.3310e-01, 7.0437e-02, 1.8922e-01, -1.2296e-01, 5.2250e-02], [-8.6617e-02, -2.4831e-01, -1.4893e-01, 1.1861e-01, -1.5281e-01, -1.9669e-01, -1.0179e-01, -4.4168e-02, 2.0120e-01, 2.4131e-01, -2.7182e-01, -1.9030e-01, 2.5390e-01, 1.4290e-02, -1.1229e-01, 1.5810e-01, -3.4638e-02, -1.5801e-01, -7.0952e-02, -1.0404e-01, -3.3984e-02, -7.3040e-02, 1.4663e-01, 1.6656e-01, -3.5996e-02, -2.6866e-01, -1.7899e-01, -1.1250e-01, 1.5972e-01, -7.4748e-02, 1.6377e-01, 1.7482e-01, 2.0310e-01, 4.7474e-02, 1.3380e-01, -2.7786e-02, 2.1988e-01, -1.1217e-01, 1.0254e-01, 5.5699e-02, -1.2858e-01, -7.3949e-02, -2.2155e-01, -9.9521e-02, 7.6868e-02, 1.0155e-01, -1.4088e-01, 8.2632e-02, 1.2162e-01, -1.7970e-01], [-1.5650e-02, -5.4569e-04, 6.8994e-03, 5.2045e-02, 1.8290e-01, 2.6804e-01, 4.4902e-02, 1.2819e-01, -2.2999e-01, -1.2295e-01, 1.1545e-02, -9.9257e-02, 6.1961e-02, 1.7859e-02, 7.1698e-02, 4.0755e-02, 6.4862e-02, 7.0479e-03, -1.7442e-02, 2.3635e-01, -2.1890e-01, -1.0068e-01, 1.2437e-01, -1.7878e-01, 3.0244e-02, 1.0967e-01, 5.9266e-02, 2.9795e-01, 1.6951e-01, -1.3309e-02, -2.1733e-02, -1.5670e-01, -2.0466e-01, -1.0576e-01, -1.3450e-01, 6.5420e-02, -1.5616e-01, -4.6581e-02, -1.7628e-01, 2.5693e-01, -1.0783e-01, 5.0509e-02, 1.7677e-01, 4.8392e-02, -2.0880e-01, 3.9134e-02, 1.5694e-02, 2.2392e-01, -8.7742e-02, -1.1996e-01], [ 2.6303e-02, 7.5723e-02, 7.3892e-02, -1.3478e-01, -2.1879e-02, 1.0987e-03, -2.6124e-02, 1.6876e-01, 2.3617e-01, -9.9598e-02, -8.9752e-02, -1.6628e-01, -1.3200e-01, 5.0376e-02, -2.4802e-01, -8.1146e-03, -6.6615e-02, -1.3746e-01, 1.1526e-01, 1.6665e-01, -2.1015e-02, -2.2107e-02, 2.9302e-02, 1.7231e-01, -2.1866e-02, 9.2823e-02, -1.2384e-01, -1.1497e-01, -1.0558e-01, 7.4382e-02, 2.2958e-01, 6.3308e-02, -1.4583e-02, -1.6542e-01, 1.0382e-01, 4.6560e-02, -2.6075e-01, 1.6569e-01, -2.3854e-01, -2.3205e-01, -1.2321e-01, 1.2136e-01, 5.2090e-02, 8.2703e-02, 1.7197e-01, -8.2319e-03, -3.2253e-02, -1.0089e-01, 1.0214e-01, 6.3957e-02], [-1.0272e-02, -2.1517e-03, 2.2925e-01, 4.0837e-02, -1.1391e-01, -2.4873e-02, 1.0643e-01, -2.4538e-02, 6.0559e-02, 7.1217e-02, 9.5310e-02, 2.2550e-01, -3.6936e-01, -1.0303e-01, 9.5995e-02, 1.6429e-02, -5.9848e-02, 1.5635e-01, -1.5026e-01, 2.7174e-02, 2.4192e-01, -1.0771e-01, 1.1466e-01, 1.9647e-01, -1.6442e-01, -1.3025e-01, -1.0987e-01, -2.5069e-01, -1.0778e-01, 5.3117e-02, -9.3216e-02, -8.1073e-02, 1.2159e-02, 1.5550e-01, -6.3092e-02, 1.4193e-01, 1.6516e-01, -4.7558e-02, -1.3505e-01, 2.4961e-01, 6.4121e-02, 7.6099e-02, 1.7728e-01, -4.8860e-02, 1.8123e-01, 1.1346e-01, 2.6336e-03, -8.3779e-02, -2.8640e-01, -5.4847e-02]], device='cuda:0') fc2.bias tensor([ 0.0657, -0.0037, 0.1450, -0.1473, -0.0027, -0.1079, -0.1483, 0.0354, 0.0860, 0.0722], device='cuda:0')
class Autoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 10),
nn.Tanh(),
)
self.decoder = nn.Sequential(
nn.Linear(10, 28 * 28),
nn.Tanh(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
model = Autoencoder().to(device)
criterion = nn.MSELoss()
# Convert vector to image
def to_img(x):
x = 0.5 * (x + 1)
x = x.view(x.size(0), 28, 28)
return x
# Displaying routine
def display_images(in_, out, n=1):
for N in range(n):
if in_ is not None:
in_pic = to_img(in_.cpu().data)
plt.figure(figsize=(18, 6))
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(in_pic[i+4*N])
plt.axis('off')
out_pic = to_img(out.cpu().data)
plt.figure(figsize=(18, 6))
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(out_pic[i+4*N])
plt.axis('off')
learning_rate = 1e-3
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
)
num_epochs = 20
# do = nn.Dropout() # comment out for standard AE
for epoch in range(num_epochs):
for data in train_loader:
img, _ = data
img = img.to(device)
img = img.view(img.size(0), -1)
# noise = do(torch.ones(img.shape)).to(device)
# img_bad = (img * noise).to(device) # comment out for standard AE
# ===================forward=====================
output = model(img) # feed <img> (for std AE) or <img_bad> (for denoising AE)
loss = criterion(output, img.data)
# ===================backward====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===================log========================
print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss.item():.4f}')
display_images(None, output) # pass (None, output) for std AE, (img_bad, output) for denoising AE
epoch [1/20], loss:0.5505 epoch [2/20], loss:0.4810 epoch [3/20], loss:0.5168 epoch [4/20], loss:0.4971 epoch [5/20], loss:0.4869 epoch [6/20], loss:0.5054 epoch [7/20], loss:0.4592 epoch [8/20], loss:0.5120 epoch [9/20], loss:0.4975 epoch [10/20], loss:0.5090 epoch [11/20], loss:0.4751 epoch [12/20], loss:0.4523 epoch [13/20], loss:0.5073 epoch [14/20], loss:0.4510 epoch [15/20], loss:0.5025 epoch [16/20], loss:0.5128 epoch [17/20], loss:0.4349 epoch [18/20], loss:0.4624 epoch [19/20], loss:0.4192 epoch [20/20], loss:0.5033
for name, param in model_cnn.named_parameters():
if ("bn" not in name):
param.requires_grad = False
x = torch.randn(1, 1, 28, 28, device=device, requires_grad=True)
class myDD(nn.Module):
"""
делаем модель, у которой параметры - входное изображение
"""
def __init__(self, model, x):
super().__init__()
self.weights = nn.Parameter(x)
self.model = model
# self.do = nn.Dropout2d(0.5)
def forward(self):
return self.model(self.weights)
model = myDD(model_cnn, x)
def trainx(nepoch, model, class2learn=0):
"""
- сколько эпох
- какая модель
- на какой класс
"""
images = []
target = torch.tensor([class2learn]).to(device) #torch.tensor([1,0,0,0,0,0,0,0,0,0]).unsqueeze(0).to(device)
# ядро для сглаживания
cnv = nn.Conv2d(in_channels = 1, out_channels = 1, kernel_size = (3,3), stride = 1, padding=1).to(device)
cnv.weight.data = cnv.weight.data*0 + (1.0/9.0)
cnv.bias[0] = 0.0
for i in range(nepoch):
model.train()
optimizer.zero_grad()
output = model()
#print(output.shape, target.shape)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if i % 100 == 25:
# сглаживаем
model.weights.data = cnv(model.weights.data)
if i % 100 == 0:
print(f'i={i} loss={loss}')
images.append(model.weights.squeeze().detach().cpu())
return images
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
images = trainx(1000, model, class2learn=0)
i=0 loss=0.7772314548492432 i=100 loss=0.00047588348388671875 i=200 loss=0.00027561187744140625 i=300 loss=0.0002117156982421875 i=400 loss=0.000156402587890625 i=500 loss=0.000125885009765625 i=600 loss=0.00010585784912109375 i=700 loss=9.250640869140625e-05 i=800 loss=8.296966552734375e-05 i=900 loss=7.534027099609375e-05
plt.figure(figsize=(18, 6))
for i in range(10):
plt.subplot(2,5,i+1)
plt.imshow(images[i])
# (f'i = {i*100}')
plt.axis('off')
plt.imshow(x.squeeze().detach().cpu())
plt.show()
plt.imshow(model.weights.squeeze().detach().cpu())
<matplotlib.image.AxesImage at 0x28d1606fef0>
trainx(1000, model, class2learn=1)
plt.imshow(x.squeeze().detach().cpu())
plt.show()
plt.imshow(model.weights.squeeze().detach().cpu())
i=0 loss=0.00019741058349609375 i=100 loss=0.0001697540283203125 i=200 loss=0.00014209747314453125 i=300 loss=0.0001316070556640625 i=400 loss=0.0001201629638671875 i=500 loss=0.000110626220703125 i=600 loss=0.0001010894775390625 i=700 loss=8.487701416015625e-05 i=800 loss=8.106231689453125e-05 i=900 loss=7.2479248046875e-05
<matplotlib.image.AxesImage at 0x28d169ceb70>