TensorFlow是一款非常流行的机器学习框架,但在训练过程中,数据处理和输入是非常关键的环节。为了解决这个问题,TensorFlow提供了 tf.data 得到了广泛的应用。tf.data.Iterator 是 tf.data 模块其中一个非常重要的函数,接下来我们详细介绍它的作用和使用方法。
Iterator的作用
tf.data.Iterator 可以用来迭代数据集,它是其它 tf.data 中的函数的组合方式,把不同的数据源组合在一起。通过 Iterator,可以实现数据的随机读取、重复使用等功能。在训练模型时,经常需要访问数据集中的每一个元素,Iterator 可以让我们实现这一需求。
Iterator的使用方法
使用 Iterator 通常需要经过以下三个步骤:
1.定义数据集:首先,需要定义一个数据集,并使用 dataset.make_one_shot_iterator() 方法创建一个 Iterator 对象。例如,我们可以定义一个包含 MNIST 手写数字数据集的数据集:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 创建数据集
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# 创建迭代器
iterator = dataset.make_one_shot_iterator()
2.获取数据:通过调用这个 Iterator 对象的 get_next() 方法,可以从数据集中获取一个数据。
# 读取一个数据
x, y = iterator.get_next()
3.使用数据:从 Iterator 中获取数据后,可以用来训练模型。
# 模型训练
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(iterator, epochs=5, steps_per_epoch=60000)
tf.data.Iterator的实例
下面提供两个使用 tf.data.Iterator 的详细实例:
实例一:DNN模型训练
该实例展示了如何在 DNN 模型的训练过程中使用 Iterator 来管理数据的输入。
import tensorflow as tf
import numpy as np
# 读取数据
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# 数据预处理
train_images = train_images / 255
test_images = test_images / 255
# 定义dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(100).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(32)
# 定义Iterator
train_iterator = train_dataset.make_initializable_iterator()
test_iterator = test_dataset.make_initializable_iterator()
# 定义计算图
x = tf.placeholder(tf.float32, shape=[None, 28, 28])
y_ = tf.placeholder(tf.int64, shape=[None])
with tf.name_scope('DNN'):
X = tf.reshape(x, [-1, 784], name='X')
h1 = tf.layers.dense(X, units=256, activation='relu', name='h1')
h2 = tf.layers.dense(h1, units=128, activation='relu', name='h2')
h3 = tf.layers.dense(h2, units=10, activation='softmax', name='h3')
cross_entropy = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=h3))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(h3, 1), y_), tf.float32))
# 模型训练
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(10):
# 初始化Iterator,并读取训练数据
sess.run(train_iterator.initializer)
while True:
try:
batch_x, batch_y = sess.run(train_iterator.get_next())
sess.run(train_step, feed_dict={x: batch_x, y_: batch_y})
except tf.errors.OutOfRangeError:
break
# 初始化Iterator,并读取测试数据
sess.run(test_iterator.initializer)
test_acc = []
while True:
try:
batch_x, batch_y = sess.run(test_iterator.get_next())
acc = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y})
test_acc.append(acc)
except tf.errors.OutOfRangeError:
break
print('epoch:' + str(i) + ", test accuracy=" + str(np.mean(test_acc)))
实例二:使用TFRecord格式的数据进行训练
该示例演示了如何使用 tf.data 和 Iterator 读取并使用 TFRecord 格式的文件。TFRecord 是一种二进制文件格式,可以存储大型数据集并进行快速读写。
import tensorflow as tf
# 将数据写入TFRecord文件
def write_to_tfrecord(image_filenames, labels, tfrecord_filename):
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
for i in range(len(image_filenames)):
image_filename = image_filenames[i]
label = labels[i]
# 读取图片
with tf.gfile.FastGFile(image_filename, 'rb') as f:
image_data = f.read()
# 创建Example对象
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
writer.close()
# 读取TFRecord文件,并创建dataset及Iterator
def read_from_tfrecord(tfrecord_filename, batch_size):
dataset = tf.data.TFRecordDataset(tfrecord_filename)
def parser(record):
keys_to_features = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
}
parsed = tf.parse_single_example(record, keys_to_features)
# 将二进制的image转化为图像格式
image = tf.decode_raw(parsed['image'], tf.uint8)
# 将图像数据的类型转化为float32
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [28, 28, 1])
return image, parsed['label']
dataset = dataset.map(parser).batch(batch_size)
iterator = dataset.make_initializable_iterator()
return iterator
# 模型训练
batch_size = 32
train_record_file = 'train.tfrecords'
test_record_file = 'test.tfrecords'
train_filenames = [...]
train_labels = [...]
test_filenames = [...]
test_labels = [...]
write_to_tfrecord(train_filenames, train_labels, train_record_file)
write_to_tfrecord(test_filenames, test_labels, test_record_file)
train_iterator = read_from_tfrecord(train_record_file, batch_size)
test_iterator = read_from_tfrecord(test_record_file, batch_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(10):
print('epoch ' + str(i))
# 初始化Iterator,并读取训练数据
sess.run(train_iterator.initializer)
while True:
try:
batch_x, batch_y = sess.run(train_iterator.get_next())
sess.run(train_step, feed_dict={x: batch_x, y_: batch_y})
except tf.errors.OutOfRangeError:
break
# 初始化Iterator,并读取测试数据
sess.run(test_iterator.initializer)
test_acc = []
while True:
try:
batch_x, batch_y = sess.run(test_iterator.get_next())
acc = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y})
test_acc.append(acc)
except tf.errors.OutOfRangeError:
break
print('epoch:' + str(i) + ", test accuracy=" + str(np.mean(test_acc)))
以上是关于 TensorFlow tf.data.Iterator 的使用方法和示例,希望对你有所帮助。