PyTorch中 tensor.detach() 和 tensor.data 的区别解析

  • Post category:Python

当我们使用PyTorch构建神经网络时,通常需要使用tensor来保存计算结果以及模型参数。在使用tensor的过程中,为了避免梯度的传递,有时候需要区分tensor的值和梯度。而PyTorch提供了两种方式来获取不同版本的tensor值,即tensor.detach()tensor.data

tensor.detach()和tensor.data有何不同?

这两种方式都可以获取不需梯度的tensor值,但是它们有一些区别:

  • tensor.detach()会返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是张量的梯度属性被禁用。因此,使用tensor.detach()获取的张量可以保留计算图的历史记录,即使在其后面进行了一些修改或其他计算,这个张量会一直跟踪这些修改,占用相应的内存空间。

  • tensor.data返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是梯度属性也被禁用。该返回值不在计算图中,因此在实际应用中如果需要保留计算图的历史记录,建议使用tensor.detach()而非tensor.data

使用示例

下面我们来看两个使用示例,分别是使用tensor.detach()tensor.data来获取tensor值。

示例1

import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = x**2

z = y.detach()

print("x:", x)
print("y:", y)
print("z:", z)

print("x的梯度:", x.grad)
print("y的梯度:", y.grad)
print("z的梯度:", z.grad)

输出:

x: tensor([1., 2., 3.], requires_grad=True)
y: tensor([1., 4., 9.], grad_fn=<PowBackward0>)
z: tensor([1., 4., 9.])
x的梯度: None
y的梯度: None
z的梯度: None

我们可以看到,使用tensor.detach()获取的tensor值不具有梯度信息,即不会影响计算图。

示例2

import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

y = x**2

z = y.data

print("x:", x)
print("y:", y)
print("z:", z)

print("x的梯度:", x.grad)
print("y的梯度:", y.grad)
print("z的梯度:", z.grad)

输出:

x: tensor([1., 2., 3.], requires_grad=True)
y: tensor([1., 4., 9.], grad_fn=<PowBackward0>)
z: tensor([1., 4., 9.])
x的梯度: None
y的梯度: None
z的梯度: None

可以看到使用tensor.data获取tensor值同样不具有梯度信息,与上面例子相同。

综上所述,tensor.detach()会返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是张量的梯度属性被禁用,可以保留计算图的历史记录;tensor.data返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是梯度属性也被禁用,该返回值不在计算图中。在实际应用中,建议使用tensor.detach()而非tensor.data,以充分利用计算图的历史信息,同时避免造成不必要的内存占用。