如何获得一个3D NumPy数组的所有2D对角线

  • Post category:Python

获取一个3D NumPy数组的所有2D对角线,可以通过以下步骤实现:

1.导入NumPy库

import numpy as np

2.创建一个3D数组

以形状为(2, 3, 4)的数组为例:

arr = np.arange(24).reshape((2, 3, 4))

此时arr的内容为:

array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

3.获取数组的对角线

NumPy提供了函数diagonal()来获取数组的对角线。其中,如果是3D数组,可以通过传入两个参数来获取对应的2D对角线。

  • axis1:指定需要获取的对角线在第几个轴上。对于一个shape为(m, n, p)的数组,可以输入0、1或2。分别表示获取第一维、第二维或第三维的对角线。
  • axis2:指定需要获取的对角线在第二个轴上。一般来说,axis2的默认值为-2,意味着获取矩阵沿第二个轴的对角线。而对于3D数组,通常需要指定axis2的值为0或1,才能获取所有的2D对角线。

以获取arr中第一个轴的对角线为例:

diag = np.diagonal(arr, axis1=0, axis2=1)

得到的diag数组内容为:

array([[ 0,  5, 10],
       [12, 17, 22]])

diag数组中的每一行就代表着原数组中对应矩阵的对角线。

为了获取所有2D对角线,可以借助for循环,对每个矩阵进行遍历,然后调用diagonal()函数获取对应的对角线。

以下为在第二个轴上获取所有2D对角线的示例:

diags = []
for i, mat in enumerate(arr):
    diag = np.diagonal(mat, axis1=0, axis2=1)
    diags.append(diag)
diags = np.array(diags)

得到的diags数组内容为:

array([[[ 0,  5, 10],
        [ 1,  6, 11]],

       [[12, 17, 22],
        [13, 18, 23]]])

diags数组中的每个子数组代表着原数组每个矩阵的所有对角线。