当我们使用PyTorch构建神经网络时,通常需要使用tensor来保存计算结果以及模型参数。在使用tensor的过程中,为了避免梯度的传递,有时候需要区分tensor的值和梯度。而PyTorch提供了两种方式来获取不同版本的tensor值,即tensor.detach()和tensor.data。
tensor.detach()和tensor.data有何不同?
这两种方式都可以获取不需梯度的tensor值,但是它们有一些区别:
-
tensor.detach()会返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是张量的梯度属性被禁用。因此,使用tensor.detach()获取的张量可以保留计算图的历史记录,即使在其后面进行了一些修改或其他计算,这个张量会一直跟踪这些修改,占用相应的内存空间。
-
tensor.data返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是梯度属性也被禁用。该返回值不在计算图中,因此在实际应用中如果需要保留计算图的历史记录,建议使用tensor.detach()而非tensor.data。
使用示例
下面我们来看两个使用示例,分别是使用tensor.detach()和tensor.data来获取tensor值。
示例1
import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x**2
z = y.detach()
print("x:", x)
print("y:", y)
print("z:", z)
print("x的梯度:", x.grad)
print("y的梯度:", y.grad)
print("z的梯度:", z.grad)
输出:
x: tensor([1., 2., 3.], requires_grad=True)
y: tensor([1., 4., 9.], grad_fn=<PowBackward0>)
z: tensor([1., 4., 9.])
x的梯度: None
y的梯度: None
z的梯度: None
我们可以看到,使用tensor.detach()获取的tensor值不具有梯度信息,即不会影响计算图。
示例2
import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x**2
z = y.data
print("x:", x)
print("y:", y)
print("z:", z)
print("x的梯度:", x.grad)
print("y的梯度:", y.grad)
print("z的梯度:", z.grad)
输出:
x: tensor([1., 2., 3.], requires_grad=True)
y: tensor([1., 4., 9.], grad_fn=<PowBackward0>)
z: tensor([1., 4., 9.])
x的梯度: None
y的梯度: None
z的梯度: None
可以看到使用tensor.data获取tensor值同样不具有梯度信息,与上面例子相同。
综上所述,tensor.detach()会返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是张量的梯度属性被禁用,可以保留计算图的历史记录;tensor.data返回一个新的tensor,该新tensor与原始tensor共享存储空间,但是梯度属性也被禁用,该返回值不在计算图中。在实际应用中,建议使用tensor.detach()而非tensor.data,以充分利用计算图的历史信息,同时避免造成不必要的内存占用。