Tensorflow中使用tfrecord方式读取数据的方法

  • Post category:Python

TensorFlow是目前深度学习领域非常流行的框架之一,可以对数据进行高效的处理,并且支持分布式训练。在使用TensorFlow进行训练之前,我们需要将数据进行预处理,这其中就包括将数据转换为TensorFlow所支持的数据格式,tfrecord就是其中一种格式。

使用tfrecord格式存储数据可以提高数据读取速度和训练速度,尤其适用于大规模数据训练的场景。本文将详细介绍如何使用TensorFlow的tfrecord方式读取数据。

一、TFRecord数据格式简介

TFRecord是TensorFlow推荐使用的一种数据格式,它采用二进制序列化的方法,可以方便地存储和读取不同格式的数据。在TFRecord文件中,数据以字符串形式被存储,而每个字符串又被序列化为一个二进制的实例protocol buffer,称之为Example。每个Example都是由多个Features组成的,而Features则由多个键值对组成。一个Feature由包含一个Tensor的张量列表(List)或可变长度值(BytesList)构成。

二、生成TFRecord文件

我们首先需要将原始数据转换为TFRecord格式,并且将其保存在一个文件中。在本例中,我们将生成一个存储图像和标签的数据集。

import tensorflow as tf
import numpy as np

# 图像大小
IMG_HEIGHT = 224
IMG_WIDTH = 224

# 加载数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# 计算数据集长度
n_train = x_train.shape[0]
n_test = x_test.shape[0]

# 建立类别字典
class_dict = {
    0: 'airplane',
    1: 'automobile',
    2: 'bird',
    3: 'cat',
    4: 'deer',
    5: 'dog', 
    6: 'frog',
    7: 'horse',
    8: 'ship',
    9: 'truck'
}

# 定义函数:将图片转为Bytes字符串
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 定义函数:将整数转为Int64列表
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 定义函数:将数据集转为TFRecord格式
def write_tfrecord(tfrecord_file, data, labels):
    with tf.io.TFRecordWriter(tfrecord_file) as writer:
        for i in range(data.shape[0]):
            # 转换为Bytes字符串形式
            image_bytes = tf.compat.as_bytes(data[i].tostring())

            # 构建Feature字典
            feature_dict = {
                'image': _bytes_feature(image_bytes),
                'label': _int64_feature(int(labels[i]))
            }

            # 构建Example
            example = tf.train.Example(features=tf.train.Features(feature=feature_dict))

            # 序列化
            serialized = example.SerializeToString()

            # 写入TFRecord文件
            writer.write(serialized)

# 将训练集保存到TFRecord文件
train_tfrecord_file = 'train.tfrecord'
write_tfrecord(train_tfrecord_file, x_train, y_train)

# 将测试集保存到TFRecord文件
test_tfrecord_file = 'test.tfrecord'
write_tfrecord(test_tfrecord_file, x_test, y_test)

三、读取TFRecord文件

在模型训练时,我们需要读取TFRecord格式的数据,并将其转化为Tensor。TensorFlow的数据预处理API提供了一系列的函数,可以很方便地读取TFRecord文件并将其转换为Tensor。

1. 读取TFRecord文件

读取TFRecord文件的方法是使用tf.data.TFRecordDataset类:

# 定义函数:读取TFRecord文件
def read_tfrecord(filelist):
    # 读取TFRecord文件
    dataset = tf.data.TFRecordDataset(filelist)

    return dataset

2. 解析Example

接下来是解析Example的过程,我们需要从Example中获取每个Feature对应的Tensor:

# 定义函数:解析Example转为Tensor
def parse_example(serialized_example):
    # 定义Feature的键名及其默认值
    feature_dict = {
        'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'label': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
    }

    # 解析并构建Feature字典
    feature_dict = tf.io.parse_single_example(serialized_example, feature_dict)

    # 解码图像数据,转为Tensor
    image = tf.io.decode_raw(feature_dict['image'], tf.uint8)
    image = tf.reshape(image, (IMG_HEIGHT, IMG_WIDTH, 3))
    image = tf.cast(image, tf.float32) / 255.

    # 将标签转为Tensor
    label = tf.cast(feature_dict['label'], tf.int32)

    return image, label

3. 处理数据

最后,我们需要对数据进行处理,包括打乱顺序、设置Batch size等操作:

