📅  最后修改于: 2023-12-03 15:36:02.713000             🧑  作者: Mango
PyTorch 闪电数据模块 (PyTorch Lightning) 是 PyTorch 的高级封装,旨在让用户更快、更简单地训练模型。它可以自动处理许多冗长的训练和验证步骤,使用户可以将更多的精力放到模型的设计和调整上。
PyTorch 闪电数据模块的特性包括:
以下是一个简单的 PyTorch 闪电数据模块示例,展示如何使用它来训练一个简单的 MNIST 分类器:
import torch
from torch import nn, optim
import pytorch_lightning as pl
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
class MNISTClassifier(pl.LightningModule):
def __init__(self):
super(MNISTClassifier, self).__init__()
self.fc1 = nn.Linear(28 * 28, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = nn.functional.softmax(self.fc3(x), dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
out = self(x)
criterion = nn.CrossEntropyLoss()
loss = criterion(out, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
out = self(x)
criterion = nn.CrossEntropyLoss()
loss = criterion(out, y)
accuracy = pl.metrics.functional.accuracy(out, y)
self.log('val_loss', loss, on_step=False, on_epoch=True)
self.log('val_accuracy', accuracy, on_step=False, on_epoch=True)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNIST('data', train=True, download=True, transform=train_transforms)
test_dataset = MNIST('data', train=False, download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)
model = MNISTClassifier()
trainer = pl.Trainer(gpus=1, max_epochs=10)
trainer.fit(model, train_loader, test_loader)
在该示例中,我们创建了一个 MNIST 分类器模型,并使用 PyTorch 闪电数据模块的 LightningModule
类作为基类。然后我们实现了训练循环和验证循环的方法 training_step
和 validation_step
,并使用 configure_optimizers
方法定义了优化器。
我们还使用 PyTorch 的数据加载工具 DataLoader
加载 MNIST 数据集,然后使用 Trainer
类来训练模型。Trainer
类在训练时将处理许多繁琐的任务,例如将数据移动到正确的设备上,自动调整学习率等。