📜  PyTorch-术语(1)

📅  最后修改于: 2023-12-03 15:34:33.113000             🧑  作者: Mango

PyTorch 术语介绍

PyTorch 是一个基于 Python 的科学计算库。本文将介绍 PyTorch 中常见的术语。

张量(Tensor)

张量(Tensor)是 PyTorch 中最基本的数据结构,它是一个多维数组。可以将张量看作是一个数字矩阵。在 PyTorch 中,张量是所有计算的基本单位。

import torch

# 创建一个 3x3 的零张量
x = torch.zeros(3, 3)
print(x)

# 创建一个 2x2 的随机张量
y = torch.randn(2, 2)
print(y)
模型(Module)

模型(Module)是 PyTorch 中的一个重要概念,它指的是一个可调用的对象,这个对象可以接受输入数据并返回输出数据。一般来说,模型通常由一堆层(Layer)组成。

import torch.nn as nn

# 创建一个简单的全连接层
layer = nn.Linear(2, 1)
print(layer)

# 对输入数据进行前向计算
x = torch.ones(1, 2)
y = layer(x)
print(y)
优化器(Optimizer)

优化器(Optimizer)是 PyTorch 中实现梯度下降算法的一种方式,它可以自动计算梯度并更新模型参数,从而优化模型。在 PyTorch 中,可以通过创建一个优化器实例并调用它的 step() 方法来更新模型参数。

import torch.optim as optim

# 创建一个优化器
optimizer = optim.SGD(layer.parameters(), lr=0.01)

# 对损失函数进行反向传播并更新模型
loss_fn = nn.MSELoss()
loss = loss_fn(y, torch.ones(1, 1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
数据集(Dataset)

数据集(Dataset)是 PyTorch 中常用的一种数据类型,它可以包含训练数据、验证数据和测试数据等。在 PyTorch 中,可以通过创建一个数据集实例并调用它的 __getitem__() 方法来获取数据。

from torch.utils.data import Dataset

# 创建一个简单的数据集
class MyDataset(Dataset):
    def __getitem__(self, index):
        return torch.randn(2, 2), torch.randn(1)
    
    def __len__(self):
        return 100
    
dataset = MyDataset()
print(len(dataset))

# 从数据集中获取数据
x, y = dataset[0]
print(x, y)
数据加载器(DataLoader)

数据加载器(DataLoader)是 PyTorch 中用于加载数据的一种方式,它可以自动将数据集分成小批量(Batch)并提供多线程数据读取和处理等功能。

from torch.utils.data import DataLoader

# 创建一个数据加载器
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 从加载器中获取数据
for x, y in dataloader:
    print(x, y)

以上是 PyTorch 中常见的几个术语,它们是深度学习中必不可少的一部分。PyTorch 提供了简单易用的 API,使得开发深度学习模型变得更加容易。