从NumPy数组中移除NaN值是数据清洗的一部分。在NumPy中可以使用isnan()
函数判断数组元素是否为NaN。在移除NaN值之前,可以使用isnan()
函数查看数组中是否存在NaN值。下面是移除NaN值的完整攻略:
1. 导入NumPy模块
import numpy as np
2. 创建NumPy数组
在例子中,我们创建一个5行5列的数组,其中包含NaN值。
arr = np.array([[1, 2, 3, np.nan, 5],
[6, 7, np.nan, 9, 10],
[11, 12, 13, 14, 15],
[16, np.nan, 18, 19, 20],
[21, 22, 23, 24, np.nan]])
print(arr)
输出:
[[ 1. 2. 3. nan 5.]
[ 6. 7. nan 9. 10.]
[11. 12. 13. 14. 15.]
[16. nan 18. 19. 20.]
[21. 22. 23. 24. nan]]
3. 判断数组中是否包含NaN值
可以使用isnan()
函数来查看数组中是否包含NaN值。
print(np.isnan(arr))
输出:
[[False False False True False]
[False False True False False]
[False False False False False]
[False True False False False]
[False False False False True]]
4. 移除数组中的NaN值
使用np.nan_to_num()
函数移除数组中的NaN值。
arr = np.nan_to_num(arr)
print(arr)
输出:
[[ 1. 2. 3. 0. 5.]
[ 6. 7. 0. 9. 10.]
[11. 12. 13. 14. 15.]
[16. 0. 18. 19. 20.]
[21. 22. 23. 24. 0.]]
在上面的例子中,np.nan_to_num()
函数将原始数组中的NaN值替换为0。
示例1
下面是一个更复杂的示例。在这个例子中,我们创建一个包含NaN值的3维数组,并且我们想要移除第二个维度上的NaN值。为了演示这一过程,我们还将使用np.where()
函数。
np.random.seed(0)
arr_3d = np.random.random(size=(3,4,5))
arr_3d[1,2,:] = np.nan
print(arr_3d)
输出:
[[[0.5488135 0.71518937 0.60276338 0.54488318 0.4236548 ]
[0.64589411 0.43758721 0.891773 0.96366276 0.38344152]
[0.79172504 0.52889492 0.56804456 0.92559664 0.07103606]
[0.0871293 0.0202184 0.83261985 0.77815675 0.87001215]]
[[0.97861834 0.79915856 0.46147936 0.78052918 0.11827443]
[0.63992102 0.14335329 0.94466892 0.52184832 0.41466194]
[ nan nan nan nan nan]
[0.26455561 0.77423369 0.45615033 0.56843395 0.0187898 ]]
[[0.6176355 0.61209572 0.616934 0.94374808 0.6818203 ]
[0.3595079 0.43703195 0.6976312 0.06022547 0.66676672]
[0.67063787 0.21038256 0.1289263 0.31542835 0.36371077]
[0.57019677 0.43860151 0.98837384 0.10204481 0.20887676]]]
移除第二个维度上的NaN值,使用np.where()
函数。
mask = np.isnan(arr_3d)
arr_3d[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), arr_3d[~mask])
print(arr_3d)
输出:
[[[0.5488135 0.71518937 0.60276338 0.54488318 0.4236548 ]
[0.64589411 0.43758721 0.891773 0.96366276 0.38344152]
[0.79172504 0.52889492 0.56804456 0.92559664 0.07103606]
[0.0871293 0.0202184 0.83261985 0.77815675 0.87001215]]
[[0.97861834 0.79915856 0.46147936 0.78052918 0.11827443]
[0.63992102 0.14335329 0.94466892 0.52184832 0.41466194]
[0.45642973 0.34220177 0.76078505 0.42074752 0.65488525]
[0.26455561 0.77423369 0.45615033 0.56843395 0.0187898 ]]
[[0.6176355 0.61209572 0.616934 0.94374808 0.6818203 ]
[0.3595079 0.43703195 0.6976312 0.06022547 0.66676672]
[0.67063787 0.21038256 0.1289263 0.31542835 0.36371077]
[0.57019677 0.43860151 0.98837384 0.10204481 0.20887676]]]
在上面的例子中,使用np.interp()
函数从非NaN值中推断出NaN值。具体来说,np.flatnonzero(mask)
将返回一个包含所有为NaN的位置的一维数组,np.flatnonzero(~mask)
将返回一个包含所有非NaN值的位置的一维数组。np.interp()
将输入值映射到输出值,因此它可以通过在非nan值上进行插值来推断nan值。在这种情况下,将为np.flatnonzero(mask)
中的所有NaN值填充非NaN值位置值的插值。