解决pytorch 的state_dict()拷贝问题

  • Post category:Python

当我们在使用 PyTorch 模型进行训练时,经常需要加载和保存模型的权重参数。在 PyTorch 中,我们可以通过 state_dict() 方法来获取模型的权重参数,以便在需要时将其保存或加载。但在某些情况下,直接使用 state_dict() 来复制模型参数可能会出现一些问题。下面是解决这个问题的完整攻略:

1. 了解 PyTorch 的权重参数结构

在使用 PyTorch 进行模型训练过程中,我们通常会用到 nn.Module 类来构建模型,其中包含了网络的各个层和参数。在 PyTorch 中,每个 nn.Module 实例都包含一个 state_dict() 方法,用来获取该模型的所有权重参数。其中,权重参数保存在一个由层名称和其对应的权重矩阵组成的字典中。

例如,以下是一个简单的前馈神经网络的示例,包含三个全连接层:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

我们可以通过以下代码获取模型的权重参数:

net = Net()
state_dict = net.state_dict()

此时得到的 state_dict 结果如下:

{
  'fc1.weight': tensor([[-0.1402, -0.0102,  0.1055,  ..., -0.1262, -0.1681,  0.0815],
                        [ 0.1941,  0.1912,  0.1556,  ...,  0.0311,  0.2212, -0.0408],
                        [-0.2455,  0.1964, -0.2239,  ..., -0.0719,  0.1390,  0.0674],
                        ...,
                        [-0.0196, -0.0767,  0.2292,  ...,  0.0329, -0.0961, -0.2141],
                        [-0.1233, -0.0367,  0.0672,  ..., -0.1409,  0.1846,  0.1830],
                        [-0.1079,  0.2058, -0.2033,  ...,  0.2311, -0.1325, -0.0241]]),
  'fc1.bias': tensor([ 0.1263,  0.0967, -0.2020,  0.2046, -0.0506,  0.0740, -0.2363, -0.2096,
                        -0.1401,  0.0091, -0.1665, -0.0123,  0.1108,  0.0463,  0.1559, -0.1228,
                        -0.0522,  0.0768,  0.1704,  0.1612]),
  'fc2.weight': tensor([[ 1.7272e-02,  8.7735e-03,  3.1229e-02,  ..., -2.7138e-02,
                         -9.9532e-05, -3.3799e-02],
                        [ 3.4929e-02, -9.9971e-03,  3.1842e-02,  ..., -9.0860e-03,
                          1.6887e-02, -3.6512e-02],
                        [-2.0188e-02, -3.3823e-02, -3.3286e-02,  ..., -2.0526e-02,
                         -3.3368e-02, -3.2620e-02],
                        ...,
                        [-5.7061e-05, -1.3220e-02,  2.5787e-03,  ..., -7.8213e-03,
                          1.4787e-02,  6.4204e-03],
                        [ 3.6745e-03, -2.9458e-02, -3.0426e-02,  ...,  1.7976e-02,
                         -1.8178e-02,  2.2674e-02]]),
  'fc2.bias': tensor([-0.0146, -0.0344, -0.0386, -0.0384,  0.0474,  0.0456, -0.0257, -0.0046,
                      -0.0378, -0.0043,  0.0353, -0.0287, -0.0053, -0.0500, -0.0193,  0.0466,
                       0.0214, -0.0229,  0.0408,  0.0320,  0.0007,  0.0143, -0.0303,  0.0263,
                       0.0266,  0.0176,  0.0426,  0.0199,  0.0073, -0.0442]),
  'fc3.weight': tensor([[ 0.0006, -0.0151, -0.0343, -0.0309,  0.0401,  0.1228,  0.0821, -0.0035,
                         -0.1232,  0.0362, -0.0737, -0.1411, -0.1194,  0.1498,  0.0904,  0.0536,
                          0.0992, -0.1288,  0.0857, -0.0906,  0.0170, -0.0989, -0.0391, -0.1411,
                         -0.1311, -0.0112, -0.0274, -0.0604, -0.0551,  0.1116],
                        [ 0.0077, -0.0113, -0.1250, -0.0962,  0.0738,  0.1124,  0.0449, -0.0749,
                          0.1336,  0.0979,  0.0348, -0.0534,  0.1412, -0.1297, -0.0255, -0.0312,
                          0.1093,  0.0457, -0.0027,  0.0403,  0.0086, -0.0817, -0.1007, -0.0283,
                         -0.0578,  0.0855, -0.1160, -0.0316, -0.0466,  0.1133]]),
  'fc3.bias': tensor([ 0.0029,  0.0511])
}

2. 使用 PyTorch 自带的拷贝函数

针对 state_dict() 拷贝带来的问题,PyTorch 提供了自带的拷贝函数 deepcopy(),调用该函数实现深度拷贝,可以正确的拷贝模型的权重参数。使用示例如下:

import copy

net = Net()
state_dict = net.state_dict()

new_state_dict = copy.deepcopy(state_dict)

此时就可以使用 new_state_dict 来复制模型参数,而不会出现 state_dict() 方法无法直接复制的问题。

3. 使用 torch.save()torch.load() 函数

另外,PyTorch 还提供了 torch.save()torch.load() 函数来保存和加载模型的权重参数。这两个函数可以保证正确序列化和反序列化 PyTorch 模型,避免直接拷贝 state_dict() 方法所带来的风险。

以下是使用 torch.save()torch.load() 函数来保存和加载模型的权重参数的示例:

net = Net()

# save model parameters
torch.save(net.state_dict(), 'model_weights.pth')

# load model parameters
new_net = Net()
new_net.load_state_dict(torch.load('model_weights.pth'))

通过以上三种方式,我们都可以正确的拷贝模型参数,确保在模型训练过程中能够正确的加载和保存模型权重参数,在实际开发中可以根据具体需求选择合适的方法。