📅  最后修改于: 2023-12-03 15:36:33.709000             🧑  作者: Mango
PyTorch Lightning 是一个轻量级的高级框架,它可以帮助您更快,更简单地构建 PyTorch 模型,而无需担心训练循环和设备设置。在本文中,我们将使用 PyTorch Lightning 训练一个简单的神经网络。
安装 PyTorch Lightning 及其依赖项的最简单方法是使用 pip。
pip install pytorch-lightning
我们将构建一个简单的神经网络,它可以识别手写数字。我们使用 MNIST 数据集来训练我们的模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(13*13*64, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return nn.functional.log_softmax(x)
我们使用两个卷积层和两个全连接层构建了一个简单的卷积神经网络。使用 flatten
层的替代方法是使用 view
直接改变张量的维度.
我们将使用 PyTorch Lightning 来训练和验证我们的模型。PyTorch Lightning 提供了一个 LightningModule
类,它将我们的神经网络转换为可在 PyTorch Lightning 中使用的类。
import pytorch_lightning as pl
from argparse import Namespace
class LightningNet(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.model = Net()
def forward(self, x):
return self.model(x)
def train_dataloader(self):
return DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=self.hparams.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(
datasets.MNIST('data', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=self.hparams.batch_size)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.hparams.lr)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = nn.functional.nll_loss(y_hat, y)
logs = {'train_loss': loss}
return {'loss': loss, 'log': logs}
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': nn.functional.nll_loss(y_hat, y)}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
logs = {'val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': logs, 'progress_bar': logs}
hparams = Namespace(lr=1e-3, batch_size=64, num_workers=4)
model = LightningNet(hparams)
trainer = pl.Trainer(gpus=1, max_epochs=5)
trainer.fit(model)
在这个示例中,我们通过继承 LightningModule
类来创建了我们的模型。我们定义了两个数据加载器,一个用于训练,一个用于验证。我们还定义了一个优化器和损失函数。最后,我们实现了训练和验证步骤以及验证指标。
我们创建了一个 LightningNet
对象并传入我们的超参数。我们使用 Trainer
类来训练我们的模型。我们指定了使用一个 GPU 和最多进行 5 个训练周期。Trainer
类负责处理训练循环和设备设置。
PyTorch Lightning 简化了训练和验证神经网络的过程。通过继承 LightningModule
类,我们可以更快,更简单地构建模型,而不必担心训练循环和设备设置。