📅  最后修改于: 2023-12-03 15:08:46.098000             🧑  作者: Mango
在 PyTorch 中,我们可以使用 torch.sort
方法对张量元素进行排序。该方法可以对张量进行升序或降序排序,并返回排序后的结果和对应的索引。
sorted_tensor, indices = torch.sort(input, dim=None, descending=False, stable=False)
参数说明:
input
:需要进行排序的张量。dim
:指定进行排序的维度,如果为 None
,则默认对整个张量排序。descending
:是否采用降序排序,默认为 False
。stable
:当排序中有相等的元素时,是否保持它们在原序列中的相对位置关系不变,默认为 False
。import torch
# 生成一个大小为5x3的张量
x = torch.randn(5, 3)
print(f"x:\n{x}\n")
# 对整个张量进行升序排序
sorted_tensor, indices = torch.sort(x)
print(f"sorted_tensor:\n{sorted_tensor}\n")
print(f"indices:\n{indices}\n")
# 对第0维度进行降序排序
sorted_tensor, indices = torch.sort(x, dim=0, descending=True)
print(f"sorted_tensor:\n{sorted_tensor}\n")
print(f"indices:\n{indices}\n")
输出结果:
x:
tensor([[ 0.6488, -0.6001, -1.7620],
[-2.8995, 0.2005, -0.6887],
[ 0.4121, -0.2223, -0.1305],
[ 1.1422, 0.2251, -0.6016],
[ 1.1444, -0.7302, 0.7982]])
sorted_tensor:
tensor([[-1.7620, -0.6001, 0.6488],
[-2.8995, -0.6887, 0.2005],
[-0.2223, -0.1305, 0.4121],
[-0.6016, 0.2251, 1.1422],
[-0.7302, 0.7982, 1.1444]])
indices:
tensor([[2, 1, 0],
[0, 2, 1],
[1, 2, 0],
[2, 1, 0],
[1, 2, 0]])
sorted_tensor:
tensor([[ 1.1444, 0.2251, 0.7982],
[ 1.1422, 0.2005, -0.1305],
[ 0.6488, -0.2223, -0.6016],
[ 0.4121, -0.6001, -0.6887],
[-2.8995, -0.7302, -1.7620]])
indices:
tensor([[4, 3, 4],
[3, 1, 2],
[0, 2, 3],
[2, 0, 1],
[1, 4, 0]])
从输出结果中可以看出,升序排序之后,结果以升序排列,且每一列的元素在原序列中的相对位置都不变;降序排序之后,结果以降序排列,且每一行的元素在原序列中的相对位置都不变。
以上就是在 PyTorch 中对张量元素进行排序的方法,希望对大家有所帮助。