📅  最后修改于: 2023-12-03 15:37:27.711000             🧑  作者: Mango
PyTorch 是一个流行的深度学习框架,可以帮助您轻松地创建和操作张量(tensor)。在本文中,我们将讨论如何在 PyTorch 中创建张量。
张量是一个多维数组(通常是矩阵),用于存储和处理数据。在 PyTorch 中,张量与 NumPy 数组非常相似,可以执行类似的操作,例如索引,切片和数学运算。
PyTorch 张量有四个属性:秩(rank)、形状(shape)、数据类型(dtype)和设备(device)。秩表示张量的维数,形状表示张量的大小,数据类型表示张量中元素的类型,设备表示张量存储在哪个硬件设备上(CPU 或 GPU)。
现在让我们看看如何在 PyTorch 中创建张量。
要在 PyTorch 中创建张量,请使用 torch.tensor()
函数。这个函数的参数可以是列表、元组、NumPy 数组、标量或另一个张量。
以下是一个创建张量的示例:
import torch
# 创建一个长度为 5 的一维张量,分别为 1、2、3、4、5
a = torch.tensor([1, 2, 3, 4, 5])
print(a)
# 创建一个 3x3 的二维张量,每个元素都为 0
b = torch.zeros((3, 3))
print(b)
# 创建一个 2x2 的二维张量,每个元素都为随机数
c = torch.rand((2, 2))
print(c)
# 创建一个 2x3 的二维张量,每个元素都为 1
d = torch.ones((2, 3))
print(d)
# 创建一个 3x3 的单位矩阵
e = torch.eye(3)
print(e)
输出:
tensor([1, 2, 3, 4, 5])
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
tensor([[0.3086, 0.4617],
[0.6900, 0.2407]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
有时,您需要在不更改张量数据的情况下更改其形状。在 PyTorch 中,您可以使用 torch.view()
函数来实现这一点。
下面是一个示例:
import torch
# 创建一个 2x3 的二维张量,每个元素都为随机数
a = torch.rand((2, 3))
print(a)
# 将张量形状改为 3x2
b = a.view((3, 2))
print(b)
输出:
tensor([[0.7759, 0.5761, 0.5968],
[0.3343, 0.9667, 0.4979]])
tensor([[0.7759, 0.5761],
[0.5968, 0.3343],
[0.9667, 0.4979]])
PyTorch 中的张量支持各种数学运算,例如加法,减法,乘法和除法。这些运算可以使用 torch.add()
、torch.sub()
、torch.mul()
和 torch.div()
函数来执行。
以下是一个示例:
import torch
# 创建两个长度为 3 的一维张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 计算两个张量之和
c = torch.add(a, b)
print(c)
# 计算两个张量之差
d = torch.sub(a, b)
print(d)
# 计算两个张量之积
e = torch.mul(a, b)
print(e)
# 计算两个张量之商
f = torch.div(a, b)
print(f)
输出:
tensor([5, 7, 9])
tensor([-3, -3, -3])
tensor([ 4, 10, 18])
tensor([0.2500, 0.4000, 0.5000])
在深度学习中,通常使用 GPU 来加速计算。在 PyTorch 中,在 GPU 上创建张量非常容易。只需要将张量分配给 GPU 设备即可。
以下是一个示例:
import torch
# 创建一个 3x3 的二维张量,存储在 GPU 上
a = torch.ones((3, 3), device="cuda")
print(a)
输出:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], device='cuda:0')
在本文中,我们讨论了如何在 PyTorch 中创建张量、改变张量形状、进行数学运算以及在 GPU 上创建张量。希望这篇文章能够帮助您更好地了解如何使用 PyTorch 中的张量。