pytorch 数据处理:定义自己的数据集合实例

  • Post category:Python

我来详细讲解一下“pytorch数据处理:定义自己的数据集合实例”的完整攻略。

定义自己的数据集合实例

在PyTorch中,我们可以通过torch.utils.data.Dataset类来定义自己的数据集合实例。为此,我们需要对原有的Dataset进行继承,并重载__len__函数和__getitem__函数。具体步骤如下:

from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self):
        # 初始化数据集

    def __len__(self):
        # 返回数据集的长度

    def __getitem__(self, idx):
        # 根据数据集的索引idx,返回相应的数据样本

在具体定义MyDataset的时候,我们需要实现__init____len____getitem__三个函数。其中,__init__函数通常用于初始化数据集,__len__函数用于返回数据集的长度,__getitem__函数用于根据数据集的索引idx,返回相应的数据样本。下面我们逐一介绍一下这三个函数:

1. 定义__init__(self)函数

__init__函数中,我们一般会读取数据集,并根据需要进行预处理,例如图像数据集我们会进行图像增强和数据标准化等操作。下面是一个读取图像文件夹的__init__函数的示例:

class MyDataset(Dataset):
    def __init__(self, root, transform=None):
        """
        root: 数据集的根目录
        transform: 预处理函数
        """
        super(MyDataset, self).__init__()

        self.root = root
        self.transform = transform

        self.imgs = list(sorted(os.listdir(root)))

    def __len__(self):
        return len(self.imgs)

这里我们通过将数据集的根目录赋值给self.root,将预处理函数赋值给self.transform,并利用os.listdir函数获取数据集中所有图像的文件名,并通过list.sorted函数按照文件名的字典序进行排序,最终保存在self.imgs中,方便在__getitem__函数中通过索引获取图像数据样本。

2. 定义__len__(self)函数

实现__len__函数非常简单,只需要返回数据集的样本数量即可,例如:

class MyDataset(Dataset):
    def __init__(self, root, transform=None):
        pass

    def __len__(self):
        return self.total_samples

3. 定义__getitem__(self, idx)函数

__getitem__函数中,我们需要根据提供的索引idx获取相应的数据样本。如果我们的数据集是一个图像数据集,那么我们可以通过PIL和transforms来加载图像并进行预处理。具体代码如下:

class MyDataset(Dataset):
    def __init__(self, root, transform=None):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        # 获取图像文件名
        img_name = os.path.join(self.root, self.imgs[idx])

        # 加载图像
        img = Image.open(img_name)

        # 如果有预处理函数transform,则对图像进行预处理
        if self.transform is not None:
            img = self.transform(img)

        return img

以上是PyTorch定义自己的数据集的基本步骤,下面我们介绍两个使用示例。

案例示例一

在这个案例中,我们可以创建一个包含5个0至9之间整数的数据集,并对其进行批处理:

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = range(10)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx])

my_dataset = MyDataset()
my_dataloader = DataLoader(my_dataset,
                            batch_size=2,
                            shuffle=True,
                            num_workers=0)

for i, batch in enumerate(my_dataloader):
    print(f"Batch {i}: {batch}")

运行上述代码后,我们可以看到在控制台输出的批量数据:

Batch 0: tensor([2, 3])
Batch 1: tensor([4, 8])
Batch 2: tensor([1, 0])
Batch 3: tensor([7, 9])
Batch 4: tensor([6, 5])

案例示例二

在这个案例中,我们可以读取一个包含实际图片的文件夹,并对其进行预处理和批处理,最后我们将这批图片保存在一个Tensorboard中以便于我们进行可视化:

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

my_dataset = ImageFolder('/path/to/image/folder', transform=transform)
my_dataloader = DataLoader(my_dataset, batch_size=2, shuffle=True, num_workers=0)

images_tensor_list = []

for i, batch in enumerate(my_dataloader):
    images_tensor_list.append(batch)

    if i == 9:
        break

images_tensor = torch.cat(images_tensor_list, dim=0)

writer = SummaryWriter(log_dir='./runs')
writer.add_images('images', images_tensor)

plt.imshow(images_tensor[0].permute(1, 2, 0))

这段代码将从指定路径中读取图片并进行预处理,最后将批处理的图片数据保存在一个Tensorboard中。其中,transform变量表示对图像进行的预处理步骤,这里使用了一些常见的图像预处理方法,例如调整图像大小、将PIL.Image强制转换为张量以及归一化。之后,我们读取了含有图片的文件夹,并通过ImageFolder类将图片转化为PyTorch中的数据集对象。随后,我们通过DataLoader将数据集包裹在一起以便于它们可以作为一个生成器被调用。最后,我们迭代数据集,并将迭代过程中的图像通过TensorBoard进行可视化。回到主程序,我们通过plt.imshow函数显示了一个经过调整大小、强制转换和标准化的图像。

以上是两个PyTorch数据集的案例示例,希望能够对你有所帮助。