如何在 PyTorch 中找到张量的第 k 个和前“k”个元素?
在本文中,我们将看到如何找到张量的第 k 个和前“k”个元素。
所以我们可以使用 torch.kthvalue() 找到张量的第 k 个元素,我们可以使用 torch.topk() 方法找到张量的前“k”个元素。
- torch.kthvalue()函数:首先该函数对张量进行升序排序,然后返回排序张量的第 k 个元素以及原始张量中第 k 个元素的索引。
Syntax: torch.kthvalue(input_tensor, k, dim=None, keepdim=False, out=None)
Parameters:
- Input_tensor: tensor.
- k: k is integer and it’s for k-th smallest element of tensor.
- dim: dim is for dimension to find the k-th value along of tensor.
- keepdim (bool): keepdim is for whether the output tensor has dim retained or not.
Return: This method returns a tuple (values, indices) of the k-th element of tensor.
- torch.topk()函数:这个函数帮助我们找到给定张量的前“k”个元素。它将返回张量的前“k”个元素,它还将返回原始张量中前“k”个元素的索引。
Syntax: torch.topk(input_tensor, k, dim=None, largest=True, sorted=True, out=None)
Parameters:
- input_tensor: tensor.
- k: k is integer value and it’s for the k in top-k.
- dim: the dim is for the dimension to sort along of tensor.
- largest: this is used to controls whether return largest or smallest elements of tensor.
- sorted: it controls whether to return the elements in sorted order.
Return: this function is returns the ‘k’ largest elements of tensor along a given dimension.
示例 1:以下程序是查找张量的第 k 个元素。
Python3
# import torch library
import torch
# define a tensor
tens = torch.Tensor([4, 5, -3, 9, 7])
print("Original Tensor:\n", tens)
# find 3 largest element from the tensor
value, index = torch.kthvalue(tens, 3)
# print value along with index
print("\nIndex:", index, "Value:", value)
Python3
# import torch library
import torch
# define tensor
tens = torch.Tensor([5.344, 8.343, -2.398, -0.995, 5, 30.421])
print("Original tensor: ", tens)
# find top 2 elements
values, indexes = torch.topk(tens, 2)
# print top 2 elements
print("Top 2 element values:", values)
# print index of top 2 elements
print("Top 2 element indices:", indexes)
输出:
例2:下面的程序是找张量的前k个元素
Python3
# import torch library
import torch
# define tensor
tens = torch.Tensor([5.344, 8.343, -2.398, -0.995, 5, 30.421])
print("Original tensor: ", tens)
# find top 2 elements
values, indexes = torch.topk(tens, 2)
# print top 2 elements
print("Top 2 element values:", values)
# print index of top 2 elements
print("Top 2 element indices:", indexes)
输出: