Tensorflow.js tf.topk()函数
Tensorflow.js 是谷歌开发的一个开源库,用于在浏览器或节点环境中运行机器学习模型和深度学习神经网络。
tf.topk()函数以及最后一个维度也用于查找 k 个最大条目的值和索引。
句法:
tf.topk (x, k?, sorted?)
参数:
- x:一维或更高的 tf.Tensor,最后一维至少为 k。
- k:它是要查找的元素的数量。
- 排序:它是布尔值。如果为真,则生成的 k 个元素将按值降序排序。
返回值: {值:tf.Tensor,索引:tf.Tensor}。它返回一个包含两个张量的对象,其中包含值和索引。
示例 1:
Javascript
const tf = require("@tensorflow/tfjs")
// Creating a 2d tensor
const a = tf.tensor2d([[1, 20, 3], [4, 3, 1], [8, 9, 10]]);
const {values, indices} = tf.topk(a);
// Printing the values and indices
values.print();
indices.print();
Javascript
const tf = require("@tensorflow/tfjs")
// Creating a 2d tensor
const a = tf.tensor2d([[1, 20, 3], [4, 3, 1], [8, 9, 10]]);
const {values, indices} = tf.topk(a, 3);
// Printing the values and indices
values.print();
indices.print();
输出:
Tensor
[[20],
[4 ],
[10]]
Tensor
[[1],
[0],
[2]]
示例 2:在此示例中,我们将提供参数 k,以获取最大的 k 个条目。
Javascript
const tf = require("@tensorflow/tfjs")
// Creating a 2d tensor
const a = tf.tensor2d([[1, 20, 3], [4, 3, 1], [8, 9, 10]]);
const {values, indices} = tf.topk(a, 3);
// Printing the values and indices
values.print();
indices.print();
输出:
当我们通过 k = 3 时,我们在结果中得到 3 个最大值。
Tensor
[[20, 3, 1],
[4 , 3, 1],
[10, 9, 8]]
Tensor
[[1, 2, 0],
[0, 1, 2],
[2, 1, 0]]
参考: https://js.tensorflow.org/api/latest/#topk