📜  如何在 PyTorch 中使用 DataLoader?

📅  最后修改于: 2022-05-13 01:55:29.916000             🧑  作者: Mango

如何在 PyTorch 中使用 DataLoader?

处理大型数据集需要一次性将它们加载到内存中。在大多数情况下,由于系统中可用的内存量有限,我们会面临内存中断。此外,由于一次加载的大量数据集,程序往往运行缓慢。 PyTorch 提供了一种解决方案,通过使用 DataLoader 将数据加载过程与自动批处理并行化。 Dataloader 已被用于并行化数据加载,因为这可以提高速度并节省内存。

数据加载器构造函数位于 torch.utils.data 包中。它有各种参数,其中唯一要传递的强制参数是必须加载的数据集,其余都是可选参数。

句法:

自定义数据集上的 DataLoader:



要在自定义数据集上实现数据加载器,我们需要覆盖以下两个子类函数:

  • _len_()函数:返回数据集的大小。
  • _getitem_()函数:从数据集中返回给定索引的样本。
Python3
# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
  
# defining the Dataset class
class data_set(Dataset):
    def __init__(self):
        numbers = list(range(0, 100, 1))
        self.data = numbers
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]
  
  
dataset = data_set()
  
# implementing dataloader on the dataset and printing per batch
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
for i, batch in enumerate(dataloader):
    print(i, batch)


Python3
# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import seaborn as sns
from torch.utils.data import TensorDataset
  
# defining the dataset consisting of 
# two columns from iris dataset
iris = sns.load_dataset('iris')
petal_length = torch.tensor(iris['petal_length'])
petal_width = torch.tensor(iris['petal_width'])
dataset = TensorDataset(petal_length, petal_width)
  
# implementing dataloader on the dataset 
# and printing per batch
dataloader = DataLoader(dataset, 
                        batch_size=5, 
                        shuffle=True)
  
for i in dataloader:
    print(i)


输出:

内置数据集上的数据加载器:

蟒蛇3

# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import seaborn as sns
from torch.utils.data import TensorDataset
  
# defining the dataset consisting of 
# two columns from iris dataset
iris = sns.load_dataset('iris')
petal_length = torch.tensor(iris['petal_length'])
petal_width = torch.tensor(iris['petal_width'])
dataset = TensorDataset(petal_length, petal_width)
  
# implementing dataloader on the dataset 
# and printing per batch
dataloader = DataLoader(dataset, 
                        batch_size=5, 
                        shuffle=True)
  
for i in dataloader:
    print(i)

输出: