当我们在使用 PyTorch 模型进行训练时,经常需要加载和保存模型的权重参数。在 PyTorch 中,我们可以通过 state_dict()
方法来获取模型的权重参数,以便在需要时将其保存或加载。但在某些情况下,直接使用 state_dict()
来复制模型参数可能会出现一些问题。下面是解决这个问题的完整攻略:
1. 了解 PyTorch 的权重参数结构
在使用 PyTorch 进行模型训练过程中,我们通常会用到 nn.Module
类来构建模型,其中包含了网络的各个层和参数。在 PyTorch 中,每个 nn.Module
实例都包含一个 state_dict()
方法,用来获取该模型的所有权重参数。其中,权重参数保存在一个由层名称和其对应的权重矩阵组成的字典中。
例如,以下是一个简单的前馈神经网络的示例,包含三个全连接层:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 30)
self.fc3 = nn.Linear(30, 2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
我们可以通过以下代码获取模型的权重参数:
net = Net()
state_dict = net.state_dict()
此时得到的 state_dict
结果如下:
{
'fc1.weight': tensor([[-0.1402, -0.0102, 0.1055, ..., -0.1262, -0.1681, 0.0815],
[ 0.1941, 0.1912, 0.1556, ..., 0.0311, 0.2212, -0.0408],
[-0.2455, 0.1964, -0.2239, ..., -0.0719, 0.1390, 0.0674],
...,
[-0.0196, -0.0767, 0.2292, ..., 0.0329, -0.0961, -0.2141],
[-0.1233, -0.0367, 0.0672, ..., -0.1409, 0.1846, 0.1830],
[-0.1079, 0.2058, -0.2033, ..., 0.2311, -0.1325, -0.0241]]),
'fc1.bias': tensor([ 0.1263, 0.0967, -0.2020, 0.2046, -0.0506, 0.0740, -0.2363, -0.2096,
-0.1401, 0.0091, -0.1665, -0.0123, 0.1108, 0.0463, 0.1559, -0.1228,
-0.0522, 0.0768, 0.1704, 0.1612]),
'fc2.weight': tensor([[ 1.7272e-02, 8.7735e-03, 3.1229e-02, ..., -2.7138e-02,
-9.9532e-05, -3.3799e-02],
[ 3.4929e-02, -9.9971e-03, 3.1842e-02, ..., -9.0860e-03,
1.6887e-02, -3.6512e-02],
[-2.0188e-02, -3.3823e-02, -3.3286e-02, ..., -2.0526e-02,
-3.3368e-02, -3.2620e-02],
...,
[-5.7061e-05, -1.3220e-02, 2.5787e-03, ..., -7.8213e-03,
1.4787e-02, 6.4204e-03],
[ 3.6745e-03, -2.9458e-02, -3.0426e-02, ..., 1.7976e-02,
-1.8178e-02, 2.2674e-02]]),
'fc2.bias': tensor([-0.0146, -0.0344, -0.0386, -0.0384, 0.0474, 0.0456, -0.0257, -0.0046,
-0.0378, -0.0043, 0.0353, -0.0287, -0.0053, -0.0500, -0.0193, 0.0466,
0.0214, -0.0229, 0.0408, 0.0320, 0.0007, 0.0143, -0.0303, 0.0263,
0.0266, 0.0176, 0.0426, 0.0199, 0.0073, -0.0442]),
'fc3.weight': tensor([[ 0.0006, -0.0151, -0.0343, -0.0309, 0.0401, 0.1228, 0.0821, -0.0035,
-0.1232, 0.0362, -0.0737, -0.1411, -0.1194, 0.1498, 0.0904, 0.0536,
0.0992, -0.1288, 0.0857, -0.0906, 0.0170, -0.0989, -0.0391, -0.1411,
-0.1311, -0.0112, -0.0274, -0.0604, -0.0551, 0.1116],
[ 0.0077, -0.0113, -0.1250, -0.0962, 0.0738, 0.1124, 0.0449, -0.0749,
0.1336, 0.0979, 0.0348, -0.0534, 0.1412, -0.1297, -0.0255, -0.0312,
0.1093, 0.0457, -0.0027, 0.0403, 0.0086, -0.0817, -0.1007, -0.0283,
-0.0578, 0.0855, -0.1160, -0.0316, -0.0466, 0.1133]]),
'fc3.bias': tensor([ 0.0029, 0.0511])
}
2. 使用 PyTorch 自带的拷贝函数
针对 state_dict()
拷贝带来的问题,PyTorch 提供了自带的拷贝函数 deepcopy()
,调用该函数实现深度拷贝,可以正确的拷贝模型的权重参数。使用示例如下:
import copy
net = Net()
state_dict = net.state_dict()
new_state_dict = copy.deepcopy(state_dict)
此时就可以使用 new_state_dict
来复制模型参数,而不会出现 state_dict()
方法无法直接复制的问题。
3. 使用 torch.save()
和 torch.load()
函数
另外,PyTorch 还提供了 torch.save()
和 torch.load()
函数来保存和加载模型的权重参数。这两个函数可以保证正确序列化和反序列化 PyTorch 模型,避免直接拷贝 state_dict()
方法所带来的风险。
以下是使用 torch.save()
和 torch.load()
函数来保存和加载模型的权重参数的示例:
net = Net()
# save model parameters
torch.save(net.state_dict(), 'model_weights.pth')
# load model parameters
new_net = Net()
new_net.load_state_dict(torch.load('model_weights.pth'))
通过以上三种方式,我们都可以正确的拷贝模型参数,确保在模型训练过程中能够正确的加载和保存模型权重参数,在实际开发中可以根据具体需求选择合适的方法。