📅  最后修改于: 2023-12-03 15:04:42.868000             🧑  作者: Mango
PyTorch是当前最流行的深度学习框架之一,它提供了许多优秀的API来帮助我们更好地加载和处理数据。本文将介绍如何使用PyTorch加载数据,让你更快更有效地训练你的深度学习模型。
如果你有一个普通数据集,比如一些图片或CSV文件,可以使用 torch.utils.data.Dataset
类来加载它们。比如我们有如下的目录结构:
data/
|-- train/
| |-- 1.jpg
| |-- 2.jpg
| |-- ...
|-- test/
| |-- 1.jpg
| |-- 2.jpg
| |-- ...
可以使用以下代码来加载训练集和测试集:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.file_list = sorted(os.listdir(data_dir))
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
filename = self.file_list[idx]
filepath = os.path.join(self.data_dir, filename)
with open(filepath, 'rb') as f:
image = Image.open(f)
image = image.convert('RGB')
return image, filename
train_dataset = MyDataset('data/train')
test_dataset = MyDataset('data/test')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
在这个例子中,我们首先定义了一个 MyDataset
类,继承了 torch.utils.data.Dataset
类,并实现了它的 __len__
和 __getitem__
方法。 __len__
方法返回数据集的长度, __getitem__
方法返回数据集中某个样本的数据和标签(没有标签就不要返回)。
然后我们利用已经定义好的 MyDataset
类来加载训练集和测试集,都传入 DataLoader
类中。 DataLoader
的作用是从 MyDataset
中按照 batch_size、shuffle 等方式抽取数据并返回一个 batch。在训练时可以逐个 batch 地处理数据。
如果你有一个图像数据集,比如 CIFAR-10
或 ImageNet
等,可以使用 torchvision.datasets
中的数据集来加载它们。比如我们可以使用以下代码来加载 CIFAR-10
:
import torchvision
train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, transform=None, target_transform=None, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, transform=None, target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
在这个例子中,我们首先使用 torchvision.datasets.CIFAR10
类来加载 CIFAR-10
数据集,该类会帮我们下载数据集并转换为 PIL.Image
类型,然后我们再通过 DataLoader
类来加载数据。需要注意的是,我们可以通过 transform
参数来进行数据增强操作,比如裁剪,翻转等操作。
如果你希望自定义数据增强操作,可以使用 torchvision.transforms
来实现。比如我们可以使用以下代码来定义一个数据增强操作:
import random
from torchvision.transforms import functional as F
class MyTransforms:
def __call__(self, image):
if random.random() < 0.5:
image = F.hflip(image)
image = F.resize(image, (256, 256))
return image
然后我们可以在加载 CIFAR-10
数据集时指定这个数据增强操作:
import torchvision
transform = MyTransforms()
train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, transform=transform, target_transform=None, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, transform=None, target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
在这个例子中,我们首先定义了 MyTransforms
类,继承了 torchvision.transforms.functional
,并实现了 __call__
方法,该方法实现了数据增强操作。然后我们将 transform=transform
传入到 CIFAR-10
数据集中,以便在加载数据时应用数据增强操作。
通过本文的介绍,我们学会了如何使用 PyTorch 加载普通数据集和图像数据集,以及如何自定义数据增强操作。当然,我们只是介绍了 PyTorch 中的一些基本操作,想要更深入的了解 PyTorch,还需要不断地学习和实践。