详解TensorFlow的 tf.metrics.accuracy 函数:计算准确率

  • Post category:Python

介绍

tf.metrics.accuracy 是 TensorFlow 中用于计算准确率的函数。该函数可以使用在分类任务如图像分类、文本分类等场景中。

用法

tf.metrics.accuracy 的基本用法如下:

metrics = tf.metrics.accuracy(labels, predictions)

其中, labels 是真实的标签值(ground-truth),predictions 是模型预测的标签值。

该函数的返回值是一个元组,第一个值是 accuracy,即准确率,第二个值是 update_op,用于更新准确率统计。

你可以定义一个 TensorFlow session 执行 update_op 操作,然后获取 accuracy 的值。由于 accuracy 是 running metric,因此需要在训练可迭代的 batch 上执行 update_op 来跟踪每个 batch 上的准确率,最终计算出整个训练集上的准确率。

sess.run(tf.global_variables_initializer())

for epoch in range(num_epochs):
    for batch_data in dataset:
        feed_dict = {input: batch_data[0], labels: batch_data[1]}
        sess.run(update_op, feed_dict=feed_dict)
    train_accuracy = sess.run(metrics)
    print("Epoch: {}, Train accuracy: {}".format(epoch, train_accuracy))

示例

现在我们以 MNIST 数据集为例,展示如何使用 tf.metrics.accuracy 函数计算训练集和测试集上的准确率。

首先,我们需要加载 MNIST 数据集:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

然后我们定义输入数据的占位符,从数据集中获取 batch,定义一个简单的线性分类器,使用 softmax 作为激活函数:

import tensorflow as tf

input = tf.placeholder(tf.float32, [None, 28 * 28])
labels = tf.placeholder(tf.float32, [None, 10])

W = tf.Variable(tf.random_normal([28 * 28, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.matmul(input, W) + b
predictions = tf.nn.softmax(logits)

接着,我们可以使用 tf.metrics.accuracy 函数计算训练和测试数据集的准确率:

# 训练集上的准确率
train_metrics = tf.metrics.accuracy(tf.argmax(labels, 1), tf.argmax(predictions, 1))

# 测试集上的准确率
test_logits = tf.matmul(input, W) + b
test_predictions = tf.nn.softmax(test_logits)
test_metrics = tf.metrics.accuracy(tf.argmax(labels, 1), tf.argmax(test_predictions, 1))

最后我们需要在 session 中执行 train_metricstest_metricsupdate_op 统计数据,然后获取准确率。

num_epochs = 10
batch_size = 32

train_init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(train_init_op)
    for epoch in range(num_epochs):
        for step in range(mnist.train.num_examples // batch_size):
            batch_input, batch_labels = mnist.train.next_batch(batch_size)
            feed_dict = {input: batch_input, labels: batch_labels}
            sess.run(train_metrics[1], feed_dict=feed_dict)

        train_accuracy = sess.run(train_metrics[0])
        print('Epoch: {}, train accuracy: {}'.format(epoch, train_accuracy))

        sess.run(test_metrics[1], feed_dict={input: mnist.test.images, labels: mnist.test.labels})
        test_accuracy = sess.run(test_metrics[0])
        print('Epoch: {}, test accuracy: {}'.format(epoch, test_accuracy))

总结

在分类任务(如图像分类)中,准确率是一项重要的指标。在 TensorFlow 中,可以使用 tf.metrics.accuracy 函数计算准确率。需要注意的是,该函数是 running metric,因此需要在训练 dataset 的可迭代的 batch 上执行 update_op,来跟踪每个 batch 上的准确率,最终计算出整个训练集上的准确率。