📅  最后修改于: 2023-12-03 14:47:54.682000             🧑  作者: Mango
在分类问题中,混淆矩阵是一种非常有用的工具,它可以帮助我们了解到我们的模型在做决策时的正确率和错误率。Tensorflow.js提供了tf.confusionMatrix()函数用于计算混淆矩阵。
tf.confusionMatrix()函数的定义如下:
tf.confusionMatrix(
labels,
predictions,
numClasses,
weights
)
tf.confusionMatrix()函数返回一个张量,由混淆矩阵构成。例如,如果有3个类别,该函数将返回一个形状为[3, 3]的矩阵,其中每行代表真实标签,每列代表预测标签。对角线上的元素代表正确的预测,非对角线上的元素代表错误的预测。
以下是一个使用tf.confusionMatrix()函数计算混淆矩阵的简单代码示例:
const tf = require('@tensorflow/tfjs-node-gpu');
const labels = tf.tensor1d([1, 1, 2, 2, 2, 0]);
const predictions = tf.tensor1d([1, 0, 2, 2, 1, 0]);
const numClasses = 3;
const weights = tf.tensor2d([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
const confusionMatrix = tf.math.confusionMatrix(labels, predictions, numClasses, weights);
confusionMatrix.print();
输出如下:
Tensor
[[1, 2, 0],
[0, 1, 9],
[7, 8, 0]]
该混淆矩阵表示: