📜  TensorFlow.js 指标完整参考(1)

📅  最后修改于: 2023-12-03 14:47:56.170000             🧑  作者: Mango

TensorFlow.js 指标完整参考

简介

TensorFlow.js 是一个基于 JavaScript 的机器学习库,可以在浏览器或 Node.js 环境中运行。TensorFlow.js 提供了一系列的指标(metrics)用于评估机器学习模型的性能。

本文将为程序员提供 TensorFlow.js 中支持的完整指标参考,以便于了解和选择适合自己项目的指标,并使用它们来度量模型的性能。

1. 精确度(Accuracy)

tf.metrics.accuracy(yTrue, yPred)

计算模型预测结果与真实标签之间的精确度。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const accuracy = tf.metrics.accuracy(trueLabels, predictedLabels);
console.log(`Accuracy: ${accuracy}`);
2. 真正例(True Positives)

tf.metrics.truePositives(yTrue, yPred, thresholds)

计算模型预测为正例且真实也为正例的样本数量。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const truePositives = tf.metrics.truePositives(trueLabels, predictedLabels);
console.log(`True Positives: ${truePositives}`);
3. 真负例(True Negatives)

tf.metrics.trueNegatives(yTrue, yPred, thresholds)

计算模型预测为负例且真实也为负例的样本数量。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const trueNegatives = tf.metrics.trueNegatives(trueLabels, predictedLabels);
console.log(`True Negatives: ${trueNegatives}`);
4. 假正例(False Positives)

tf.metrics.falsePositives(yTrue, yPred, thresholds)

计算模型预测为正例但真实为负例的样本数量。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const falsePositives = tf.metrics.falsePositives(trueLabels, predictedLabels);
console.log(`False Positives: ${falsePositives}`);
5. 假负例(False Negatives)

tf.metrics.falseNegatives(yTrue, yPred, thresholds)

计算模型预测为负例但真实为正例的样本数量。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const falseNegatives = tf.metrics.falseNegatives(trueLabels, predictedLabels);
console.log(`False Negatives: ${falseNegatives}`);
6. 精确率(Precision)

tf.metrics.precision(yTrue, yPred, thresholds)

计算模型预测为正例的正确率。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const precision = tf.metrics.precision(trueLabels, predictedLabels);
console.log(`Precision: ${precision}`);
7. 召回率(Recall)

tf.metrics.recall(yTrue, yPred, thresholds)

计算模型预测为正例的覆盖率。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const recall = tf.metrics.recall(trueLabels, predictedLabels);
console.log(`Recall: ${recall}`);
8. F1 值

tf.metrics.f1Score(yTrue, yPred, thresholds)

计算模型的 F1 值,综合了精确率和召回率。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
  • thresholds: 阈值或阈值数组,用于二元分类的概率预测。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0, 1, 0, 0]);

const f1Score = tf.metrics.f1Score(trueLabels, predictedLabels);
console.log(`F1 Score: ${f1Score}`);
9. AUC 值

tf.metrics.auc(yTrue, yPred)

计算模型的 AUC(Area Under the Curve)值,用于评估二元分类器的性能。

  • yTrue: 真实标签的张量。
  • yPred: 预测结果的张量。
const trueLabels = tf.tensor1d([0, 1, 1, 0]);
const predictedLabels = tf.tensor1d([0.2, 0.8, 0.6, 0.3]);

const auc = tf.metrics.auc(trueLabels, predictedLabels);
console.log(`AUC: ${auc}`);
结论

本文介绍了 TensorFlow.js 中的常用指标,用于评估机器学习模型的性能。根据项目需求,可以选择适合的指标来度量模型的准确性、召回率、精确率和其他性能指标。

请注意,在使用指标之前,需要导入 tf.metrics 模块并遵循相应的使用方法。

希望本文能够帮助程序员更好地了解并使用 TensorFlow.js 的指标功能。详情请参考 TensorFlow.js 官方文档