📅  最后修改于: 2023-12-03 15:34:33.102000             🧑  作者: Mango
PyTorch 数据集是针对深度学习任务开发的数据集库,旨在简化数据准备和处理成为管道。在 PyTorch 中,可以使用自带的 TorchVision 库或自定义数据集加载器来使用数据集。
PyTorch 数据集库支持多种数据类型,包括图像、文本、语音、视频等,可以通过 TorchVision 库中的数据集类来访问常用的数据集。其中包括 MNIST、CIFAR10、CIFAR100、ImageNet 等常用的图像数据集,而 Text、Audio、Video等类型数据集则可以通过自定义数据集来实现。
TorchVision 库是 PyTorch官方提供的一个基于 PyTorch 的视觉工具库。其中包含了一系列用于图像预处理、数据加载和模型训练的函数和工具。
TorchVision 库中提供了多个数据集类:torchvision.datasets.MNIST
、torchvision.datasets.CIFAR10
、torchvision.datasets.CIFAR100
、torchvision.datasets.ImageNet
等。每个数据集类都有以下参数:
root
:数据集存放路径;train
:是否使用训练集;transform
:对数据进行处理的函数;target_transform
:对标签进行处理的函数;download
:如果数据集未下载,是否自动下载。以下是加载 CIFAR10 数据集的代码示例:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
trainset = datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
testset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=False)
TorchVision 库提供了一系列用于图像预处理的函数,这些函数可以用于数据增强、数据规范化等任务。
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
在 PyTorch 中,数据加载器是一种迭代器,用于多线程的数据加载。TorchVision 库中提供了 torch.utils.data.DataLoader
类来实现数据加载器。以下是使用 DataLoader 加载 CIFAR10 数据集的代码示例:
import torch.utils.data as data
trainloader = data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
除了使用 TorchVision 库提供的数据集类,还可以通过自定义数据集类来加载和处理自己的数据集。自定义数据集类需要继承 torch.utils.data.Dataset
类,实现 __getitem__()
和 __len__()
方法。
以下是加载自定义数据集的代码示例:
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, data, target, transform=None, target_transform=None):
self.data = data
self.target = target
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
img, target = self.data[index], self.target[index]
if self.transform:
img = self.transform(img)
if self.target_transform:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
PyTorch 数据集库提供了方便的方式来加载和处理各种类型的数据集。使用 TorchVision 库可以方便地加载常用的图像数据集,同时也可以通过自定义数据集来处理不同类型的数据。