📌  相关文章
📜  如何在 PyTorch 中找到张量的第 k 个和前“k”个元素?

📅  最后修改于: 2022-05-13 01:55:20.955000             🧑  作者: Mango

如何在 PyTorch 中找到张量的第 k 个和前“k”个元素?

在本文中,我们将看到如何找到张量的第 k 个和前“k”个元素。

所以我们可以使用 torch.kthvalue() 找到张量的第 k 个元素,我们可以使用 torch.topk() 方法找到张量的前“k”个元素。

  • torch.kthvalue()函数:首先该函数对张量进行升序排序,然后返回排序张量的第 k 个元素以及原始张量中第 k 个元素的索引。
  • torch.topk()函数:这个函数帮助我们找到给定张量的前“k”个元素。它将返回张量的前“k”个元素,它还将返回原始张量中前“k”个元素的索引。

示例 1:以下程序是查找张量的第 k 个元素。

Python3
# import torch library
import torch
  
# define a tensor
tens = torch.Tensor([4, 5, -3, 9, 7])
print("Original Tensor:\n", tens)
  
# find 3 largest element from the tensor
value, index = torch.kthvalue(tens, 3)
  
# print value along with index
print("\nIndex:", index, "Value:", value)


Python3
# import torch library
import torch
  
# define tensor
tens = torch.Tensor([5.344, 8.343, -2.398, -0.995, 5, 30.421])
print("Original tensor: ", tens)
  
# find top 2 elements
values, indexes = torch.topk(tens, 2)
  
# print top 2 elements
print("Top 2 element values:", values)
  
  
# print index of top 2 elements
print("Top 2 element indices:", indexes)


输出:

例2:下面的程序是找张量的前k个元素

Python3

# import torch library
import torch
  
# define tensor
tens = torch.Tensor([5.344, 8.343, -2.398, -0.995, 5, 30.421])
print("Original tensor: ", tens)
  
# find top 2 elements
values, indexes = torch.topk(tens, 2)
  
# print top 2 elements
print("Top 2 element values:", values)
  
  
# print index of top 2 elements
print("Top 2 element indices:", indexes)

输出: