TensorFlow的tf.gather函数用于根据指定的索引在给定的张量中收集元素。
函数的参数:
tf.gather(params, indices, axis=None, batch_dims=0, name=None)
参数解释:
- params:一个tensor类型的数据,代表要从中取值的输入的张量。
- indices:一个tensor类型的数据,代表要收集元素的索引。
- axis:默认值为0,代表是收集的维度,如何未指定,代表最外层维度。
- batch_dims:这个参数不是很常用。如果指定的话,代表要在哪个轴之前计算,具体请看官方文档的解释。
- name:op的名称,可选。
下面来看一个例子:
import tensorflow as tf
# 创建一个张量,shape为 [3, 3]
x = tf.constant([[3, 4, 5], [7, 8, 9], [12, 13, 14]], dtype=tf.int32)
# 通过gather收集张量
y = tf.gather(x, [0, 2], axis=0)
# 输出结果
with tf.Session() as sess:
print(sess.run(y))
# 输出:
# [[ 3 4 5]
# [12 13 14]]
在这个例子里,我们创建一个shape为 [3, 3] 的二维矩阵,并把它赋值给变量x。然后我们使用tf.gather函数,从x中收集了x[0]和x[2]两个行向量,得到一个新的张量y,其shape为 [2, 3]。
另外一个例子:
import tensorflow as tf
# 创建一个张量,shape为 [3, 3, 3]
x = tf.constant([[[3, 4, 5], [7, 8, 9], [12, 13, 14]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[5, 6, 7], [8, 9, 10], [11, 12, 13]]],
dtype=tf.int32)
# 使用gather函数对维度0进行切片,也就是取x[0]和x[2]这两个二维张量
y = tf.gather(x, [0, 2], axis=0)
# 将y和x进行拼接
z = tf.concat([y, x], axis=0)
# 输出结果
with tf.Session() as sess:
print(sess.run(z))
# 输出:
# [[[ 3 4 5]
# [ 7 8 9]
# [12 13 14]]
# [[ 5 6 7]
# [ 8 9 10]
# [11 12 13]]
# [[ 3 4 5]
# [ 7 8 9]
# [12 13 14]]
# [[ 1 2 3]
# [ 4 5 6]
# [ 7 8 9]]
# [[ 5 6 7]
# [ 8 9 10]
# [11 12 13]]]
在这个例子中,我们创建一个形状为 [3, 3, 3] 的三维矩阵x,并使用tf.gather函数,从维度0上对x切片,得到一个新的张量y,其shape为 [2, 3, 3]。然后我们将y和x在维度0上做拼接,得到一个形状为 [5, 3, 3] 的张量z。
以上就是tf.gather函数的作用和使用方法的完整攻略。