介绍
tf.metrics.accuracy
是 TensorFlow 中用于计算准确率的函数。该函数可以使用在分类任务如图像分类、文本分类等场景中。
用法
tf.metrics.accuracy
的基本用法如下:
metrics = tf.metrics.accuracy(labels, predictions)
其中, labels
是真实的标签值(ground-truth),predictions
是模型预测的标签值。
该函数的返回值是一个元组,第一个值是 accuracy
,即准确率,第二个值是 update_op
,用于更新准确率统计。
你可以定义一个 TensorFlow session 执行 update_op 操作,然后获取 accuracy 的值。由于 accuracy 是 running metric,因此需要在训练可迭代的 batch 上执行 update_op 来跟踪每个 batch 上的准确率,最终计算出整个训练集上的准确率。
sess.run(tf.global_variables_initializer())
for epoch in range(num_epochs):
for batch_data in dataset:
feed_dict = {input: batch_data[0], labels: batch_data[1]}
sess.run(update_op, feed_dict=feed_dict)
train_accuracy = sess.run(metrics)
print("Epoch: {}, Train accuracy: {}".format(epoch, train_accuracy))
示例
现在我们以 MNIST 数据集为例,展示如何使用 tf.metrics.accuracy
函数计算训练集和测试集上的准确率。
首先,我们需要加载 MNIST 数据集:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
然后我们定义输入数据的占位符,从数据集中获取 batch,定义一个简单的线性分类器,使用 softmax 作为激活函数:
import tensorflow as tf
input = tf.placeholder(tf.float32, [None, 28 * 28])
labels = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.random_normal([28 * 28, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.matmul(input, W) + b
predictions = tf.nn.softmax(logits)
接着,我们可以使用 tf.metrics.accuracy
函数计算训练和测试数据集的准确率:
# 训练集上的准确率
train_metrics = tf.metrics.accuracy(tf.argmax(labels, 1), tf.argmax(predictions, 1))
# 测试集上的准确率
test_logits = tf.matmul(input, W) + b
test_predictions = tf.nn.softmax(test_logits)
test_metrics = tf.metrics.accuracy(tf.argmax(labels, 1), tf.argmax(test_predictions, 1))
最后我们需要在 session 中执行 train_metrics
和 test_metrics
的 update_op
统计数据,然后获取准确率。
num_epochs = 10
batch_size = 32
train_init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(train_init_op)
for epoch in range(num_epochs):
for step in range(mnist.train.num_examples // batch_size):
batch_input, batch_labels = mnist.train.next_batch(batch_size)
feed_dict = {input: batch_input, labels: batch_labels}
sess.run(train_metrics[1], feed_dict=feed_dict)
train_accuracy = sess.run(train_metrics[0])
print('Epoch: {}, train accuracy: {}'.format(epoch, train_accuracy))
sess.run(test_metrics[1], feed_dict={input: mnist.test.images, labels: mnist.test.labels})
test_accuracy = sess.run(test_metrics[0])
print('Epoch: {}, test accuracy: {}'.format(epoch, test_accuracy))
总结
在分类任务(如图像分类)中,准确率是一项重要的指标。在 TensorFlow 中,可以使用 tf.metrics.accuracy
函数计算准确率。需要注意的是,该函数是 running metric,因此需要在训练 dataset 的可迭代的 batch 上执行 update_op,来跟踪每个 batch 上的准确率,最终计算出整个训练集上的准确率。