如何在 PyTorch 中对张量的元素进行排序?
在本文中,我们将了解如何在Python中对 PyTorch 张量的元素进行排序。
为了对 PyTorch 张量的元素进行排序,我们使用了 torch.sort()方法。当张量是二维的时,我们可以将元素与列或行一起排序。
Syntax: torch.sort(input, dim=- 1, descending=False)
Parameters:
- input: It is an input PyTorch tensor.
- dim: The dimension along which the tensor is sorted. It is an optional int value.
- descending: An optional boolean value used for sorting tensor elements in ascending or descending order. Default is set to False, sorting in ascending order.
Returns: It returns a named tuple of (values, indices), where values are the sorted values and indices are the indices of the elements in the original input tensor.
示例 1:
在下面的示例中,我们按升序和降序对一维张量的元素进行排序。按升序或降序对张量进行排序。我们应用torch.sort()方法对输入张量的元素进行排序。要按降序排序,请将descending=True传递给该方法。
Python3
# importing required library
import torch
# defining a PyTorch Tensor
tensor = torch.tensor([-12, -23, 0.0, 32,
1.32, 201, 5.02])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in ascending order:")
values, indices = torch.sort(tensor)
# printing values of sorted tensor
print("Sorted values:\n", values)
# printing indices of sorted value
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in descending order:")
values, indices = torch.sort(tensor, descending=True)
# printing values of sorted tensor
print("Sorted values:\n", values)
# printing indices of sorted value
print("Indices:\n", indices)
Python3
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43,31,-92],
[3,-4.3,53],
[-4.2,7,-6.2]])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the column:")
values, indices = torch.sort(tensor, dim = 0)
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the column:")
values, indices = torch.sort(tensor, dim = 0,
descending=True)
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
Python3
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43, 31, -92],
[3, -4.3, 53],
[-4.2, 7, -6.2]])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the row:")
values, indices = torch.sort(tensor, dim=1)
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the row:")
values, indices = torch.sort(tensor,
dim=1,
descending=True)
# printing values in sorted tensor
print("Sorted values:\n", values)
# printing indices of values in sorted tensor
print("Indices:\n", indices)
输出:
示例 2:
在此示例中,我们将二维张量的元素与列一起按升序和降序排序。
Python3
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43,31,-92],
[3,-4.3,53],
[-4.2,7,-6.2]])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the column:")
values, indices = torch.sort(tensor, dim = 0)
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the column:")
values, indices = torch.sort(tensor, dim = 0,
descending=True)
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
输出:
示例 3:
在这个例子中,我们按照行的升序和降序对二维张量的元素进行排序。
Python3
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43, 31, -92],
[3, -4.3, 53],
[-4.2, 7, -6.2]])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the row:")
values, indices = torch.sort(tensor, dim=1)
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the row:")
values, indices = torch.sort(tensor,
dim=1,
descending=True)
# printing values in sorted tensor
print("Sorted values:\n", values)
# printing indices of values in sorted tensor
print("Indices:\n", indices)
输出: