📜  使用 PyTorch 对 MNIST 进行逻辑回归(1)

📅  最后修改于: 2023-12-03 15:22:16.087000             🧑  作者: Mango

使用 PyTorch 对 MNIST 进行逻辑回归

本教程将介绍如何使用 PyTorch 对 MNIST 数据集进行逻辑回归。我们将使用 PyTorch 提供的自动求导以及优化器来训练模型,并且在最后将展示模型的准确率。

数据集

MNIST 数据集包含了一系列手写数字的灰度图片,每张图片大小为 $28 \times 28$ 像素。我们的目标是通过模型预测出每张图片所表示的数字。PyTorch 提供了方便的数据加载器并且可以自动进行数据增强,例如随机旋转、翻转和裁剪等操作。以下代码将加载 MNIST 数据集:

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

train_dataset = datasets.MNIST(root='./data', train=True,
                               transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False,
                              transform=transforms.ToTensor(), download=True)
模型定义

我们使用一个简单的逻辑回归模型来进行分类,它包含了一个线性层和一个 softmax 函数。以下是该模型的定义:

class LogisticRegression(torch.nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_size, num_classes)

    def forward(self, x):
        x = x.view(-1, self.linear.in_features)
        x = self.linear(x)
        return torch.nn.functional.softmax(x, dim=1)

我们的模型输入大小为 $28 \times 28 = 784$,输出大小为 $10$,表示每个数字的概率。

模型训练

我们将使用交叉熵作为损失函数,使用随机梯度下降进行优化。以下代码将训练我们的模型:

model = LogisticRegression(784, 10)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10):
    for i, (images, labels) in enumerate(train_dataset):
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, 10, i + 1, len(train_dataset) // 100, loss.item()))
模型评估

我们在测试集上评估模型的准确率。以下代码将展示如何计算模型的准确率:

model.eval()

correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_dataset:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += 1
        correct += (predicted == labels).sum().item()

print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

在我的本地环境上,训练 10 个 epoch 后模型在测试集上的准确率为 87.64%。

总结

本教程介绍了如何使用 PyTorch 对 MNIST 数据集进行逻辑回归。我们展示了如何加载数据集、定义模型、训练模型、评估模型。使用 PyTorch 可以使我们方便地进行模型训练并且得到不错的结果。