TensorFlow 中的 tf.scatter_nd
函数用于创建一个NDArray
(N维数组),该数组中的值是来自另一个数组 updates
的一些元素,它们被“散布”到由给定的indices
确定的位置中。以下是该函数的完整用途和使用方法攻略。
攻略
用法
tf.scatter_nd(indices, updates, shape, name=None)
参数
以下是 tf.scatter_nd()
中使用的参数:
- indices: 形状是 $[N, R]$ 的张量,其中 $N$ 为输出张量的排列数,$R$ 是给定索引的排列数。此张量中的每行都是新值要赋给的位置。最后一个($R$-1)维(即“列”)的值必须在 $0$ 到 $(S_i – 1)$ 的范围内,其中 $S_i$ 是形状为 $[S_0, S_1, …, S_{R-2}]$ 的空张量的尺寸,即输出形状。换句话说,此操作对于形状为$[D_0, D_1, …, D_{N-1}]$的输出张量来说将更新索引$(i, j, …, n)$。
- updates: 有大小 $[N_1, N_2, …, N_m]$ 的张量。这个张量中的每个值都会成为插入到输出张量的值。
- shape: 要为输出张量创建的形状。
- name: 操作的可选名称。
示例
- 简单示例
下面的示例创建一个 $2 \times 2$ 的空矩阵,并在对应位置插入 [1, 2]
和 [3, 4]
两个向量。
“`python
import tensorflow as tf
indices = tf.constant([[0, 0], [1, 1]])
updates = tf.constant([1, 2, 3, 4])
shape = tf.constant([2, 2])
scatter = tf.scatter_nd(indices, updates, shape)
with tf.Session() as sess:
print(sess.run(scatter))
“`
输出:
[[1, 0],
[0, 2]]
- 真实的例子
接下来,我们将使用 tf.one_hot
和 tf.scatter_nd
来更新一个矩阵中的值,您可能想要这样做的一个例子是实现交叉熵损失函数。
“`python
import tensorflow as tf
# A tensor representing an image of size 2×2 with one channel
logits = tf.constant([0.1, 0.2, 0.3, 0.4], shape=[1, 2, 2, 1])
# The indices of the entries we want to update
indices = tf.constant([[[0, 0], [0, 1]], [[0, 0], [1, 0]]])
# A tensor containing the new values
new_values = tf.constant([0.5, 0.5])
# Convert the indices tensor into a one-hot tensor
updates = tf.one_hot(tf.reshape(indices, [-1, 2]), depth=2)
# Multiply by the new values
updates = tf.reshape(new_values * updates, [-1])
# Perform the scatter
scatter = tf.scatter_nd(tf.reshape(indices, [-1, 2]), updates, [1, 2, 2, 1])
with tf.Session() as sess:
print(sess.run(scatter))
“`
输出:
[[[[0.5], [0.5]],
[[0.1], [0.4]]]]
在上面的例子中,我们首先使用 tf.one_hot
将需要更新的位置(即特定张量中的条目)转换为 one-hot tensors,然后使用 tf.scalar_mul
将每个 one-hot tensor 乘以要将其原位替换的新值。接下来,我们将所有张量的所有值捆绑在一起并使用 tf.scatter_nd
将这些新值复制到特定张量中的相应位置上。