📜  在 PyTorch 中使用逻辑回归识别手写数字(1)

📅  最后修改于: 2023-12-03 15:23:16.837000             🧑  作者: Mango

在 PyTorch 中使用逻辑回归识别手写数字

本篇教程将介绍如何使用 PyTorch 实现逻辑回归,并将其应用于手写数字的分类任务。

什么是逻辑回归?

逻辑回归是一种基础的分类方法,它通过对输入的特征进行加权求和,并经过一个 sigmoid 函数将结果限制在 0 到 1 的范围内,从而预测样本属于某一类别的概率。

准备数据

我们将使用 MNIST 数据集,该数据集包含大量的手写数字图片及其对应的标签,其中训练集包含 60000 张图片,测试集包含 10000 张图片。

首先我们需要下载数据集:

import torch
import torchvision

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=64, shuffle=True)
构建模型

我们使用 PyTorch 中的 nn.Module 构建逻辑回归模型,其结构包含一个线性层和一个 sigmoid 函数。

class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.linear(x)
        x = self.sigmoid(x)
        return x

model = LogisticRegression(28*28, 10)
训练模型

接着,我们定义损失函数和优化器,并使用训练集对模型进行训练。

criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, 28*28)
        labels = torch.nn.functional.one_hot(labels, 10)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
测试模型

最后,我们使用测试集对模型进行评估。

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28)
        labels = torch.nn.functional.one_hot(labels, 10)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy: {:.2f}%'.format(100 * correct / total))
总结

本篇教程介绍了如何使用 PyTorch 实现逻辑回归,并将其应用于手写数字的分类任务。通过这个例子,您可以熟悉使用 PyTorch 进行机器学习的流程,并了解逻辑回归的基本原理。