详解TensorFlow的 tf.trainable_variables 函数:返回所有可训练的变量

  • Post category:Python

tf.trainable_variables() 函数返回当前计算图中需要训练的变量的列表。需要训练的变量就是在训练时会更新其值的变量。

使用 tf.trainable_variables() 函数的步骤如下:

  1. 首先,在定义变量时需要设置 trainable 参数为 True。这样,变量在构建计算图时就会被加入到需要训练的变量列表中。

  2. 接着,使用 tf.trainable_variables() 函数获取需要训练的变量列表。

  3. 在训练过程中,使用优化器来最小化损失函数。当调用 optimizer.minimize(loss) 时,TensorFlow 会自动计算需要训练的变量的梯度并更新其值。因此,只需要将损失函数作为参数传入优化器即可。

下面提供两个实例,以更好地了解 tf.trainable_variables() 函数的作用和使用方法:

实例一:使用 tf.trainable_variables() 训练模型

假设我们有一个包含两个变量 Wb 的模型,其定义如下:

import tensorflow as tf

# 定义模型
x = tf.placeholder(tf.float32, shape=[None, 784], name='input')
W = tf.Variable(tf.zeros([784, 10]), name='W', trainable=True)
b = tf.Variable(tf.zeros([10]), name='b', trainable=True)
y = tf.nn.softmax(tf.matmul(x, W) + b, name='output')

在构建模型时,我们将参数 trainable 设置为 True,这样可以保证在训练过程中该变量会被更新。

接下来,使用 tf.trainable_variables() 函数获取需要训练的变量列表:

variables_to_train = tf.trainable_variables()

在训练时,设置优化器和损失函数,并进行训练:

# 定义损失函数
y_ = tf.placeholder(tf.float32, shape=[None, 10], name='label')
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

# 设置优化器和训练操作
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5)
train_op = optimizer.minimize(cross_entropy, var_list=variables_to_train)

# 进行训练
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(1000):
    batch_xs, batch_ys = ...
    sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

在每一次迭代中,使用 sess.run(train_op) 对模型进行训练。由于优化器会自动计算需要训练的变量的梯度并更新其值,因此在这里不需要显式地更新变量。

实例二:使用 tf.contrib.framework.get_variables()var_list 控制需要训练的变量

除了使用 tf.trainable_variables() 函数,我们还可以使用 tf.contrib.framework.get_variables() 函数手动获取需要训练的变量,并使用 var_list 参数控制需要训练的变量。

import tensorflow as tf

# 定义模型
x = tf.placeholder(tf.float32, shape=[None, 784], name='input')
W = tf.Variable(tf.zeros([784, 10]), name='W', trainable=True)
b = tf.Variable(tf.zeros([10]), name='b', trainable=False)
y = tf.nn.softmax(tf.matmul(x, W) + b, name='output')

# 使用 tf.contrib.framework.get_variables() 函数获取需要训练的变量
variables_to_train = tf.contrib.framework.get_variables(trainable=True)

# 设置优化器和训练操作
y_ = tf.placeholder(tf.float32, shape=[None, 10], name='label')
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5)
train_op = optimizer.minimize(cross_entropy, var_list=variables_to_train)

# 进行训练
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(1000):
    batch_xs, batch_ys = ...
    sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})

在这个例子中,我们将 btrainable 参数设置为 False,即该变量不参与训练。然后,使用 tf.contrib.framework.get_variables(trainable=True) 获取需要训练的变量,这将返回一个列表,其中只包含 W 变量。最后,将变量列表作为参数传入 train_op 操作中。

总之,tf.trainable_variables() 函数可以方便地获取需要训练的变量列表,并在训练过程中进行自动更新。在使用时,需要保证变量的 trainable 参数已经被正确设置,并且需要根据实际需求控制需要训练的变量范围。