如何在 PyTorch 中使用 DataLoader?
处理大型数据集需要一次性将它们加载到内存中。在大多数情况下,由于系统中可用的内存量有限,我们会面临内存中断。此外,由于一次加载的大量数据集,程序往往运行缓慢。 PyTorch 提供了一种解决方案,通过使用 DataLoader 将数据加载过程与自动批处理并行化。 Dataloader 已被用于并行化数据加载,因为这可以提高速度并节省内存。
数据加载器构造函数位于 torch.utils.data 包中。它有各种参数,其中唯一要传递的强制参数是必须加载的数据集,其余都是可选参数。
句法:
DataLoader(dataset, shuffle=True, sampler=None, batch_size=32)
自定义数据集上的 DataLoader:
要在自定义数据集上实现数据加载器,我们需要覆盖以下两个子类函数:
- _len_()函数:返回数据集的大小。
- _getitem_()函数:从数据集中返回给定索引的样本。
Python3
# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# defining the Dataset class
class data_set(Dataset):
def __init__(self):
numbers = list(range(0, 100, 1))
self.data = numbers
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
dataset = data_set()
# implementing dataloader on the dataset and printing per batch
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
for i, batch in enumerate(dataloader):
print(i, batch)
Python3
# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import seaborn as sns
from torch.utils.data import TensorDataset
# defining the dataset consisting of
# two columns from iris dataset
iris = sns.load_dataset('iris')
petal_length = torch.tensor(iris['petal_length'])
petal_width = torch.tensor(iris['petal_width'])
dataset = TensorDataset(petal_length, petal_width)
# implementing dataloader on the dataset
# and printing per batch
dataloader = DataLoader(dataset,
batch_size=5,
shuffle=True)
for i in dataloader:
print(i)
输出:
内置数据集上的数据加载器:
蟒蛇3
# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import seaborn as sns
from torch.utils.data import TensorDataset
# defining the dataset consisting of
# two columns from iris dataset
iris = sns.load_dataset('iris')
petal_length = torch.tensor(iris['petal_length'])
petal_width = torch.tensor(iris['petal_width'])
dataset = TensorDataset(petal_length, petal_width)
# implementing dataloader on the dataset
# and printing per batch
dataloader = DataLoader(dataset,
batch_size=5,
shuffle=True)
for i in dataloader:
print(i)
输出: