python np.split函数

  • Post category:Python

np.split()函数是numpy库中用来将一个数组分裂成多个子数组的函数。它的语法格式为:

numpy.split(array, indices_or_sections, axis=0)

其中,array表示待分裂的数组,indices_or_sections可以是整数、数组或者是一个表示分裂点的序列,axis是设定分裂的轴方向,默认为0(即按照第一个维度进行分裂)。

当indices_or_sections为整数n时,表示将数组沿着设定的轴方向,分裂成n个相等的子数组。当indices_or_sections是一个列表时,表示将数组沿着设定的轴方向,分裂成列表中元素的数量+1个子数组。

下面分别给出两条示例代码,说明如何使用np.split()函数:

第一条代码示例:

import numpy as np

arr = np.arange(10)
splitted_arr = np.split(arr, 2)

print(splitted_arr)

输出结果:

[array([0, 1, 2, 3, 4]), array([5, 6, 7, 8, 9])]

解释:以上代码中,生成一个从0到9的一维数组arr。然后使用np.split()函数在沿着第一个维度(即axis=0)的方向上将数组分成了两个大小相等的子数组。分裂后得到的结果存储在一个列表中,分别输出这两个子数组(splitted_arr[0]和splitted_arr[1])。

第二条代码示例:

import numpy as np

arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
splitted_arr = np.split(arr, [1, 2], axis=1)

print(splitted_arr)

输出结果:

[array([[1],
       [4],
       [7]]), array([[2],
       [5],
       [8]]), array([[3],
       [6],
       [9]])]

解释:以上代码中,生成一个大小为3×3的二维数组arr。使用np.split()函数沿着第二个维度(即axis=1)将数组在第1个和第2个位置进行分裂。分裂后得到的结果存储在一个列表中,分别输出这三个子数组(splitted_arr[0]、splitted_arr[1]、splitted_arr[2])。第一个子数组包含了原始数组中的第1列,第二个子数组包含了原始数组中的第2列,第三个子数组包含了原始数组中的第3列。