TensorFlow 模型载入方法汇总(小结)

  • Post category:Python

TensorFlow模型载入方法汇总(小结)

本文将总结几种载入TensorFlow模型的方法,并提供相应的代码示例。

方法一:使用tf.train.Saver

这是TensorFlow中最基本也最常用的载入模型方法。它生成一个Checkpoint文件,保存模型的变量信息,再通过这个文件载入模型的参数。具体步骤如下:

  1. 定义模型并训练得到模型参数。
import tensorflow as tf

# 定义模型
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
w = tf.Variable(0.)
b = tf.Variable(0.)
loss = tf.reduce_mean(tf.square(y - w * x - b))

# 训练模型并保存模型参数
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(1000):
        sess.run(train_op, feed_dict={x: [1, 2, 3], y: [2, 4, 6]})
    saver.save(sess, './model.ckpt')
  1. 载入模型并使用模型参数。
# 载入模型并使用模型参数
with tf.Session() as sess:
    saver.restore(sess, './model.ckpt')
    print(sess.run(w))  # 输出2.0

方法二:使用tf.saved_model

tf.saved_model是TensorFlow官方提供的模型保存和载入的方法,它保存的是一个完整的模型,包括计算图、变量等信息,并支持对模型进行版本管理。

  1. 定义模型并训练得到模型参数。
import tensorflow as tf

# 定义模型
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
w = tf.Variable(0.)
b = tf.Variable(0.)
loss = tf.reduce_mean(tf.square(y - w * x - b))

# 训练模型并保存为saved_model
trainer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
builder = tf.saved_model.builder.SavedModelBuilder('./saved_model')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(1000):
        sess.run(trainer, feed_dict={x: [1, 2, 3], y: [2, 4, 6]})
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING])
    builder.save()
  1. 载入模型并使用模型参数。
# 载入saved_model并使用模型参数
import tensorflow as tf

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], './saved_model')
    w = sess.graph.get_tensor_by_name('Variable:0')
    print(sess.run(w))  # 输出2.0

以上是两种常用的TensorFlow模型载入方法,根据不同的需求可以灵活运用。