详解TensorFlow的 tf.data.TFRecordDataset 函数:从 TFRecord 文件创建数据集

  • Post category:Python

TensorFlow 中的 tf.data.TFRecordDataset 函数是用于读取 TFRecord 格式数据的函数,它可以让你高效地读取大规模数据集,并使用 TensorFlow 进行训练和预测。在本篇攻略中,我们将详细讲解 TFRecord 格式数据的产生方法、TFRecordDataset 函数的使用方法以及两个相关实例。

1. TFRecord 格式数据

TFRecord 是 Tensorflow 推荐的一种高效存储数据的二进制格式。TFRecord 通过将运算不同步和变量混合到一个统一的格式中来减少 I/O 边界和数据预处理。它包含了多个样本,每个样本则由多个 feature(即 Tensorflow 中 tf.train.Feature )组成。一个 feature 可以是二进制字串、整数、浮点数等。

常用的创建 TFRecord 文件的方式是使用 tf.io.TFRecordWriter 函数,它可以将样本中的 feature 一一转换为 tf.train.Example,再将多个 tf.train.Example 存储到一个 TFRecord 文件中。

下面是一个创建 TFRecord 文件的示例代码:

import tensorflow as tf
import numpy as np

def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def create_example(features, labels):
  features = {
    'features': _float_feature(features.flatten().tolist()),
    'labels': _int64_feature(labels.flatten().tolist())
  }
  return tf.train.Example(features=tf.train.Features(feature=features))

def create_dataset(num_samples, num_features, num_labels):
  features = np.random.rand(num_samples, num_features).astype(np.float32)
  labels = np.random.randint(0 , num_labels, size=(num_samples, 1)).astype(np.int64)

  with tf.io.TFRecordWriter('data.tfrecords') as writer:
    for i in range(num_samples):
      example = create_example(features[i,:], labels[i,:])
      writer.write(example.SerializeToString())

create_dataset(1000, 10, 5)

通过以上代码,将随机生成的 1000 个样本及其对应的 10 个特征和 1 个标签保存到了名为 data.tfrecords 的文件中。

2. 使用方法

TFRecordDataset 可以读取 TFRecord 文件,并且可以使用多线程异步读取数据。下面是一个 TFRecordDataset 的使用案例:

import tensorflow as tf

def parser(record):
  keys_to_features = {
    'features': tf.io.FixedLenFeature((10,), tf.float32),
    'labels': tf.io.FixedLenFeature((1,), tf.int64)
  }
  parsed = tf.io.parse_single_example(record, keys_to_features)
  return parsed['features'], parsed['labels']

dataset = tf.data.TFRecordDataset(['data.tfrecords'])
dataset = dataset.map(parser)
dataset = dataset.batch(32)

for features, labels in dataset:
  print(features.shape, labels.shape)

其中,parser 函数用于解析 TFRecord 文件中的数据,parse_single_example 函数用于解析一条记录中的 features 和 labels,返回值将被送入下一步的处理流程里。FixedLenFeature 子函数用于指定每个 feature 的大小。

在上述代码中,我们将使用 map 函数将每一条数据进行解析。解析完成后,我们使用了 batch 函数将数据集分成了大小为 32 的批次。

以上代码运行后,将输出以下结果:

(32, 10) (32, 1)
(32, 10) (32, 1)
(32, 10) (32, 1)
...

由此可见,成功解析了大小为 10、标签为 1 的数据,并且按批次进行了数据读取。

3. 实例说明

本部分将介绍两个使用 TFRecordDataset 函数的实例。

实例 1. CIFAR10 数据集读取

下面的代码演示了如何读取 CIFAR10 数据集中的数据:

import tensorflow as tf
import os

def unpickle(file):
  import _pickle as cPickle
  with open(file, 'rb') as fo:
    dict = cPickle.load(fo, encoding='bytes')
  return dict

def get_filenames(dataset_dir):
  return [os.path.join(dataset_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]

def get_parser():
  def parser(record):
    keys_to_features = {
      "data": tf.io.FixedLenFeature([], tf.string),
      "label": tf.io.FixedLenFeature([], tf.int64)
    }
    parsed = tf.io.parse_single_example(record, keys_to_features)
    image = tf.io.decode_raw(parsed["data"], tf.uint8)
    image = tf.cast(tf.reshape(image, [3,32,32]), tf.float32)
    image = tf.transpose(image, [1, 2, 0])
    label = tf.cast(parsed["label"], tf.int32)    
    return image, label
  return parser

def cifar10_input_fn(dataset_dir, batch_size, shuffle=False):
  filenames = get_filenames(dataset_dir)

  dataset = tf.data.TFRecordDataset(filenames)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=512)

  dataset = dataset.map(get_parser(), num_parallel_calls=8)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  return dataset

在上述代码中,我们首先定义了一个 get_parser 函数,函数内部使用了 parse_single_example 函数进行解析。在解析 CIFAR10 数据集的过程中,我们对图像数据进行了一些处理。在 cifar10_input_fn 函数中,我们使用 TFRecordDataset函数读取多个 TFRecord 文件,然后使用 map 函数调用解析函数。最后调用 batch 函数和 prefetch 函数对数据集进行处理,以便于训练。

以上代码中,我们可以直接使用函数 cifar10_input_fn 来获取 CIFAR10 数据集的 dataset 对象。在实际运用中,我们可以根据 dataset 的内容来设计具体的训练、测试或验证流程。

实例 2. 海量数据训练

当我们的训练数据集非常大时,我们很难将所有数据读入到内存中。我们可以使用 TFRecord 格式的文件来读取大规模数据。下面的代码演示了如何使用 TFRecordDataset 来进行海量数据的训练:

import tensorflow as tf
import os

def unpickle(file):
  import _pickle as cPickle
  with open(file, 'rb') as fo:
    dict = cPickle.load(fo, encoding='bytes')
  return dict

def get_filenames(dataset_dir):
  return [os.path.join(dataset_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]

def get_parser():
  def parser(record):
    keys_to_features = {
      "data": tf.io.FixedLenFeature([], tf.string),
      "label": tf.io.FixedLenFeature([], tf.int64)
    }
    parsed = tf.io.parse_single_example(record, keys_to_features)
    image = tf.io.decode_raw(parsed["data"], tf.uint8)
    image = tf.cast(tf.reshape(image, [3,32,32]), tf.float32)
    image = tf.transpose(image, [1, 2, 0])
    label = tf.cast(parsed["label"], tf.int32)    
    return image, label
  return parser

def cifar10_input_fn(dataset_dir, batch_size, shuffle=False):
  filenames = get_filenames(dataset_dir)

  dataset = tf.data.TFRecordDataset(filenames)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=512)

  dataset = dataset.map(get_parser(), num_parallel_calls=8)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  return dataset

def train(dataset_dir, epochs, batch_size):
  dataset = cifar10_input_fn(dataset_dir, batch_size=batch_size, shuffle=True)

  model = tf.keras.models.Sequential([ ... ])
  model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

  model.fit(dataset, epochs=epochs)

train('./cifar10/', 10, 32)

在上述代码中,我们通过 cifar10_input_fn 函数来获取 dataset 对象。在主函数 train 中,我们先通过 dataset 定义了一个 model,并使用 fit 函数来开始训练。通过这种方式,我们可以快速地完成大规模数据集的训练。

4. 总结

在本篇攻略中,我们详细讲解了 TFRecord 格式数据的产生方法、TFRecordDataset 函数的使用方法以及两个相关实例,希望对大家的 TensorFlow 学习和应用有所帮助。