pytorch查看网络参数显存占用量等操作

  • Post category:Python

PyTorch 是一个基于 Python 的科学计算库,它支持两种高级操作:构建动态计算图和自动求导。在深度学习模型的训练中,了解网络的参数显存占用情况以及调试过程中查看网络参数的具体信息都是非常有必要的。本文将介绍如何在 PyTorch 中查看网络参数显存占用量等操作。

1. 查看网络参数显存占用量

PyTorch 中的 torch.cuda 模块提供了一系列 API 来操作显存。下面我们将展示如何查看 PyTorch 模型在显存中的占用情况:

import torch

# 构建模型
model = torch.nn.Linear(3, 1).cuda()

# 在cuda上运行
input = torch.randn((2, 3)).cuda()

# 前向计算
_ = model(input)

# 打印模型显存占用情况
print("显存占用量:", torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, "MB")

在运行完上述代码后,将会输出模型显存占用情况。

2. 查看网络参数具体信息

在 PyTorch 中,可以通过 model.parameters() 方法来访问网络的参数,然后可以使用 getattr() 函数或者直接遍历获取每个参数的具体信息。

import torch

# 构建模型
class CustomModel(torch.nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.fc1 = torch.nn.Linear(3, 1)
        self.fc2 = torch.nn.Linear(4, 3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


model = CustomModel()

# 打印第一个线性层的权重和偏置
print(model.fc1.weight)
print(model.fc1.bias)

# 遍历模型的参数
for name, param in model.named_parameters():
    print(name, param.shape)

在运行完上述代码后,将会输出模型参数的具体信息,包括参数名称和形状等。