在 TensorFlow 中,如果我们想要训练模型中的可训练变量,我们需要对它们进行初始化。通常使用 tf.global_variables_initializer() 或者 tf.train.Saver() 的 restore() 方法。除此之外,还有一个特殊的初始化方法:tf.trainable_variables_initializer()。
tf.trainable_variables_initializer() 会初始化一个 TensorFlow 图中所有可训练的变量。这些变量通常是神经网络中的参数,例如权重和偏置。需要说明的是,只有在定义神经网络的代码中声明为可训练变量的 Tensor 才能够被该函数所初始化。
使用方法非常简单,直接调用该函数即可,例如:
init = tf.trainable_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# 进行训练模型的操作
下面通过两个实例来加深理解。
示例一
假设我们要使用 TensorFlow 实现一个简单的线性回归模型,其中有两个可训练的变量 W 和 b。
import tensorflow as tf
tf.reset_default_graph()
X = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='X')
y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y')
W = tf.Variable(initial_value=tf.zeros(shape=[1, 1]), name='W', trainable=True)
b = tf.Variable(initial_value=tf.zeros(shape=[1]), name='b', trainable=True)
y_pred = tf.matmul(X, W) + b
loss = tf.reduce_mean(tf.square(y-y_pred))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
init = tf.trainable_variables_initializer()
这里,我们需要对 W 和 b 变量进行初始化,以便能够在训练模型时使用。我们可以使用 init = tf.trainable_variables_initializer()
来进行初始化,然后在 session 中运行该操作即可:
with tf.Session() as sess:
sess.run(init)
# 进行训练模型的操作
示例二
如果我们在定义网络时,只将其中一个变量(如 W)声明为可训练变量,那么 tf.trainable_variables_initializer() 函数只会对该变量进行初始化。
import tensorflow as tf
tf.reset_default_graph()
X = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='X')
y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y')
W = tf.Variable(initial_value=tf.zeros(shape=[1, 1]), name='W', trainable=True)
b = tf.Variable(initial_value=tf.zeros(shape=[1]), name='b', trainable=False)
y_pred = tf.matmul(X, W) + b
loss = tf.reduce_mean(tf.square(y-y_pred))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
init = tf.trainable_variables_initializer()
在这个例子中,我们将 b 声明为不可训练变量(trainable=False),这意味着它不会被 tf.trainable_variables_initializer() 函数所初始化。初始化时只有 W 会被初始化,我们可以在 session 中运行该操作进行初始化。
with tf.Session() as sess:
sess.run(init)
# 进行训练模型的操作
这就是 tf.trainable_variables_initializer() 函数的用途和使用方法的详细攻略。