📜  Tensorflow.js tf.metrics.precision()函数(1)

📅  最后修改于: 2023-12-03 15:05:33.118000             🧑  作者: Mango

TensorFlow.js的tf.metrics.precision()函数

tf.metrics.precision函数是TensorFlow.js的一个函数,用于计算分类问题的精确度(Precision)。精确度是分类问题中分类为正类样本中实际为正类的样本数占比。

语法
tf.metrics.precision(yTrue, yPred, threshold?)
  • yTrue:一个张量,包含分类问题的样本真实标签。
  • yPred:一个张量,包含分类问题的模型预测标签。
  • threshold:一个可选参数,用于二分类问题中控制分类边界的阈值。默认值为0.5。
返回值

一个张量,包含了所有样本的精确度值。

示例
const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node');

// 创建一个包含5个样本为二分类问题的张量。
const yTrue = tf.tensor2d([[0], [0], [1], [1], [0]]);
// 创建一个预测标签张量,其中有2个错误预测。
const yPred = tf.tensor2d([[0.1], [0.9], [0.8], [0.6], [0.3]]);

// 计算精确度。
const precision = tf.metrics.precision(yTrue, yPred);
precision.print();

输出:

Tensor
    0.5

此示例计算了一个包含5个样本的二分类问题的精度。下图展示了正确分类和错误分类的样本。

| 真实标签 | 预测标签 | 注意 | |--------|----------|-----| | 0 | 0.1 | | | 0 | 0.9 | ☑️ | | 1 | 0.8 | ☑️ | | 1 | 0.6 | | | 0 | 0.3 | |

其中,预测标签0.9和0.8的样本正确分类,精度为2/4=0.5。

总结

tf.metrics.precision函数是TensorFlow.js的一个分类问题度量函数,用于计算模型在样本分类中的精度。通过设置阈值,可以处理二分类和多分类问题。在开发分类问题的深度学习模型时,该函数是一个非常有用的工具。