Pytorch使用技巧之Dataloader中的collate_fn参数详析

  • Post category:Python

Pytorch使用技巧之Dataloader中的collate_fn参数详析

在PyTorch中,DataLoader是一个非常常用的数据加载模块,用于从数据集中批量加载数据。作为一个优秀的数据加载模块,DataLoader可以并行加载数据、支持自定义数据读取器等,能够极大地提高数据加载效率和准确性。

然而,当我们面对一些非常规的数据集时(比如长度不一致的样本、带有多个标签的样本等),DataLoader的默认行为可能无法满足我们的需求。在这种情况下,我们就需要使用collate_fn参数来自定义样本的组合方式,从而满足我们的需求。

collate_fn参数的作用

collate_fnDataLoader的一个参数,用于指定如何将多个样本组合成一个batch。具体来说,collate_fn的作用是将多个样本合并成一个batch,并将每个样本的特征以及标签分别组合成一个tensor返回给神经网络模型。

如果我们不显式地指定collate_fn,那么DataLoader的默认行为是将多个样本按第一维度堆叠起来,并将样本的特征和标签分别组成一个tensor返回给模型。但是,如果我们要在一个batch中使用长度不一致的样本,或者要在一个batch中组合多个标签,那么默认的collate_fn显然无法满足我们的需求。

因此,我们需要自定义collate_fn来满足我们的需求。下面,我们将通过两个实例来具体讲解collate_fn的使用方式。

示例1:使用collate_fn组合长度不一致的样本

在一些任务中,我们的样本长度是不一致的,如文本分类中,一个文本可能有不同的句子数量。这时,我们需要使用collate_fn参数,将不同长度的样本组合到一个batch中,便于网络同时处理多个样本。

下面是一个使用collate_fn组合长度不一致的样本的实例:

import torch
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return sample

def collate_fn(batch):
    seqs = [torch.LongTensor(item) for item in batch]
    lengths = torch.LongTensor([len(item) for item in seqs])
    seqs = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=0)
    return seqs, lengths

text_data = [[1, 2, 3], [4, 5], [6], [7, 8, 9, 10]]
text_dataset = TextDataset(text_data)
text_dataloader = DataLoader(text_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

for batch in text_dataloader:
    print(batch)

在这个实例中,我们的数据集是一个长度不一致的文本序列,定义了一个TextDataset数据集。collate_fn函数中,我们使用了pad_sequence函数将不同长度的文本序列对其到同一长度,并用0进行padding。每个文本序列和其长度组成一个tuple,同时作为一个batch返回给网络,方便进行训练。

示例2:使用collate_fn组合多个标签

在一些任务中,我们的样本可能需要有多个标签,比如一个图片可能对应多个标签,比如人脸识别中的性别、年龄等。这时,我们需要使用collate_fn参数,将不同标签的数据组合到一个tensor中,便于网络处理多个标签。

下面是一个使用collate_fn组合多个标签的实例:

import torch
from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        target = self.targets[index]
        return sample, target['gender'], target['age']

def collate_fn(batch):
    imgs = [item[0] for item in batch]
    genders = torch.LongTensor([item[1] for item in batch])
    ages = torch.LongTensor([item[2] for item in batch])
    return torch.stack(imgs, 0), genders, ages

fake_data = torch.rand(4, 3, 32, 32)
fake_targets = [{'gender': 0, 'age': 20}, {'gender': 1, 'age': 30}, {'gender': 0, 'age': 25}, {'gender': 1, 'age': 35}]
image_dataset = ImageDataset(fake_data, fake_targets)
image_dataloader = DataLoader(image_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

for batch in image_dataloader:
    print(batch)

在这个实例中,我们的数据集是一个图片数据集,定义了一个ImageDataset数据集,其实可以用torchvision自带的数据集,这里为了方便自定义了一个。其中我们使用了一个字典记录每个样本的两个标签(gender和age)。在collate_fn函数中,我们将图片、gender和age分别组成一个tensor返回给网络。

总结

collate_fn参数是DataLoader中非常方便的一个参数,可以自由地组合各种不同类型的样本,让我们可以更好地适应各种实际场景。在具体使用时,我们需要根据实际情况进行自定义。如果遇到数据格式比较特殊的话,建议先将数据读取出来后,观察其格式,然后再结合collate_fn的使用方式,进行适当地组合。