在 Pytorch 中加载数据
在本文中,我们将讨论如何在 PyTorch 中加载不同类型的数据。
出于演示目的,Pytorch 附带了 3 个数据集部分,即 torchaudio、torchvision 和 torchtext。我们可以利用这些演示数据集来了解如何使用 Pytorch 加载声音、图像和文本数据。
Torchaudio 数据集
使用 Pytorch 在 torchaudio 中加载演示 yes_no 音频数据集。
Yes_No 数据集是一个音频波形数据集,它的值以 3 个值的元组形式存储,即波形、采样率、标签,其中波形表示音频信号,采样率表示频率,标签表示是或否。
- 导入 torch 和 torchaudio 包。 (如有必要,使用 pip install torchaudio 安装)
- 使用带有数据集访问器的 torchaudio函数,后跟数据集名称。
- 现在,传递必须存储数据集的路径并指定 download = True 以下载数据集。这里,'./' 指定根目录。
- 现在,使用 for 循环遍历加载的数据集,并访问存储在元组中的 3 个值以查看数据集的样本。
要加载您的自定义数据:
Syntax: torch.utils.data.DataLoader(data, batch_size, shuffle)
Parameters:
- data – audio dataset or the path to the audio dataset
- batch_size – for large dataset, batch_size specifies how much data to load at once
- shuffle – a bool type. Setting it to True will shuffle the data.
Python3
# import the torch and torchaudio dataset packages.
import torch
import torchaudio
# access the dataset in torchaudio package using
# datasets followed by dataset name.
# './' makes sure that the dataset is stored
# in a root directory.
# download = True ensures that the
# data gets downloaded
yesno_data = torchaudio.datasets.YESNO('./',
download=True)
# loading the first 5 data from yesno_data
for i in range(5):
waveform, sample_rate, labels = yesno_data[i]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(
waveform, sample_rate, labels))
Python3
# import the torch and
# torchvision dataset packages.
import torch
import torchvision
# access the dataset in torchvision package using
# .datasets followed by dataset name.
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
Python3
# import necessary function
# from torchvision package
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
# specify the image dataset folder
data_dir = r'path to dataset\train'
# perform some transformations like resizing,
# centring and tensorconversion
# using transforms function
transform = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
# pass the image data folder and
# transform function to the datasets
# .imagefolder function
dataset = datasets.ImageFolder(data_dir,
transform=transform)
# now use dataloder function load the
# dataset in the specified transformation.
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=32,
shuffle=True)
# iter function iterates through all the
# images and labels and stores in two variables
images, labels = next(iter(dataloader))
# print the total no of samples
print('Number of samples: ', len(images))
image = images[2][0] # load 3rd sample
# visualize the image
plt.imshow(image, cmap='gray')
# print the size of image
print("Image Size: ", image.size())
# print the label
print(label)
Python3
# import the torch and torchtext dataset packages.
import torch
import torchtext
# access the dataset in torchtext package
# using .datasets followed by dataset name.
text_data = torchtext.datasets.IMDB(split='train')
# define a function to tokenize
# the words in the corpus
def tokenize(label, line):
return line.split()
# define a empty list to store
# the tokenized words
tokens = []
# iterate over the text_data and
# tokenize each line and store
# it in the list tokens
for label, line in text_data:
tokens += tokenize(label, line)
print('The total no. of tokens in imdb dataset is',
len(tokens))
输出:
Torchvision 数据集
使用 Pytorch 在 torchvision 中加载演示 ImageNet 视觉数据集。单击此处通过注册下载数据集。
Python3
# import the torch and
# torchvision dataset packages.
import torch
import torchvision
# access the dataset in torchvision package using
# .datasets followed by dataset name.
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
代码说明:
- 该过程与加载音频数据几乎相同。
- 在这里,必须导入 torchvision 而不是 torchaudio。
- 使用带有数据集访问器的 torchvision函数,后跟数据集名称。
- 现在,传递数据集所在的路径。由于 ImageNet 数据集不再可公开访问,因此请在本地系统中下载根数据并将路径传递给此函数。这将轻松加载视觉数据。
要加载您的自定义图像数据,请使用上面提到的 torch.utils.data.DataLoader(data, batch_size, shuffle)。
Python3
# import necessary function
# from torchvision package
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
# specify the image dataset folder
data_dir = r'path to dataset\train'
# perform some transformations like resizing,
# centring and tensorconversion
# using transforms function
transform = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
# pass the image data folder and
# transform function to the datasets
# .imagefolder function
dataset = datasets.ImageFolder(data_dir,
transform=transform)
# now use dataloder function load the
# dataset in the specified transformation.
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=32,
shuffle=True)
# iter function iterates through all the
# images and labels and stores in two variables
images, labels = next(iter(dataloader))
# print the total no of samples
print('Number of samples: ', len(images))
image = images[2][0] # load 3rd sample
# visualize the image
plt.imshow(image, cmap='gray')
# print the size of image
print("Image Size: ", image.size())
# print the label
print(label)
输出:
Image size: torch.Size([224,224])
tensor([0, 0, 0, 1, 1, 1])
Torchtext 数据集
使用 Pytorch 在 torchtext 中加载演示 IMDB 文本数据集。要加载您的自定义文本数据,我们使用 torch.utils.data.DataLoader() 方法。
Syntax: torch.utils.data.DataLoader(‘path to/imdb_data’, batch_size, shuffle=True)
代码说明:
- 该过程与加载图像和音频数据几乎相同。
- 在这里,必须导入torchtext 而不是torchvision。
- 将 torchtext函数与数据集访问器一起使用,后跟数据集名称 (IMDB)。
- 现在,将 split函数传递给 torchtext函数以拆分数据集以训练和测试数据。
- 现在定义一个函数,通过迭代语料库中的每一行,将语料库中的每一行拆分为单独的标记,如图所示。这样,我们就可以轻松地使用 Pytorch 加载文本数据。
Python3
# import the torch and torchtext dataset packages.
import torch
import torchtext
# access the dataset in torchtext package
# using .datasets followed by dataset name.
text_data = torchtext.datasets.IMDB(split='train')
# define a function to tokenize
# the words in the corpus
def tokenize(label, line):
return line.split()
# define a empty list to store
# the tokenized words
tokens = []
# iterate over the text_data and
# tokenize each line and store
# it in the list tokens
for label, line in text_data:
tokens += tokenize(label, line)
print('The total no. of tokens in imdb dataset is',
len(tokens))
输出: