📜  PyTorch-加载数据(1)

📅  最后修改于: 2023-12-03 15:04:42.868000             🧑  作者: Mango

PyTorch-加载数据

PyTorch是当前最流行的深度学习框架之一,它提供了许多优秀的API来帮助我们更好地加载和处理数据。本文将介绍如何使用PyTorch加载数据,让你更快更有效地训练你的深度学习模型。

1. 加载普通数据集

如果你有一个普通数据集,比如一些图片或CSV文件,可以使用 torch.utils.data.Dataset 类来加载它们。比如我们有如下的目录结构:

data/
|-- train/
|   |-- 1.jpg
|   |-- 2.jpg
|   |-- ...
|-- test/
|   |-- 1.jpg
|   |-- 2.jpg
|   |-- ...

可以使用以下代码来加载训练集和测试集:

import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.file_list = sorted(os.listdir(data_dir))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = self.file_list[idx]
        filepath = os.path.join(self.data_dir, filename)
        with open(filepath, 'rb') as f:
            image = Image.open(f)
            image = image.convert('RGB')
        return image, filename

train_dataset = MyDataset('data/train')
test_dataset = MyDataset('data/test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

在这个例子中,我们首先定义了一个 MyDataset 类,继承了 torch.utils.data.Dataset 类,并实现了它的 __len____getitem__ 方法。 __len__ 方法返回数据集的长度, __getitem__ 方法返回数据集中某个样本的数据和标签(没有标签就不要返回)。

然后我们利用已经定义好的 MyDataset 类来加载训练集和测试集,都传入 DataLoader 类中。 DataLoader 的作用是从 MyDataset 中按照 batch_size、shuffle 等方式抽取数据并返回一个 batch。在训练时可以逐个 batch 地处理数据。

2. 加载图像数据集

如果你有一个图像数据集,比如 CIFAR-10ImageNet 等,可以使用 torchvision.datasets 中的数据集来加载它们。比如我们可以使用以下代码来加载 CIFAR-10

import torchvision

train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, transform=None, target_transform=None, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, transform=None, target_transform=None, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

在这个例子中,我们首先使用 torchvision.datasets.CIFAR10 类来加载 CIFAR-10 数据集,该类会帮我们下载数据集并转换为 PIL.Image 类型,然后我们再通过 DataLoader 类来加载数据。需要注意的是,我们可以通过 transform 参数来进行数据增强操作,比如裁剪,翻转等操作。

3. 自定义数据增强

如果你希望自定义数据增强操作,可以使用 torchvision.transforms 来实现。比如我们可以使用以下代码来定义一个数据增强操作:

import random
from torchvision.transforms import functional as F

class MyTransforms:
    def __call__(self, image):
        if random.random() < 0.5:
            image = F.hflip(image)
        image = F.resize(image, (256, 256))
        return image

然后我们可以在加载 CIFAR-10 数据集时指定这个数据增强操作:

import torchvision

transform = MyTransforms()

train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, transform=transform, target_transform=None, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, transform=None, target_transform=None, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

在这个例子中,我们首先定义了 MyTransforms 类,继承了 torchvision.transforms.functional,并实现了 __call__ 方法,该方法实现了数据增强操作。然后我们将 transform=transform 传入到 CIFAR-10 数据集中,以便在加载数据时应用数据增强操作。

总结

通过本文的介绍,我们学会了如何使用 PyTorch 加载普通数据集和图像数据集,以及如何自定义数据增强操作。当然,我们只是介绍了 PyTorch 中的一些基本操作,想要更深入的了解 PyTorch,还需要不断地学习和实践。