📜  Tensorflow.js tf.metrics.recall()函数(1)

📅  最后修改于: 2023-12-03 14:47:55.605000             🧑  作者: Mango

Tensorflow.js - tf.metrics.recall()

Tensorflow.js中的tf.metrics.recall()函数用于计算召回率(Recall)指标。Recall是一个用于度量分类模型的重要性能指标,它指在所有正样本中,被分类器正确分类的比例。

语法
tf.metrics.recall(labels, predictions, classWeight?)
参数
  • labels:实际的标签数据,一维的张量类型。
  • predictions:模型的预测结果,一维的张量类型。
  • classWeight:类别权重,可选,用于解决不平衡类别问题,它是一个用于对每个类别赋权重的对象,其中每个键/值对应一个类别/权重。
返回值

函数返回一个召回率的标量张量。

示例
const tf = require('@tensorflow/tfjs-node');

// 实际标签数据
const labels = tf.tensor1d([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1]);
// 模型预测结果
const predictions = tf.tensor1d([1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1]);

// 计算召回率
const recall = tf.metrics.recall(labels, predictions);
recall.print(); // 输出:0.6666666666666666
注意事项
  • labels和predictions的长度必须相同。
  • labels和predictions的取值只能为0或1。