详解TensorFlow的 tf.data.Iterator 函数:创建数据集迭代器

  • Post category:Python

TensorFlow是一款非常流行的机器学习框架,但在训练过程中,数据处理和输入是非常关键的环节。为了解决这个问题,TensorFlow提供了 tf.data 得到了广泛的应用。tf.data.Iterator 是 tf.data 模块其中一个非常重要的函数,接下来我们详细介绍它的作用和使用方法。

Iterator的作用

tf.data.Iterator 可以用来迭代数据集,它是其它 tf.data 中的函数的组合方式,把不同的数据源组合在一起。通过 Iterator,可以实现数据的随机读取、重复使用等功能。在训练模型时,经常需要访问数据集中的每一个元素,Iterator 可以让我们实现这一需求。

Iterator的使用方法

使用 Iterator 通常需要经过以下三个步骤:

1.定义数据集:首先,需要定义一个数据集,并使用 dataset.make_one_shot_iterator() 方法创建一个 Iterator 对象。例如,我们可以定义一个包含 MNIST 手写数字数据集的数据集:

import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 创建数据集
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# 创建迭代器
iterator = dataset.make_one_shot_iterator()

2.获取数据:通过调用这个 Iterator 对象的 get_next() 方法,可以从数据集中获取一个数据。

# 读取一个数据
x, y = iterator.get_next()

3.使用数据:从 Iterator 中获取数据后,可以用来训练模型。

# 模型训练
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(iterator, epochs=5, steps_per_epoch=60000)

tf.data.Iterator的实例

下面提供两个使用 tf.data.Iterator 的详细实例:

实例一:DNN模型训练

该实例展示了如何在 DNN 模型的训练过程中使用 Iterator 来管理数据的输入。

import tensorflow as tf
import numpy as np

# 读取数据
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# 数据预处理
train_images = train_images / 255
test_images = test_images / 255

# 定义dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(100).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(32)

# 定义Iterator
train_iterator = train_dataset.make_initializable_iterator()
test_iterator = test_dataset.make_initializable_iterator()

# 定义计算图
x = tf.placeholder(tf.float32, shape=[None, 28, 28])
y_ = tf.placeholder(tf.int64, shape=[None])

with tf.name_scope('DNN'):
    X = tf.reshape(x, [-1, 784], name='X')
    h1 = tf.layers.dense(X, units=256, activation='relu', name='h1')
    h2 = tf.layers.dense(h1, units=128, activation='relu', name='h2')
    h3 = tf.layers.dense(h2, units=10, activation='softmax', name='h3')

cross_entropy = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=h3))

train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(h3, 1), y_), tf.float32))

# 模型训练
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(10):
        # 初始化Iterator,并读取训练数据
        sess.run(train_iterator.initializer)
        while True:
            try:
                batch_x, batch_y = sess.run(train_iterator.get_next())
                sess.run(train_step, feed_dict={x: batch_x, y_: batch_y})
            except tf.errors.OutOfRangeError:
                break

        # 初始化Iterator,并读取测试数据
        sess.run(test_iterator.initializer)
        test_acc = []
        while True:
            try:
                batch_x, batch_y = sess.run(test_iterator.get_next())
                acc = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y})
                test_acc.append(acc)
            except tf.errors.OutOfRangeError:
                break
        print('epoch:' + str(i) + ", test accuracy=" + str(np.mean(test_acc)))

实例二:使用TFRecord格式的数据进行训练

该示例演示了如何使用 tf.data 和 Iterator 读取并使用 TFRecord 格式的文件。TFRecord 是一种二进制文件格式,可以存储大型数据集并进行快速读写。

import tensorflow as tf

# 将数据写入TFRecord文件
def write_to_tfrecord(image_filenames, labels, tfrecord_filename):
    writer = tf.python_io.TFRecordWriter(tfrecord_filename)
    for i in range(len(image_filenames)):
        image_filename = image_filenames[i]
        label = labels[i]
        # 读取图片
        with tf.gfile.FastGFile(image_filename, 'rb') as f:
            image_data = f.read()
        # 创建Example对象
        example = tf.train.Example(features=tf.train.Features(feature={
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
        }))
        writer.write(example.SerializeToString())
    writer.close()

# 读取TFRecord文件,并创建dataset及Iterator
def read_from_tfrecord(tfrecord_filename, batch_size):
    dataset = tf.data.TFRecordDataset(tfrecord_filename)

    def parser(record):
        keys_to_features = {
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
        # 将二进制的image转化为图像格式
        image = tf.decode_raw(parsed['image'], tf.uint8)
        # 将图像数据的类型转化为float32
        image = tf.cast(image, tf.float32)
        image = tf.reshape(image, [28, 28, 1])

        return image, parsed['label']

    dataset = dataset.map(parser).batch(batch_size)

    iterator = dataset.make_initializable_iterator()
    return iterator

# 模型训练
batch_size = 32
train_record_file = 'train.tfrecords'
test_record_file = 'test.tfrecords'
train_filenames = [...]
train_labels = [...]
test_filenames = [...]
test_labels = [...]

write_to_tfrecord(train_filenames, train_labels, train_record_file)
write_to_tfrecord(test_filenames, test_labels, test_record_file)

train_iterator = read_from_tfrecord(train_record_file, batch_size)
test_iterator = read_from_tfrecord(test_record_file, batch_size)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(10):
        print('epoch ' + str(i))
        # 初始化Iterator,并读取训练数据
        sess.run(train_iterator.initializer)
        while True:
            try:
                batch_x, batch_y = sess.run(train_iterator.get_next())
                sess.run(train_step, feed_dict={x: batch_x, y_: batch_y})
            except tf.errors.OutOfRangeError:
                break

        # 初始化Iterator,并读取测试数据
        sess.run(test_iterator.initializer)
        test_acc = []
        while True:
            try:
                batch_x, batch_y = sess.run(test_iterator.get_next())
                acc = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y})
                test_acc.append(acc)
            except tf.errors.OutOfRangeError:
                break
        print('epoch:' + str(i) + ", test accuracy=" + str(np.mean(test_acc)))

以上是关于 TensorFlow tf.data.Iterator 的使用方法和示例,希望对你有所帮助。