import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x')
然后我们将数据从存储的文件读回内存。
x2 = torch.load('x')
x2
tensor([1., 1., 1.])
我们还可以存储一列Tensor
并读回内存。
y = torch.zeros(4)
torch.save([x, y], 'xy')
x2, y2 = torch.load('xy')
(x2, y2)
(tensor([1., 1., 1.]), tensor([0., 0., 0., 0.]))
我们甚至可以存储并读取一个从字符串映射到Tensor
的字典。
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
除Tensor
以外,我们还可以读写模型的参数。我们可以使用save
方法来保存模型的state_dict
,Module
类提供了load_state_dict
函数来读取模型参数。为了演示方便,我们先创建一个多层感知机,并将其初始化。
class MLP(nn.Module):
def __init__(self, **kwargs):
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Linear(20, 256)
self.activation = nn.ReLU()
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(self.activation(self.hidden(x)))
net = MLP()
X = torch.rand(2, 20)
Y = net(X)
下面把该模型的参数存成文件,文件名为mlp.params。
filename = 'mlp.params'
torch.save(net.state_dict(), filename)
接下来,我们再实例化一次定义好的多层感知机。与随机初始化模型参数不同,我们在这里直接读取保存在文件里的参数。
net2 = MLP()
net2.load_state_dict(torch.load(filename))
<All keys matched successfully>
因为这两个实例都有同样的模型参数,那么对同一个输入X
的计算结果将会是一样的。我们来验证一下。
Y2 = net2(X)
Y2 == Y
tensor([[True, True, True, True, True, True, True, True, True, True], [True, True, True, True, True, True, True, True, True, True]])