Pytorch使用技巧之Dataloader中的collate_fn参数详析
在PyTorch中,DataLoader
是一个非常常用的数据加载模块,用于从数据集中批量加载数据。作为一个优秀的数据加载模块,DataLoader
可以并行加载数据、支持自定义数据读取器等,能够极大地提高数据加载效率和准确性。
然而,当我们面对一些非常规的数据集时(比如长度不一致的样本、带有多个标签的样本等),DataLoader
的默认行为可能无法满足我们的需求。在这种情况下,我们就需要使用collate_fn
参数来自定义样本的组合方式,从而满足我们的需求。
collate_fn参数的作用
collate_fn
是DataLoader
的一个参数,用于指定如何将多个样本组合成一个batch。具体来说,collate_fn
的作用是将多个样本合并成一个batch,并将每个样本的特征以及标签分别组合成一个tensor返回给神经网络模型。
如果我们不显式地指定collate_fn
,那么DataLoader
的默认行为是将多个样本按第一维度堆叠起来,并将样本的特征和标签分别组成一个tensor返回给模型。但是,如果我们要在一个batch中使用长度不一致的样本,或者要在一个batch中组合多个标签,那么默认的collate_fn
显然无法满足我们的需求。
因此,我们需要自定义collate_fn
来满足我们的需求。下面,我们将通过两个实例来具体讲解collate_fn
的使用方式。
示例1:使用collate_fn组合长度不一致的样本
在一些任务中,我们的样本长度是不一致的,如文本分类中,一个文本可能有不同的句子数量。这时,我们需要使用collate_fn
参数,将不同长度的样本组合到一个batch中,便于网络同时处理多个样本。
下面是一个使用collate_fn
组合长度不一致的样本的实例:
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
return sample
def collate_fn(batch):
seqs = [torch.LongTensor(item) for item in batch]
lengths = torch.LongTensor([len(item) for item in seqs])
seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=0)
return seqs, lengths
text_data = [[1, 2, 3], [4, 5], [6], [7, 8, 9, 10]]
text_dataset = TextDataset(text_data)
text_dataloader = DataLoader(text_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
for batch in text_dataloader:
print(batch)
在这个实例中,我们的数据集是一个长度不一致的文本序列,定义了一个TextDataset
数据集。collate_fn
函数中,我们使用了pad_sequence
函数将不同长度的文本序列对其到同一长度,并用0进行padding。每个文本序列和其长度组成一个tuple,同时作为一个batch返回给网络,方便进行训练。
示例2:使用collate_fn组合多个标签
在一些任务中,我们的样本可能需要有多个标签,比如一个图片可能对应多个标签,比如人脸识别中的性别、年龄等。这时,我们需要使用collate_fn
参数,将不同标签的数据组合到一个tensor中,便于网络处理多个标签。
下面是一个使用collate_fn
组合多个标签的实例:
import torch
from torch.utils.data import Dataset, DataLoader
class ImageDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
target = self.targets[index]
return sample, target['gender'], target['age']
def collate_fn(batch):
imgs = [item[0] for item in batch]
genders = torch.LongTensor([item[1] for item in batch])
ages = torch.LongTensor([item[2] for item in batch])
return torch.stack(imgs, 0), genders, ages
fake_data = torch.rand(4, 3, 32, 32)
fake_targets = [{'gender': 0, 'age': 20}, {'gender': 1, 'age': 30}, {'gender': 0, 'age': 25}, {'gender': 1, 'age': 35}]
image_dataset = ImageDataset(fake_data, fake_targets)
image_dataloader = DataLoader(image_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
for batch in image_dataloader:
print(batch)
在这个实例中,我们的数据集是一个图片数据集,定义了一个ImageDataset
数据集,其实可以用torchvision自带的数据集,这里为了方便自定义了一个。其中我们使用了一个字典记录每个样本的两个标签(gender和age)。在collate_fn
函数中,我们将图片、gender和age分别组成一个tensor返回给网络。
总结
collate_fn
参数是DataLoader
中非常方便的一个参数,可以自由地组合各种不同类型的样本,让我们可以更好地适应各种实际场景。在具体使用时,我们需要根据实际情况进行自定义。如果遇到数据格式比较特殊的话,建议先将数据读取出来后,观察其格式,然后再结合collate_fn
的使用方式,进行适当地组合。