📜  了解 PyTorch 闪电数据模块(1)

📅  最后修改于: 2023-12-03 15:36:02.713000             🧑  作者: Mango

了解 PyTorch 闪电数据模块

简介

PyTorch 闪电数据模块 (PyTorch Lightning) 是 PyTorch 的高级封装,旨在让用户更快、更简单地训练模型。它可以自动处理许多冗长的训练和验证步骤,使用户可以将更多的精力放到模型的设计和调整上。

特性

PyTorch 闪电数据模块的特性包括:

  • 自动批量化和设备移动
  • 准确、可靠且可重复的训练和验证循环
  • 高效的分布式训练和混合精度训练支持
  • 自动调整学习率以降低手动调整学习率的复杂性
  • 通过使用各种插件来扩展框架
示例

以下是一个简单的 PyTorch 闪电数据模块示例,展示如何使用它来训练一个简单的 MNIST 分类器:

import torch
from torch import nn, optim
import pytorch_lightning as pl
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms


class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = nn.functional.softmax(self.fc3(x), dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(out, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(out, y)
        accuracy = pl.metrics.functional.accuracy(out, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_accuracy', accuracy, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST('data', train=True, download=True, transform=train_transforms)
test_dataset = MNIST('data', train=False, download=True, transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

model = MNISTClassifier()
trainer = pl.Trainer(gpus=1, max_epochs=10)
trainer.fit(model, train_loader, test_loader)

在该示例中,我们创建了一个 MNIST 分类器模型,并使用 PyTorch 闪电数据模块的 LightningModule 类作为基类。然后我们实现了训练循环和验证循环的方法 training_stepvalidation_step,并使用 configure_optimizers 方法定义了优化器。

我们还使用 PyTorch 的数据加载工具 DataLoader 加载 MNIST 数据集,然后使用 Trainer 类来训练模型。Trainer 类在训练时将处理许多繁琐的任务,例如将数据移动到正确的设备上,自动调整学习率等。