详解TensorFlow的 tf.expand_dims 函数:在指定位置增加维度

  • Post category:Python

tf.expand_dims 函数是 TensorFlow 中的一个重要函数,可以用于增加张量的维度。它可以在给定的位置上,向张量中添加一个指定大小的新的维度。具体的用法和实例如下:

一、参数解释

tf.expand_dims(tensor, axis)

该函数的参数分别为:

  • tensor: 需要扩充维度的张量,它的 rank 至少为 1。
  • axis: 需要添加的新维度的位置,新维度索引为 axis。axis 不能大于张量的 rank+1。

二、使用方法

  1. 将一个二维张量变为三维张量

解析:下面我们来看一个例子,使用 tf.expand_dims 可以将二维张量变为三维张量。比如一共有 $4$ 个样本,每个样本有 $28$ 个特征,使用 tf.expand_dims 可以将二维张量[n_samples, n_features]变为三维张量[n_samples, 1, n_features]。

import tensorflow as tf
import numpy as np

tensor_2d = tf.constant(np.random.rand(4, 28),dtype=tf.float32)

#在维度1添加1个新的维度
tensor_3d = tf.expand_dims(tensor_2d, axis=1)

print("tensor_2d.shape:{}".format(tensor_2d.shape))
print("tensor_3d.shape:{}".format(tensor_3d.shape))
  • 输出:
tensor_2d.shape: (4, 28)
tensor_3d.shape: (4, 1, 28)
  • 解析:我们使用 numpy 生成的一个 $4 × 28$ 的二维张量,然后使用 tf.expand_dims 在第2个位置上添加了一个新的维度,这时它变成了具有三个维度的张量,第一个维度为 $4$ 表示样本数量,第二个维度为 $1$ 表示我们需要增加的新的维度,而第三个维度为 $28$ 表示每个样本的特征数量。

  • 将一个标量变为一个张量

解析:除了用于增加新的维度,还可以使用 tf.expand_dims 将一个标量变成一个张量,比如将 $1$ 变成一个维度为 $[1]$ 的张量。

tensor_scalar = tf.constant(1)

#在维度0添加1个新的维度
tensor_1d = tf.expand_dims(tensor_scalar, axis=0)

print("tensor_scalar.shape:{}".format(tensor_scalar.shape))
print("tensor_1d.shape:{}".format(tensor_1d.shape))
  • 输出:
tensor_scalar.shape: ()
tensor_1d.shape: (1,)
  • 解析: 上面代码中 tensor_scalar 值为 $1$,它原来是一个标量,使用 tf.expand_dims 可以将它变成一个维度为 $[1]$ 的张量。 显然,这时张量的第 $0$ 个维度为 $1$。

总结

因此,通过上面的例子,我们了解了使用 tf.expand_dims 函数,可以添加新的维度,使处理张量数据更加方便和高效。