了解 PyTorch 闪电数据模块
PyTorch Lightning 旨在使 PyTorch 代码更具结构化和可读性,这不仅限于 PyTorch 模型,还包括数据本身。在 PyTorch 中,我们使用 DataLoaders 来训练或测试我们的模型。虽然我们也可以在 PyTorch Lightning 中使用 DataLoaders 来训练模型,但 PyTorch Lightning 还为我们提供了一种更好的方法,称为 DataModules。 DataModule 是一个可重用和可共享的类,它封装了 DataLoader 以及处理数据所需的步骤。创建数据加载器可能会变得混乱,这就是为什么最好以 DataModule 的形式对数据集进行分组的原因。建议您了解如何使用 PyTorch Lightning 定义神经网络。
安装 PyTorch 闪电:
安装 Lightning 与在Python中安装任何其他库相同。
pip install pytorch-lightning
或者,如果您想在 conda 环境中安装它,您可以使用以下命令:-
conda install -c conda-forge pytorch-lightning
Pytorch 闪电数据模块格式
要定义 Lightning DataModule,我们遵循以下格式:-
import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader
class DataModuleClass(pl.LightningDataModule):
def __init__(self):
#Define required parameters here
def prepare_data(self):
# Define steps that should be done
# on only one GPU, like getting data.
def setup(self, stage=None):
# Define steps that should be done on
# every GPU, like splitting data, applying
# transform etc.
def train_dataloader(self):
# Return DataLoader for Training Data here
def val_dataloader(self):
# Return DataLoader for Validation Data here
def test_dataloader(self):
# Return DataLoader for Testing Data here
注意:上述函数的名称应该完全相同。
了解 DataModule 类
在本文中,我将使用 MNIST 数据作为示例。如我们所见,创建 Lightning DataModule 的第一个要求是继承 pytorch-lightning 中的 LightningDataModule 类:
import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader
class DataModuleMNIST(pl.LightningDataModule):
__init__() 方法:
它用于存储有关批量大小、转换等的信息。
def __init__(self):
super().__init__()
self.download_dir = ''
self.batch_size = 32
self.transform = transforms.Compose([
transforms.ToTensor()
])
prepare_data() 方法:
此方法用于定义仅由一个 GPU 执行的进程。它通常用于处理下载数据的任务。
def prepare_data(self):
datasets.MNIST(self.download_dir,
train=True, download=True)
datasets.MNIST(self.download_dir, train=False,
download=True)
设置()方法:
此方法用于定义由所有可用 GPU 执行的进程。它通常用于处理加载数据的任务。
def setup(self, stage=None):
data = datasets.MNIST(self.download_dir,
train=True, transform=self.transform)
self.train_data, self.valid_data = random_split(data, [55000, 5000])
self.test_data = datasets.MNIST(self.download_dir,
train=False, transform=self.transform)
train_dataloader() 方法:
此方法用于创建训练数据数据加载器。在这个函数中,你通常只返回训练数据的数据加载器。
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size)
val_dataloader() 方法:
此方法用于创建验证数据数据加载器。在这个函数中,你通常只返回验证数据的数据加载器。
def val_dataloader(self):
return DataLoader(self.valid_data, batch_size=self.batch_size)
test_dataloader() 方法:
此方法用于创建测试数据数据加载器。在这个函数中,你通常只返回测试数据的数据加载器。
def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)
使用 DataModule 训练 Pytorch 闪电模型:
在 Pytorch Lighting 中,我们使用 Trainer() 来训练我们的模型,在这种情况下,我们可以将数据作为 DataLoader 或 DataModule 传递。我们以我在本文中定义的模型为例:
class model(pl.LightningModule):
def __init__(self):
super(model, self).__init__()
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256, 128)
self.out = nn.Linear(128, 10)
self.lr = 0.01
self.loss = nn.CrossEntropyLoss()
def forward(self, x):
batch_size, _, _, _ = x.size()
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.loss(logits, y)
return loss
def validation_step(self, valid_batch, batch_idx):
x, y = valid_batch
logits = self.forward(x)
loss = self.loss(logits, y)
现在要训练这个模型,我们将创建一个 Trainer() 对象并通过将我们的模型和数据模块作为参数传递给它进行 fit() 。
clf = model()
mnist = DataModuleMNIST()
trainer = pl.Trainer(gpus=1)
trainer.fit(clf, mnist)
下面是完整的实现:
Python3
# import module
import torch
# To get the layers and losses for our model
from torch import nn
import pytorch_lightning as pl
# To get the activation function for our model
import torch.nn.functional as F
# To get MNIST data and transforms
from torchvision import datasets, transforms
# To get the optimizer for our model
from torch.optim import SGD
# To get random_split to split training
# data into training and validation data
# and DataLoader to create dataloaders for train,
# valid and test data to be returned
# by our data module
from torch.utils.data import random_split, DataLoader
class model(pl.LightningModule):
def __init__(self):
super(model, self).__init__()
# Defining our model architecture
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256, 128)
self.out = nn.Linear(128, 10)
# Defining learning rate
self.lr = 0.01
# Defining loss
self.loss = nn.CrossEntropyLoss()
def forward(self, x):
# Defining the forward pass of the model
batch_size, _, _, _ = x.size()
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)
def configure_optimizers(self):
# Defining and returning the optimizer for our model
# with the defines parameters
return torch.optim.SGD(self.parameters(), lr = self.lr)
def training_step(self, train_batch, batch_idx):
# Defining training steps for our model
x, y = train_batch
logits = self.forward(x)
loss = self.loss(logits, y)
return loss
def validation_step(self, valid_batch, batch_idx):
# Defining validation steps for our model
x, y = valid_batch
logits = self.forward(x)
loss = self.loss(logits, y)
class DataModuleMNIST(pl.LightningDataModule):
def __init__(self):
super().__init__()
# Directory to store MNIST Data
self.download_dir = ''
# Defining batch size of our data
self.batch_size = 32
# Defining transforms to be applied on the data
self.transform = transforms.Compose([
transforms.ToTensor()
])
def prepare_data(self):
# Downloading our data
datasets.MNIST(self.download_dir,
train = True, download = True)
datasets.MNIST(self.download_dir,
train = False, download = True)
def setup(self, stage=None):
# Loading our data after applying the transforms
data = datasets.MNIST(self.download_dir,
train = True,
transform = self.transform)
self.train_data, self.valid_data = random_split(data,
[55000, 5000])
self.test_data = datasets.MNIST(self.download_dir,
train = False,
transform = self.transform)
def train_dataloader(self):
# Generating train_dataloader
return DataLoader(self.train_data,
batch_size = self.batch_size)
def val_dataloader(self):
# Generating val_dataloader
return DataLoader(self.valid_data,
batch_size = self.batch_size)
def test_dataloader(self):
# Generating test_dataloader
return DataLoader(self.test_data,
batch_size = self.batch_size)
clf = model()
mnist = DataModuleMNIST()
trainer = pl.Trainer()
trainer.fit(clf, mnist)
输出: