📜  火炬设备 - Python (1)

📅  最后修改于: 2023-12-03 14:56:11.127000             🧑  作者: Mango

火炬设备 - Python

火炬设备是用于PyTorch深度学习框架的工具包,可以帮助开发者更加方便地进行模型训练和调试。它具有如下特点:

  • 灵活:支持各种深度学习模型和常用数据集;
  • 易用:API简单易懂,快速上手;
  • 高效:能够最大程度地利用硬件设备,加速模型训练;
  • 可视化:提供直观的训练曲线和模型性能指标。
安装

使用pip命令进行安装:

pip install torch torchvision torchaudio
pip install pytorch-ignite
示例

下面的示例教程展示了火炬设备如何用于训练MNIST数据集的分类模型。

# 导入所需的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

# 定义模型结构
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

# 加载数据集
mnist_train = MNIST('./data', train=True, download=True,
                    transform=ToTensor())
mnist_test = MNIST('./data', train=False, download=True,
                   transform=ToTensor())
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=True)

# 定义模型、损失函数和优化器
model = Classifier()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 定义训练函数和评估函数
def training_function(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = batch
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluation_function(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch
        y_pred = model(x)
        return y_pred, y

# 创建trainer和evaluator
trainer = create_supervised_trainer(model, optimizer, criterion, device='cuda')
evaluator = create_supervised_evaluator(model,
                                        metrics={'accuracy': Accuracy(),
                                                 'loss': Loss(criterion)},
                                        device='cuda')

# 每隔一定epochs输出训练结果
@trainer.on(Events.EPOCH_COMPLETED(every=5))
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print(f"Training - Epoch: {trainer.state.epoch} " +
          f"Avg accuracy: {metrics['accuracy']:.2f} " +
          f"Avg loss: {metrics['loss']:.2f}")

# 在验证集上进行评估,输出结果
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(test_loader)
    metrics = evaluator.state.metrics
    print(f"Validation - Epoch: {trainer.state.epoch} " +
          f"Avg accuracy: {metrics['accuracy']:.2f} " +
          f"Avg loss: {metrics['loss']:.2f}")

# 触发训练过程
trainer.run(train_loader, max_epochs=20)
总结

火炬设备是一个十分优秀的PyTorch工具包,它能够帮助开发者更加快速、高效地进行模型训练和调试。如果你还没有尝试过这个工具包,不妨下载安装后亲自试一下!