下面就是关于PyTorch加载自己的图像数据集实例的完整攻略,该攻略分为以下三个步骤:
- 准备数据集
- 自定义dataset和dataloader
- 加载数据集
1. 准备数据集
在加载数据集之前,我们需要先准备好数据集。数据集应该按照以下文件夹结构进行组织:
dataset/
class1/
image1.jpg
image2.jpg
...
class2/
image1.jpg
image2.jpg
...
...
其中,dataset文件夹为数据集的根目录,class1、class2等文件夹为各分类的目录,每个分类目录下包含多个对应分类的图像文件。
2. 自定义dataset和dataloader
接下来,我们需要自定义dataset和dataloader。dataset是一个数据集抽象类,我们需要继承该抽象类并实现__len__和__getitem__两个方法,用于返回数据集的长度和数据。
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.img_names = []
self.labels = []
self.classes = os.listdir(root_dir)
for i, c in enumerate(self.classes):
imgs = os.listdir(os.path.join(root_dir, c))
self.img_names += [os.path.join(c, img) for img in imgs]
self.labels += [i]*len(imgs)
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.img_names[idx])
label = self.labels[idx]
img = Image.open(img_path).convert('RGB')
# 在这里可以对图像进行预处理、数据增强等操作
return img, label
然后,我们需要自定义dataloader,即数据加载器,用于从数据集中加载batch_size个数据进行训练。
from torch.utils.data import DataLoader
dataset = CustomDataset("dataset")
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
在这里我们使用了自定义的CustomDataset和PyTorch自带的DataLoader类,通过指定batch_size和shuffle参数来进行数据加载。
3. 加载数据集
最后,我们需要对dataloader进行迭代,以加载数据集并进行训练。下面是一个使用dataloader进行训练的示例。
import torch.nn as nn
import torch.optim as optim
model = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(32*8*8, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
for epoch in range(num_epochs):
for imgs, labels in dataloader:
optimizer.zero_grad()
preds = model(imgs)
loss = criterion(preds, labels)
loss.backward()
optimizer.step()
在这里,我们使用了nn.Module实现了一个简单的卷积神经网络,并使用了CrossEntropyLoss作为损失函数,采用随机梯度下降算法对模型进行训练。每一次迭代我们使用dataloader加载batch_size个数据,并传入模型进行训练。