详解TensorFlow的 tf.gather 函数:根据索引获取张量的值

  • Post category:Python

TensorFlow的tf.gather函数用于根据指定的索引在给定的张量中收集元素。

函数的参数:

tf.gather(params, indices, axis=None, batch_dims=0, name=None)

参数解释:

  • params:一个tensor类型的数据,代表要从中取值的输入的张量。
  • indices:一个tensor类型的数据,代表要收集元素的索引。
  • axis:默认值为0,代表是收集的维度,如何未指定,代表最外层维度。
  • batch_dims:这个参数不是很常用。如果指定的话,代表要在哪个轴之前计算,具体请看官方文档的解释。
  • name:op的名称,可选。

下面来看一个例子:

import tensorflow as tf

# 创建一个张量,shape为 [3, 3]
x = tf.constant([[3, 4, 5], [7, 8, 9], [12, 13, 14]], dtype=tf.int32)

# 通过gather收集张量
y = tf.gather(x, [0, 2], axis=0)

# 输出结果
with tf.Session() as sess:
    print(sess.run(y))

# 输出:
# [[ 3  4  5]
#  [12 13 14]]

在这个例子里,我们创建一个shape为 [3, 3] 的二维矩阵,并把它赋值给变量x。然后我们使用tf.gather函数,从x中收集了x[0]和x[2]两个行向量,得到一个新的张量y,其shape为 [2, 3]。

另外一个例子:

import tensorflow as tf

# 创建一个张量,shape为 [3, 3, 3]
x = tf.constant([[[3, 4, 5], [7, 8, 9], [12, 13, 14]], 
                 [[1, 2, 3], [4, 5, 6], [7, 8, 9]], 
                 [[5, 6, 7], [8, 9, 10], [11, 12, 13]]], 
                dtype=tf.int32)

# 使用gather函数对维度0进行切片,也就是取x[0]和x[2]这两个二维张量
y = tf.gather(x, [0, 2], axis=0)

# 将y和x进行拼接
z = tf.concat([y, x], axis=0)

# 输出结果
with tf.Session() as sess:
    print(sess.run(z))

# 输出:
# [[[ 3  4  5]
#   [ 7  8  9]
#   [12 13 14]]

#  [[ 5  6  7]
#   [ 8  9 10]
#   [11 12 13]]

#  [[ 3  4  5]
#   [ 7  8  9]
#   [12 13 14]]

#  [[ 1  2  3]
#   [ 4  5  6]
#   [ 7  8  9]]

#  [[ 5  6  7]
#   [ 8  9 10]
#   [11 12 13]]]

在这个例子中,我们创建一个形状为 [3, 3, 3] 的三维矩阵x,并使用tf.gather函数,从维度0上对x切片,得到一个新的张量y,其shape为 [2, 3, 3]。然后我们将y和x在维度0上做拼接,得到一个形状为 [5, 3, 3] 的张量z。

以上就是tf.gather函数的作用和使用方法的完整攻略。