📜  PyTorch-数据集(1)

📅  最后修改于: 2023-12-03 15:34:33.102000             🧑  作者: Mango

PyTorch 数据集

PyTorch 数据集是针对深度学习任务开发的数据集库,旨在简化数据准备和处理成为管道。在 PyTorch 中,可以使用自带的 TorchVision 库或自定义数据集加载器来使用数据集。

数据集类型

PyTorch 数据集库支持多种数据类型,包括图像、文本、语音、视频等,可以通过 TorchVision 库中的数据集类来访问常用的数据集。其中包括 MNIST、CIFAR10、CIFAR100、ImageNet 等常用的图像数据集,而 Text、Audio、Video等类型数据集则可以通过自定义数据集来实现。

TorchVision 库

TorchVision 库是 PyTorch官方提供的一个基于 PyTorch 的视觉工具库。其中包含了一系列用于图像预处理、数据加载和模型训练的函数和工具。

加载数据集

TorchVision 库中提供了多个数据集类:torchvision.datasets.MNISTtorchvision.datasets.CIFAR10torchvision.datasets.CIFAR100torchvision.datasets.ImageNet 等。每个数据集类都有以下参数:

  • root:数据集存放路径;
  • train:是否使用训练集;
  • transform:对数据进行处理的函数;
  • target_transform:对标签进行处理的函数;
  • download:如果数据集未下载,是否自动下载。

以下是加载 CIFAR10 数据集的代码示例:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

trainset = datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
testset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=False)
数据预处理

TorchVision 库提供了一系列用于图像预处理的函数,这些函数可以用于数据增强、数据规范化等任务。

import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
数据加载器

在 PyTorch 中,数据加载器是一种迭代器,用于多线程的数据加载。TorchVision 库中提供了 torch.utils.data.DataLoader 类来实现数据加载器。以下是使用 DataLoader 加载 CIFAR10 数据集的代码示例:

import torch.utils.data as data

trainloader = data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
自定义数据集

除了使用 TorchVision 库提供的数据集类,还可以通过自定义数据集类来加载和处理自己的数据集。自定义数据集类需要继承 torch.utils.data.Dataset 类,实现 __getitem__()__len__()方法。

以下是加载自定义数据集的代码示例:

import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data, target, transform=None, target_transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        img, target = self.data[index], self.target[index]
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            target = self.target_transform(target)
        return img, target
    
    def __len__(self):
        return len(self.data)
结语

PyTorch 数据集库提供了方便的方式来加载和处理各种类型的数据集。使用 TorchVision 库可以方便地加载常用的图像数据集,同时也可以通过自定义数据集来处理不同类型的数据。