📜  PyTorch 入门(1)

📅  最后修改于: 2023-12-03 15:04:42.761000             🧑  作者: Mango

PyTorch 入门

PyTorch 是一个开放源代码的机器学习框架,主要由Facebook人工智能研究团队开发和维护。它可以帮助程序员们更加方便快捷地使用 Python 进行机器学习和深度学习的项目开发。

安装

在使用 PyTorch 前,你需要在你的电脑上安装 PyTorch 包,建议使用 Anaconda 安装。(注:以下指令在 Linux 系统上通过测试)

使用以下命令安装 PyTorch:

conda install pytorch torchvision -c pytorch
快速上手

PyTorch 的主要数据结构是张量,与 NumPy 的多维数组类似,你可以使用它进行向量、矩阵等数组运算。

下面是一个最简单的例子,通过 PyTorch 实现了矩阵加法和乘法:

import torch

# 矩阵加法
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = x + y  # 或者 z = torch.add(x, y)
print(z)

# 矩阵乘法
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[5, 6], [7, 8]])
z = x.mm(y)  # 或者 z = torch.mm(x, y)
print(z)

输出结果如下:

tensor([5, 7, 9])
tensor([[19, 22],
        [43, 50]])
神经网络

PyTorch 也可以用来构建神经网络。

以下是一个简单的三层全连接神经网络实现,其中输入层有 784 个神经元,隐藏层 1 有 256 个神经元,隐藏层 2 有 128 个神经元,输出层有 10 个神经元。输入层到隐藏层 1,隐藏层 1 到隐藏层 2,隐藏层 2 到输出层均为全连接层。

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # Full Connection Layers
        self.fc1 = nn.Linear(784, 256) 
        self.fc2 = nn.Linear(256, 128) 
        self.fc3 = nn.Linear(128, 10) 

    def forward(self, x):
        # Flatten the input image to a vector
        x = x.view(-1, 784)

        x = nn.ReLU()(self.fc1(x))
        x = nn.ReLU()(self.fc2(x))
        x = nn.LogSoftmax(dim=1)(self.fc3(x))
        
        return x

# 实例化神经网络
net = Net()

# 定义损失函数
loss_function = nn.NLLLoss()

# 定义优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

# 训练神经网络模型
for epoch in range(5):
    for data in train_loader:
        x, y = data

        optimizer.zero_grad()

        output = net(x)
        loss = loss_function(output, y)
        loss.backward()
        optimizer.step()

# 测试神经网络模型
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        x, y = data
        output = net(x)
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct += 1
            total += 1

print("Accuracy: ", round(correct/total, 3))

以上代码演示了如何使用 PyTorch 实现一个简单的全连接神经网络,并在 MNIST 手写数字识别数据集上进行了测试。在这个例子中,我们使用 NLLLoss 作为损失函数,并使用随机梯度下降算法进行优化。最后输出了测试正确率。

结语

通过本文的介绍,希望你对 PyTorch 所提供的强大功能有了初步了解,有了勇气去尝试构建一个自己的神经网络模型。此外,PyTorch 的官方文档提供了更加详细的 API 说明与示例代码,以及许多使用 PyTorch 的优秀项目的源代码。