如何在 PyTorch 中对张量的元素进行排序?
在本文中,我们将了解如何在Python中对 PyTorch 张量的元素进行排序。
为了对 PyTorch 张量的元素进行排序,我们使用了 torch.sort()方法。当张量是二维的时,我们可以将元素与列或行一起排序。
Syntax: torch.sort(input, dim=- 1, descending=False)
- input: It is an input PyTorch tensor.
- dim: The dimension along which the tensor is sorted. It is an optional int value.
- descending: An optional boolean value used for sorting tensor elements in ascending or descending order. Default is set to False, sorting in ascending order.
Returns: It returns a named tuple of (values, indices), where values are the sorted values and indices are the indices of the elements in the original input tensor.
示例 1:
# importing required library
import torch
# defining a PyTorch Tensor
tensor = torch.tensor([-12, -23, 0.0, 32,
1.32, 201, 5.02])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in ascending order:")
values, indices = torch.sort(tensor)
# printing values of sorted tensor
print("Sorted values:\n", values)
# printing indices of sorted value
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in descending order:")
values, indices = torch.sort(tensor, descending=True)
# printing values of sorted tensor
print("Sorted values:\n", values)
# printing indices of sorted value
print("Indices:\n", indices)
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43,31,-92],
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the column:")
values, indices = torch.sort(tensor, dim = 0)
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the column:")
values, indices = torch.sort(tensor, dim = 0,
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43, 31, -92],
[3, -4.3, 53],
[-4.2, 7, -6.2]])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the row:")
values, indices = torch.sort(tensor, dim=1)
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the row:")
values, indices = torch.sort(tensor,
# printing values in sorted tensor
print("Sorted values:\n", values)
# printing indices of values in sorted tensor
print("Indices:\n", indices)
示例 2:
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43,31,-92],
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the column:")
values, indices = torch.sort(tensor, dim = 0)
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the column:")
values, indices = torch.sort(tensor, dim = 0,
# printing values in sorted tensor
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
示例 3:
# importing the library
import torch
# define a 2D torch tensor
tensor = torch.tensor([[43, 31, -92],
[3, -4.3, 53],
[-4.2, 7, -6.2]])
print("Tensor:\n", tensor)
# sorting the tensor in ascending order
print("Sorting tensor in \
ascending order along the row:")
values, indices = torch.sort(tensor, dim=1)
print("Sorted values:\n", values)
# print indices of values in sorted tensor
print("Indices:\n", indices)
# sorting the tensor in descending order
print("Sorting tensor in \
descending order along the row:")
values, indices = torch.sort(tensor,
# printing values in sorted tensor
print("Sorted values:\n", values)
# printing indices of values in sorted tensor
print("Indices:\n", indices)