📅  最后修改于: 2023-12-03 15:04:42.888000             🧑  作者: Mango
在Pytorch中,可以使用索引对Tensor进行设置或取出部分数据。索引操作方法类似于NumPy中的数组索引,但有一些Pytorch独有的细节需要注意。
首先我们创建一个5x5的Tensor:
import torch
x = torch.arange(25).reshape(5, 5)
print(x)
# >>> tensor([[ 0, 1, 2, 3, 4],
# [ 5, 6, 7, 8, 9],
# [10, 11, 12, 13, 14],
# [15, 16, 17, 18, 19],
# [20, 21, 22, 23, 24]])
其中第i行第j列的元素可以通过x[i, j]进行访问,例如:
print(x[0, 1]) # >>> 1
除了使用单个索引,也可以使用范围索引:
print(x[:, 1:3]) # >>> tensor([[ 1, 2],
# [ 6, 7],
# [11, 12],
# [16, 17],
# [21, 22]])
除此之外,Pytorch还支持使用掩码数组进行索引,例如:
mask = x > 10
print(mask)
# >>> tensor([[False, False, False, False, False],
# [False, False, False, False, False],
# [False, True, True, True, True],
# [ True, True, True, True, True],
# [ True, True, True, True, True]])
print(x[mask]) # >>> tensor([11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
在Pytorch中,采取类似于NumPy的方式对数据类型进行索引。例如:
x = torch.tensor([1, 2, 3], dtype=torch.float)
print(x)
# >>> tensor([1., 2., 3.])
print(x[0].item())
# >>> 1.0
可以使用LongTensor类型的索引数组对tensor进行索引。例如:
x = torch.randn(3, 4)
print(x)
# >>> tensor([[-0.5474, -0.7157, -0.0070, -1.6918],
# [ 2.3551, -0.5569, 0.9610, 0.3642],
# [ 0.1221, -0.6849, -0.4440, -0.3760]])
indices = torch.tensor([0, 2])
print(x[indices])
# >>> tensor([[-0.5474, -0.7157, -0.0070, -1.6918],
# [ 0.1221, -0.6849, -0.4440, -0.3760]])
scatter_操作是一个重要、灵活的操作,该操作能够根据索引将一些值写入Tensor,例如:
x = torch.zeros(2, 4)
indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
values = torch.tensor([1.0, 2.0, 3.0])
x.scatter_(1, indices, values)
print(x)
# >>> tensor([[1., 3., 0., 0.],
# [2., 1., 3., 0.]])
上述代码中,我们将值为[1.0, 2.0, 3.0]分别写入x的第0行第1列、第1行第0列和第1行第1列。
gather操作是另一个非常有用的操作,它与scatter_操作相对应,可以根据索引提取Tensor中的一些值,例如:
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
indices = torch.tensor([1, 0])
y = torch.gather(x, 1, indices.unsqueeze(0).t())
print(y)
# >>> tensor([[2.],
# [3.]])
上述代码中,我们将x的第0列和第1列互换,并提取出第0行和第1行,最终得到一个列向量。