📅  最后修改于: 2023-12-03 15:23:16.837000             🧑  作者: Mango
本篇教程将介绍如何使用 PyTorch 实现逻辑回归,并将其应用于手写数字的分类任务。
逻辑回归是一种基础的分类方法,它通过对输入的特征进行加权求和,并经过一个 sigmoid 函数将结果限制在 0 到 1 的范围内,从而预测样本属于某一类别的概率。
我们将使用 MNIST 数据集,该数据集包含大量的手写数字图片及其对应的标签,其中训练集包含 60000 张图片,测试集包含 10000 张图片。
首先我们需要下载数据集:
import torch
import torchvision
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=64, shuffle=True)
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=64, shuffle=True)
我们使用 PyTorch 中的 nn.Module
构建逻辑回归模型,其结构包含一个线性层和一个 sigmoid 函数。
class LogisticRegression(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegression, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.linear(x)
x = self.sigmoid(x)
return x
model = LogisticRegression(28*28, 10)
接着,我们定义损失函数和优化器,并使用训练集对模型进行训练。
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_epochs = 10
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, 28*28)
labels = torch.nn.functional.one_hot(labels, 10)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
最后,我们使用测试集对模型进行评估。
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, 28*28)
labels = torch.nn.functional.one_hot(labels, 10)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy: {:.2f}%'.format(100 * correct / total))
本篇教程介绍了如何使用 PyTorch 实现逻辑回归,并将其应用于手写数字的分类任务。通过这个例子,您可以熟悉使用 PyTorch 进行机器学习的流程,并了解逻辑回归的基本原理。