PyTorch报”RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 ‘other’ “的原因以及解决办法

  • Post category:Python

问题描述:

当使用PyTorch的函数进行操作时,可能会遇到如下错误:

RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 'other'

问题分析:

这个错误信息提示我们输入的期望数据类型为整形类型 int,但实际输入的却是浮点类型 double,这导致了类型不匹配的错误。出现这种情况通常是因为在输入数据时出现了类型错误或数据类型转换错误。

解决方案:

一种最常见的解决方式是,在进行操作时,确保输入数据的数据类型为 int。对于代码中的所有涉及数字类型的代码行,更改变量或字面量的数据类型为 int,这将确保所有数据类型的匹配。

如果输入的数据确实需要是浮点类型 double,则可以使用数据类型转换函数将 int 转换为 double。

以下是一些解决 RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 'other' 这个错误信息的常见方案:

  1. 确保输入数据类型为 int
import torch

x = torch.tensor([1, 2, 3], dtype=torch.int)
y = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float)

# 以下是错误示例,将导致 "Expected object of scalar type Int but got scalar type Double for argument #2 'other'" 错误
z = x + y

# 修正错误的代码
z = x + y.int()
print(z)
  1. 使用 dtype 参数,明确指定数据类型
import torch

x = torch.tensor([1, 2, 3], dtype=torch.int)
y = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float)

# 使用 dtype 参数,明确指定数据类型
z = x + y.to(torch.int)
print(z)
  1. 使用数据类型转换函数进行类型转换
import torch

x = torch.tensor([1, 2, 3], dtype=torch.int)
y = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float)

# 使用数据类型转换函数进行类型转换
z = x + y.int()
print(z)

以上是 PyTorch 报 “RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 ‘other'” 错误的一些解决方案,希望对你有所帮助。