# 定义函数:处理数据
def process_data(filelist, batch_size, shuffle):
    # 读取TFRecord文件
    dataset = read_tfrecord(filelist)

    # 解析Example
    dataset = dataset.map(parse_example)

    # 打乱顺序
    if shuffle:
        dataset = dataset.shuffle(buffer_size=1024)

    # 设置Batch size
    dataset = dataset.batch(batch_size)

    return dataset

我们可以调用这个函数来处理数据:

# 读取训练集并处理
train_dataset = process_data(train_tfrecord_file, batch_size=32, shuffle=True)

# 读取测试集并处理
test_dataset = process_data(test_tfrecord_file, batch_size=32, shuffle=False)

四、完整代码

最后,我们将整个代码集于此处。

import tensorflow as tf
import numpy as np

# 图像大小
IMG_HEIGHT = 224
IMG_WIDTH = 224

# 加载数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# 计算数据集长度
n_train = x_train.shape[0]
n_test = x_test.shape[0]

# 建立类别字典
class_dict = {
    0: 'airplane',
    1: 'automobile',
    2: 'bird',
    3: 'cat',
    4: 'deer',
    5: 'dog', 
    6: 'frog',
    7: 'horse',
    8: 'ship',
    9: 'truck'
}

# 定义函数:将图片转为Bytes字符串
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 定义函数:将整数转为Int64列表
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 定义函数:将数据集转为TFRecord格式
def write_tfrecord(tfrecord_file, data, labels):
    with tf.io.TFRecordWriter(tfrecord_file) as writer:
        for i in range(data.shape[0]):
            # 转换为Bytes字符串形式
            image_bytes = tf.compat.as_bytes(data[i].tostring())

            # 构建Feature字典
            feature_dict = {
                'image': _bytes_feature(image_bytes),
                'label': _int64_feature(int(labels[i]))
            }

            # 构建Example
            example = tf.train.Example(features=tf.train.Features(feature=feature_dict))

            # 序列化
            serialized = example.SerializeToString()

            # 写入TFRecord文件
            writer.write(serialized)

# 将训练集保存到TFRecord文件
train_tfrecord_file = 'train.tfrecord'
write_tfrecord(train_tfrecord_file, x_train, y_train)

# 将测试集保存到TFRecord文件
test_tfrecord_file = 'test.tfrecord'
write_tfrecord(test_tfrecord_file, x_test, y_test)

# 定义函数:读取TFRecord文件
def read_tfrecord(filelist):
    # 读取TFRecord文件
    dataset = tf.data.TFRecordDataset(filelist)

    return dataset

# 定义函数:解析Example转为Tensor
def parse_example(serialized_example):
    # 定义Feature的键名及其默认值
    feature_dict = {
        'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'label': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
    }

    # 解析并构建Feature字典
    feature_dict = tf.io.parse_single_example(serialized_example, feature_dict)

    # 解码图像数据,转为Tensor
    image = tf.io.decode_raw(feature_dict['image'], tf.uint8)
    image = tf.reshape(image, (IMG_HEIGHT, IMG_WIDTH, 3))
    image = tf.cast(image, tf.float32) / 255.

    # 将标签转为Tensor
    label = tf.cast(feature_dict['label'], tf.int32)

    return image, label

# 定义函数:处理数据
def process_data(filelist, batch_size, shuffle):
    # 读取TFRecord文件
    dataset = read_tfrecord(filelist)

    # 解析Example
    dataset = dataset.map(parse_example)

    # 打乱顺序
    if shuffle:
        dataset = dataset.shuffle(buffer_size=1024)

    # 设置Batch size
    dataset = dataset.batch(batch_size)

    return dataset

# 读取训练集并处理
train_dataset = process_data(train_tfrecord_file, batch_size=32, shuffle=True)

# 读取测试集并处理
test_dataset = process_data(test_tfrecord_file, batch_size=32, shuffle=False)

五、示例说明

1. TensorFlow Object Detection API

我们可以使用TensorFlow Object Detection API训练目标检测模型,在这个过程中涉及到处理大量的原始图像数据。我们可以使用tfrecord格式来存储和训练数据,从而加快模型的训练速度。

2. 场景文本识别

在场景文本识别任务中,通常需要处理大量的文本和图像数据。我们可以使用tfrecord格式存储这些数据,通过TensorFlow读取并预处理数据,然后使用适当的深度学习算法来进行训练,从而提高识别准确率和速度。